In [4]:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from torchvision.models import resnet50
import csv
import timm


import os, random, torch, numpy as np

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    ##torch.backends.cudnn.deterministic = False
    ##torch.backends.cudnn.benchmark = True
    torch.use_deterministic_algorithms(True, warn_only=False)
    print(f"[INFO] All seeds set to {seed}")

seed_everything(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)

def free_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
# --------------------- CONFIG ---------------------

TEXT_MODEL_NAME = 'UBC-NLP/MARBERTv2'
NUM_CLASSES = 2
BATCH_SIZE = 8
NUM_EPOCHS = 3
FOLDS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("/kaggle/input/ekafnewsforkhawla/shuffled_cleaned_text_file for khawla.txt", sep="\t",encoding="utf-8")
print(len(df))

image_encoder = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT)

image_encoder.classifier = nn.Identity()


class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(0.4)

    def forward(self, query, key, value, key_padding_mask=None):
        attn_output, _ = self.attn(query, key, value, key_padding_mask=key_padding_mask)
        x = self.norm(query + self.dropout(attn_output))
        ff_output = self.norm(x + self.dropout(self.ff(x)))
        return ff_output

# Define the Full Multimodal Classifier
class Multimodal_Arabic_Fake_news_Identification_via_Hybrid_Attention_Networks(nn.Module):
    def __init__(self, text_model_name, image_encoder, num_classes):
        super().__init__()

        # Text encoder (e.g., BERT)
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, 512)

        # Image encoder (e.g., ResNet, ViT)
        self.image_encoder = image_encoder 
        self.image_proj = nn.Linear(1280, 512)

        # Cross-modal interaction attention
        self.cross_modal_attn = CrossAttentionLayer(embed_dim=512, num_heads=4)

        # Final classifier
        self.classifier = nn.Sequential(nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, num_classes))
    def forward(self, input_ids, attention_mask, image):
        # Encode text
        text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_embeds = self.text_proj(text_output.last_hidden_state) 

        # Encode image
        image_feats1 = self.image_encoder(image) 
        image_feats = self.image_proj(image_feats1).unsqueeze(1) 
        visual_guided_text = self.cross_modal_attn(query=image_feats, key=text_embeds, value=text_embeds)
        
        final_multi_repr = torch.mean(visual_guided_text, dim=1) 

        output = self.classifier(final_multi_repr)
        return output


# --------------------- DATASET ---------------------
class ArabicMultimodal_Fake_News_Dataset(Dataset):
    def __init__(self, samples, tokenizer, transform):
        self.samples = samples
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        ##image = Image.open(sample['image_path']).convert('RGB')
        image = Image.open(sample[1]).convert('RGB')
        image = self.transform(image)

        text_enc = self.tokenizer(
            sample[0], padding='max_length', truncation=True, max_length=64, return_tensors='pt'
        )
        return {
            'input_ids': text_enc['input_ids'].squeeze(0),
            'attention_mask': text_enc['attention_mask'].squeeze(0),
            'image': image,
            'label': torch.tensor(sample[2], dtype=torch.long)
        }

# --------------------- UTILS ---------------------


def load_data(annotation_file, image_root):
    samples = []
    label2folder = {0: 'Fake', 1: 'Real'}

    with open(annotation_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            tweet_id = row['id']
            text = row['preprocess1']
            label = int(row['label'])
            labelstr=row['label']
            dataset = row['dataset']

            folder_name = label2folder[label]

            # Try different extensions
            image_path = None
            for ext in ['jpg', 'png', 'jpeg']:
                temp_path = os.path.join(image_root, folder_name, f"{tweet_id}.{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break

            if image_path:
                samples.append((text, image_path, label,dataset))
            else:
                print(f"Image not found for {tweet_id} with label {label}")

    return samples

# --------------------- TRAIN + EVAL ---------------------
def train_model(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_correct = 0, 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        images = batch['image'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
    return total_loss / len(dataloader), total_correct / len(dataloader.dataset)

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            images = batch['image'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            outputs = model(input_ids, attention_mask, images)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_preds, all_labels


# --------------------- MAIN ---------------------
if __name__ == '__main__':
    ##label2idx = {'Fake': 1, 'Real': 0}
    label2idx = {1: 1, 0: 0}  
    idx2label = {v: k for k, v in label2idx.items()}
    samples = load_data(
    '/kaggle/input/ekafnewsforkhawla/shuffled_cleaned_text_file for khawla.txt',
    '/kaggle/input/ekafnewsforkhawla/sorted_imagesOur')
    
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
    transform = models.EfficientNet_B1_Weights.DEFAULT.transforms()
    
    labels = [sample[2] for sample in samples]
    
    strat =labels
    skf = StratifiedKFold(n_splits=FOLDS, shuffle=False, random_state=None)
    all_preds, all_trues = [], []
    for fold, (train_idx, val_idx) in enumerate(skf.split(samples, strat)):
        print(f"Fold {fold + 1}/{FOLDS}")
        train_samples = [samples[i] for i in train_idx]
        val_samples = [samples[i] for i in val_idx]

        train_dataset = ArabicMultimodal_Fake_News_Dataset(train_samples, tokenizer, transform)
        val_dataset = ArabicMultimodal_Fake_News_Dataset(val_samples, tokenizer, transform)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,worker_init_fn=seed_worker, generator=g)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,worker_init_fn=seed_worker, generator=g)

        model = Multimodal_Arabic_Fake_news_Identification_via_Hybrid_Attention_Networks(text_model_name=TEXT_MODEL_NAME,image_encoder=image_encoder,num_classes=NUM_CLASSES).to(DEVICE)
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(NUM_EPOCHS):
            train_loss, train_acc = train_model(model, train_loader, optimizer, criterion)
            print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f}")

        preds, trues = evaluate_model(model, val_loader)
        print(f"Fold {fold + 1} Classification Report:")
        print(classification_report(trues, preds, target_names=["0","1"], digits=4))
        print("Confusion Matrix:")
        print(confusion_matrix(trues, preds))

        all_preds.extend(preds)
        all_trues.extend(trues)
        free_gpu_memory()

    print("\n=== Overall Classification Report ===")
    print(TEXT_MODEL_NAME)
    print("\n=== Overall Classification Report ===")
    print(classification_report(all_trues, all_preds, target_names=["0","1"], digits=4))
    print("=== Overall Confusion Matrix ===")
    print(confusion_matrix(all_trues, all_preds))


[INFO] All seeds set to 42
5138
Fold 1/5




Epoch 1/3 - Loss: 0.4596 - Train Acc: 0.7835
Epoch 2/3 - Loss: 0.2259 - Train Acc: 0.9051
Epoch 3/3 - Loss: 0.0895 - Train Acc: 0.9684
Fold 1 Classification Report:
              precision    recall  f1-score   support

           0     0.7639    0.8566    0.8076       272
           1     0.9461    0.9048    0.9249       756

    accuracy                         0.8920      1028
   macro avg     0.8550    0.8807    0.8663      1028
weighted avg     0.8979    0.8920    0.8939      1028

Confusion Matrix:
[[233  39]
 [ 72 684]]
Fold 2/5




Epoch 1/3 - Loss: 0.4033 - Train Acc: 0.8073
Epoch 2/3 - Loss: 0.1706 - Train Acc: 0.9316
Epoch 3/3 - Loss: 0.0623 - Train Acc: 0.9803
Fold 2 Classification Report:
              precision    recall  f1-score   support

           0     0.7895    0.8824    0.8333       272
           1     0.9558    0.9153    0.9351       756

    accuracy                         0.9066      1028
   macro avg     0.8726    0.8988    0.8842      1028
weighted avg     0.9118    0.9066    0.9082      1028

Confusion Matrix:
[[240  32]
 [ 64 692]]
Fold 3/5




Epoch 1/3 - Loss: 0.3799 - Train Acc: 0.8316
Epoch 2/3 - Loss: 0.1360 - Train Acc: 0.9496
Epoch 3/3 - Loss: 0.0624 - Train Acc: 0.9808
Fold 3 Classification Report:
              precision    recall  f1-score   support

           0     0.7914    0.8755    0.8313       273
           1     0.9532    0.9166    0.9345       755

    accuracy                         0.9056      1028
   macro avg     0.8723    0.8960    0.8829      1028
weighted avg     0.9102    0.9056    0.9071      1028

Confusion Matrix:
[[239  34]
 [ 63 692]]
Fold 4/5
Epoch 1/3 - Loss: 0.3363 - Train Acc: 0.8523
Epoch 2/3 - Loss: 0.1045 - Train Acc: 0.9613
Epoch 3/3 - Loss: 0.0430 - Train Acc: 0.9852




Fold 4 Classification Report:
              precision    recall  f1-score   support

           0     0.7823    0.9118    0.8421       272
           1     0.9662    0.9086    0.9365       755

    accuracy                         0.9094      1027
   macro avg     0.8743    0.9102    0.8893      1027
weighted avg     0.9175    0.9094    0.9115      1027

Confusion Matrix:
[[248  24]
 [ 69 686]]
Fold 5/5




Epoch 1/3 - Loss: 0.2897 - Train Acc: 0.8762
Epoch 2/3 - Loss: 0.0867 - Train Acc: 0.9715
Epoch 3/3 - Loss: 0.0529 - Train Acc: 0.9849
Fold 5 Classification Report:
              precision    recall  f1-score   support

           0     0.8781    0.9007    0.8893       272
           1     0.9639    0.9550    0.9594       755

    accuracy                         0.9406      1027
   macro avg     0.9210    0.9279    0.9244      1027
weighted avg     0.9412    0.9406    0.9408      1027

Confusion Matrix:
[[245  27]
 [ 34 721]]

=== Overall Classification Report ===
UBC-NLP/MARBERTv2

=== Overall Classification Report ===
              precision    recall  f1-score   support

           0     0.7996    0.8854    0.8403      1361
           1     0.9570    0.9200    0.9382      3777

    accuracy                         0.9109      5138
   macro avg     0.8783    0.9027    0.8892      5138
weighted avg     0.9153    0.9109    0.9123      5138

=== Overall Confusion Matrix ===
[[1205  156