In [None]:
# RAG-Based Health Misinformation Detection for COVID-19 Tweets
# Complete implementation for Kaggle environment with fixed API and training args

import os
import pandas as pd
import numpy as np
import torch
import transformers
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from transformers import Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import requests
import time
from datasets import Dataset
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings("ignore")

print(f"Using transformers version: {transformers.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ------------------------------------------
# 1. Load and Preprocess the Dataset
# ------------------------------------------

df = pd.read_csv('/kaggle/input/merged-covid-misinformation/merged_dataset.csv')  # Adjust path if needed
print(f"Loaded dataset with {len(df)} entries")
print(df.head())

# Check class distribution
print("\nClass distribution:")
print(df['label'].value_counts())

# Convert labels to binary format
df['label_encoded'] = df['label'].apply(lambda x: 1 if x == 'misinformation' else 0)

# Basic preprocessing
def preprocess_tweet(text):
    """Basic preprocessing for tweets"""
    if isinstance(text, str):
        # Convert to lowercase
        text = text.lower()
        # Remove URLs (simple regex)
        text = ' '.join([word for word in text.split() if not word.startswith('http')])
        # Remove multiple spaces
        text = ' '.join(text.split())
        return text
    return ""

df['processed_text'] = df['content'].apply(preprocess_tweet)

# ------------------------------------------
# 2. External Knowledge Retrieval Functions
# ------------------------------------------

# For fact-checking sources (using a local database of COVID facts)
def retrieve_from_factcheck(query):
    """Retrieve from local database of COVID facts"""
    # Common COVID facts and misconceptions
    covid_facts = {
        "covid cure": "There is no known cure for COVID-19, but vaccines are effective in preventing severe illness.",
        "covid vaccine": "COVID-19 vaccines have been scientifically proven to be safe and effective.",
        "5g covid": "There is no scientific evidence linking 5G technology to COVID-19.",
        "mask": "Masks help reduce the spread of COVID-19 by blocking respiratory droplets.",
        "hydroxychloroquine": "Studies have not shown hydroxychloroquine to be effective against COVID-19.",
        "vitamin": "While vitamins support immune health, no vitamin has been proven to prevent or cure COVID-19.",
        "microchip": "COVID-19 vaccines do not contain microchips or tracking devices.",
        "bill gates": "Claims that Bill Gates is using vaccines for population control are false.",
        "covid lab": "The scientific consensus is that COVID-19 was not artificially created in a laboratory.",
        "covid hoax": "COVID-19 is a real disease that has caused millions of deaths worldwide.",
        "covid fake": "COVID-19 is a real disease, not a hoax or conspiracy.",
        "lockdown": "Lockdowns were implemented to slow the spread of COVID-19 and prevent healthcare systems from being overwhelmed.",
        "pcr test": "PCR tests are reliable for detecting the presence of the SARS-CoV-2 virus that causes COVID-19.",
        "covid deaths": "COVID-19 has caused millions of deaths globally, as confirmed by excess mortality studies.",
        "covid origin": "Scientific evidence suggests COVID-19 originated from animal-to-human transmission.",
        "covid symptoms": "Common COVID-19 symptoms include fever, cough, fatigue, and loss of taste or smell.",
        "quarantine": "Quarantine helps prevent the spread of COVID-19 by isolating potentially infected individuals.",
        "covid test": "COVID-19 tests are designed to detect current or past infection with the SARS-CoV-2 virus.",
        "covid treatment": "COVID-19 treatments may include antivirals, monoclonal antibodies, or supportive care.",
        "covid statistics": "COVID-19 case and death statistics are tracked by health organizations worldwide.",
        "covid immunity": "Both vaccination and prior infection can provide some immunity against COVID-19.",
        "covid variants": "COVID-19 variants emerge through genetic mutations in the SARS-CoV-2 virus.",
        "vaccine side effects": "COVID-19 vaccines can cause temporary side effects like fatigue or soreness, but serious side effects are extremely rare.",
        "ivermectin": "Medical authorities do not recommend ivermectin for COVID-19 treatment outside of clinical trials.",
        "covid children": "Children can contract and transmit COVID-19, though they typically have milder symptoms than adults.",
        "natural immunity": "Natural immunity from infection provides some protection, but vaccination is still recommended.",
        "vaccine mandate": "Vaccine mandates have been implemented in some places to increase vaccination rates and protect public health.",
        "covid restrictions": "COVID-19 restrictions were implemented to reduce transmission and save lives.",
        "covid conspiracy": "Scientific evidence contradicts conspiracy theories about COVID-19's origin or purpose.",
        "covid survival rate": "While many people survive COVID-19, it has caused millions of deaths worldwide.",
        "wuhan": "The first identified cases of COVID-19 were in Wuhan, China in late 2019.",
        "who covid": "The World Health Organization provides guidance on COVID-19 prevention, detection, and treatment.",
        "cdc covid": "The CDC provides evidence-based guidance on COVID-19 for the United States.",
        "covid pneumonia": "COVID-19 can cause pneumonia, a serious lung infection.",
        "covid testing": "COVID-19 testing is an important tool for detecting and controlling the spread of the virus.",
        "asymptomatic": "People with asymptomatic COVID-19 can still spread the virus to others.",
        "covid vaccine safety": "COVID-19 vaccines have undergone rigorous safety testing and continuous monitoring.",
        "covid hospitalization": "COVID-19 can lead to hospitalization, especially for unvaccinated individuals and those with risk factors.",
        "long covid": "Some COVID-19 patients experience persistent symptoms, known as Long COVID.",
        "false positive": "False positives in COVID-19 testing are possible but rare with PCR tests when performed correctly."
    }
    
    # Check if any key phrases are in the query
    for key, fact in covid_facts.items():
        if key in query.lower():
            return fact
    
    return "No specific fact-check information found for this query."

# Combined retrieval function that doesn't rely on external APIs
def retrieve_knowledge(tweet):
    """Retrieve external knowledge for a tweet using only local data"""
    # Extract key phrases from tweet (simplified approach)
    words = tweet.lower().split()
    query = " ".join([w for w in words if len(w) > 3 and w not in ['this', 'that', 'with', 'from', 'what', 'when']])
    
    # Add 'covid' to the query if not present and the tweet is about COVID
    if 'covid' not in query and ('covid' in tweet.lower() or 'coronavirus' in tweet.lower()):
        query = 'covid ' + query
    
    # Limit query length
    query = ' '.join(query.split()[:7])
    
    # Get information from fact-check source
    factcheck_info = retrieve_from_factcheck(query)
    
    if factcheck_info != "No specific fact-check information found for this query.":
        return "FACT CHECK: " + factcheck_info
    
    # If no specific fact check found, provide general COVID information
    covid_general_info = {
        "general": "COVID-19 is caused by the SARS-CoV-2 virus and spreads primarily through respiratory droplets.",
        "symptoms": "Common COVID-19 symptoms include fever, cough, fatigue, and loss of taste or smell.",
        "prevention": "Preventive measures for COVID-19 include vaccination, masks, physical distancing, and hand hygiene.",
        "treatment": "COVID-19 treatment may include antivirals, monoclonal antibodies, or supportive care depending on severity.",
        "vaccine": "COVID-19 vaccines are safe, effective, and reduce risk of severe illness and hospitalization."
    }
    
    # Select relevant general information
    for key, info in covid_general_info.items():
        if key in query:
            return "GENERAL INFO: " + info
    
    return "GENERAL INFO: " + covid_general_info["general"]

# ------------------------------------------
# 3. Testing the Retrieval Function
# ------------------------------------------

# Let's test our retrieval function on a few examples
test_tweets = df['processed_text'].iloc[:5].tolist()
print("\nTesting retrieval function on sample tweets:")
for tweet in test_tweets:
    print("\nTWEET:", tweet[:100], "...")
    knowledge = retrieve_knowledge(tweet)
    print("RETRIEVED:", knowledge[:100], "...")

# ------------------------------------------
# 4. Retrieve Knowledge for All Tweets
# ------------------------------------------

# Determine sample size: use full dataset if small, otherwise sample
sample_size = min(1000, len(df))
if len(df) > sample_size:
    print(f"\nUsing a sample of {sample_size} tweets to avoid computational overhead")
    sampled_indices = np.random.choice(len(df), size=sample_size, replace=False)
    df_sample = df.iloc[sampled_indices].copy()
else:
    df_sample = df.copy()

# Retrieve knowledge for the sampled tweets
print("\nRetrieving external knowledge for tweets...")
df_sample['retrieved_knowledge'] = ""
for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample)):
    knowledge = retrieve_knowledge(row['processed_text'])
    df_sample.at[idx, 'retrieved_knowledge'] = knowledge

# Display a few examples
print("\nExamples of retrieved knowledge:")
for i in range(min(5, len(df_sample))):
    print("\nTWEET:", df_sample['processed_text'].iloc[i][:100], "...")
    print("RETRIEVED:", df_sample['retrieved_knowledge'].iloc[i][:100], "...")
    print("LABEL:", df_sample['label'].iloc[i])

# ------------------------------------------
# 5. Prepare Data for RAG Model
# ------------------------------------------

# Split into train, validation, and test sets
train_df, temp_df = train_test_split(df_sample, test_size=0.3, random_state=42, stratify=df_sample['label_encoded'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label_encoded'])

print("\nData split sizes:")
print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

# Create a new column combining tweet and retrieved knowledge
def combine_text_and_knowledge(text, knowledge):
    return f"Tweet: {text} [SEP] Knowledge: {knowledge}"

train_df['combined_text'] = train_df.apply(lambda x: combine_text_and_knowledge(
    x['processed_text'], x['retrieved_knowledge']), axis=1)
val_df['combined_text'] = val_df.apply(lambda x: combine_text_and_knowledge(
    x['processed_text'], x['retrieved_knowledge']), axis=1)
test_df['combined_text'] = test_df.apply(lambda x: combine_text_and_knowledge(
    x['processed_text'], x['retrieved_knowledge']), axis=1)

# Convert to HuggingFace datasets
train_dataset = Dataset.from_pandas(train_df[['combined_text', 'label_encoded']])
val_dataset = Dataset.from_pandas(val_df[['combined_text', 'label_encoded']])
test_dataset = Dataset.from_pandas(test_df[['combined_text', 'label_encoded']])

# ------------------------------------------
# 6. Define Model and Tokenizer
# ------------------------------------------

# We'll use a pre-trained model well-suited for tweet classification
model_name = "distilbert-base-uncased"  # Smaller model for faster training
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define maximum sequence length
max_length = 512  # Long enough for tweet + retrieved knowledge

# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples['combined_text'],
        padding="max_length",
        truncation=True,
        max_length=max_length
    )

# Apply tokenization
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_val = val_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)

# ------------------------------------------
# 7. Define Performance Metrics
# ------------------------------------------

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# ------------------------------------------
# 8. Train the RAG-based Model
# ------------------------------------------

# Load pre-trained model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Define training arguments compatible with older Transformers versions
batch_size = 16
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    # Set both strategies to "steps"
    eval_strategy="steps",  # Add this line
    save_strategy="steps",        # Explicitly set this
    save_steps=100,
    eval_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True
)
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    compute_metrics=compute_metrics
)

# Train the model
print("\nTraining the RAG-based model...")
trainer.train()

# ------------------------------------------
# 9. Evaluate on Test Set
# ------------------------------------------

print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_test)
print(f"Test results: {test_results}")

# ------------------------------------------
# 10. Compare with Baseline (No Knowledge)
# ------------------------------------------

print("\nPreparing baseline model (without retrieved knowledge)...")

# Create datasets without retrieved knowledge
train_df['tweet_only'] = train_df['processed_text']
val_df['tweet_only'] = val_df['processed_text']
test_df['tweet_only'] = test_df['processed_text']

baseline_train = Dataset.from_pandas(train_df[['tweet_only', 'label_encoded']])
baseline_val = Dataset.from_pandas(val_df[['tweet_only', 'label_encoded']])
baseline_test = Dataset.from_pandas(test_df[['tweet_only', 'label_encoded']])

# Tokenization function for baseline
def tokenize_baseline(examples):
    return tokenizer(
        examples['tweet_only'],
        padding="max_length",
        truncation=True,
        max_length=max_length
    )

tokenized_baseline_train = baseline_train.map(tokenize_baseline, batched=True)
tokenized_baseline_val = baseline_val.map(tokenize_baseline, batched=True)
tokenized_baseline_test = baseline_test.map(tokenize_baseline, batched=True)

# Train baseline model
baseline_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
baseline_trainer = Trainer(
    model=baseline_model,
    args=training_args,
    train_dataset=tokenized_baseline_train,
    eval_dataset=tokenized_baseline_val,
    compute_metrics=compute_metrics
)

print("\nTraining the baseline model...")
baseline_trainer.train()

# Evaluate baseline model
print("\nEvaluating baseline model on test set...")
baseline_results = baseline_trainer.evaluate(tokenized_baseline_test)
print(f"Baseline results: {baseline_results}")

# ------------------------------------------
# 11. Compare Results
# ------------------------------------------

print("\n=== COMPARISON OF RESULTS ===")
print("Metric    | Baseline | RAG-based")
print("--------------------------------")
for metric in ['accuracy', 'f1', 'precision', 'recall']:
    baseline_value = baseline_results.get(f'eval_{metric}', 0)
    rag_value = test_results.get(f'eval_{metric}', 0)
    diff = rag_value - baseline_value
    diff_str = f"{diff:.4f} ({'↑' if diff > 0 else '↓'})"
    print(f"{metric.ljust(10)}| {baseline_value:.4f} | {rag_value:.4f} ({diff_str})")

# ------------------------------------------
# 12. Case Study - Qualitative Analysis
# ------------------------------------------

# Let's examine some examples where RAG and baseline models disagree
print("\n=== CASE STUDIES ===")

# Make predictions using both models
baseline_preds = baseline_trainer.predict(tokenized_baseline_test)
rag_preds = trainer.predict(tokenized_test)

baseline_labels = baseline_preds.predictions.argmax(-1)
rag_labels = rag_preds.predictions.argmax(-1)
true_labels = test_df['label_encoded'].values

# Find examples where models disagree
disagreement_indices = np.where(baseline_labels != rag_labels)[0]
print(f"Found {len(disagreement_indices)} examples where models disagree")

# Select a few interesting examples for case study
case_study_indices = disagreement_indices[:min(5, len(disagreement_indices))]
for idx in case_study_indices:
    tweet = test_df['processed_text'].iloc[idx]
    knowledge = test_df['retrieved_knowledge'].iloc[idx]
    true_label = "Misinformation" if true_labels[idx] == 1 else "Reliable"
    baseline_pred = "Misinformation" if baseline_labels[idx] == 1 else "Reliable"
    rag_pred = "Misinformation" if rag_labels[idx] == 1 else "Reliable"
    
    print("\n---")
    print(f"Tweet: {tweet}")
    print(f"Retrieved Knowledge: {knowledge}")
    print(f"True Label: {true_label}")
    print(f"Baseline Prediction: {baseline_pred}")
    print(f"RAG Model Prediction: {rag_pred}")
    
    # Highlight which model was correct
    if rag_labels[idx] == true_labels[idx] and baseline_labels[idx] != true_labels[idx]:
        print("✓ RAG model was correct, baseline was wrong")
    elif baseline_labels[idx] == true_labels[idx] and rag_labels[idx] != true_labels[idx]:
        print("✗ Baseline was correct, RAG model was wrong")
    else:
        print("Both models were incorrect")

# ------------------------------------------
# 13. Save Models and Results
# ------------------------------------------

# Save the trained RAG model
print("\nSaving models...")
trainer.save_model('./rag_model')
baseline_trainer.save_model('./baseline_model')

# Save results summary
results_summary = {
    'RAG': test_results,
    'Baseline': baseline_results,
    'Sample_Size': len(df_sample)
}

import json
with open('./results_summary.json', 'w') as f:
    json.dump(results_summary, f)

print("\n=== COMPLETED RAG-BASED HEALTH MISINFORMATION DETECTION ===")

Using transformers version: 4.51.1
PyTorch version: 2.5.1+cu124
CUDA available: False
Loaded dataset with 12900 entries
                                             content           label
0  The CDC currently reports 99031 deaths. In gen...        Reliable
1  States reported 1121 deaths a small rise from ...        Reliable
2  Politically Correct Woman (Almost) Uses Pandem...  Misinformation
3  #IndiaFightsCorona: We have 1524 #COVID testin...        Reliable
4  Populous states can generate large case counts...        Reliable

Class distribution:
label
Reliable          6719
Misinformation    6181
Name: count, dtype: int64

Testing retrieval function on sample tweets:

TWEET: the cdc currently reports 99031 deaths. in general the discrepancies in death counts between differe ...
RETRIEVED: GENERAL INFO: COVID-19 is caused by the SARS-CoV-2 virus and spreads primarily through respiratory d ...

TWEET: states reported 1121 deaths a small rise from last tuesday. southern states reported

  0%|          | 0/1000 [00:00<?, ?it/s]


Examples of retrieved knowledge:

TWEET: bill gates who is supporting covid-19 vaccine research visited in new zealand during may. ...
RETRIEVED: FACT CHECK: Claims that Bill Gates is using vaccines for population control are false. ...
LABEL: Misinformation

TWEET: pak pm imran khan's wife tested positive for covid-19. ...
RETRIEVED: GENERAL INFO: COVID-19 is caused by the SARS-CoV-2 virus and spreads primarily through respiratory d ...
LABEL: Misinformation

TWEET: ???clearly, the obama administration did not leave any kind of game plan for something like this.??� ...
RETRIEVED: GENERAL INFO: COVID-19 is caused by the SARS-CoV-2 virus and spreads primarily through respiratory d ...
LABEL: Misinformation

TWEET: aaaaaaaaaaaaaaaaaaaaaa it had to hit while i was on spring break ...
RETRIEVED: GENERAL INFO: COVID-19 is caused by the SARS-CoV-2 virus and spreads primarily through respiratory d ...
LABEL: Misinformation

TWEET: ukrainian media registered the first confirmed case of the ne

Map:   0%|          | 0/700 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Training the RAG-based model...


<IPython.core.display.Javascript object>

In [3]:
!pip install --upgrade transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Downloading transformers-4.51.3-py3-none-any.whl (10.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m82.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.51.1
    Uninstalling transformers-4.51.1:
      Successfully uninstalled transformers-4.51.1
Successfully installed transformers-4.51.3
