In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from torchvision.models import resnet50
import csv
import timm


import os, random, torch, numpy as np

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    ##torch.backends.cudnn.deterministic = False
    ##torch.backends.cudnn.benchmark = True
    torch.use_deterministic_algorithms(True, warn_only=False)
    print(f"[INFO] All seeds set to {seed}")

seed_everything(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)

def free_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
# --------------------- CONFIG ---------------------

TEXT_MODEL_NAME = 'UBC-NLP/MARBERTv2'
NUM_CLASSES = 2
BATCH_SIZE = 8
NUM_EPOCHS = 3
FOLDS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("/kaggle/input/ekafnewsforkhawla/shuffled_cleaned_text_file for khawla.txt", sep="\t",encoding="utf-8")
print(len(df))


image_encoder = models.vit_b_32(weights=models.ViT_B_32_Weights.IMAGENET1K_V1)

image_encoder.heads = nn.Identity()

class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(0.4)

    def forward(self, query, key, value, key_padding_mask=None):
        attn_output, _ = self.attn(query, key, value, key_padding_mask=key_padding_mask)
        x = self.norm(query + self.dropout(attn_output))
        ff_output = self.norm(x + self.dropout(self.ff(x)))
        return ff_output

# Define the Full Multimodal Classifier
class Multimodal_Arabic_Fake_news_Identification_via_Hybrid_Attention_Networks(nn.Module):
    def __init__(self, text_model_name, image_encoder, num_classes):
        super().__init__()

        # Text encoder (e.g., BERT)
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, 512)

        # Image encoder (e.g., ResNet, ViT)
        self.image_encoder = image_encoder  
        self.image_proj = nn.Linear(768, 512)

        # Cross-modal interaction attention
        self.cross_modal_attn = CrossAttentionLayer(embed_dim=512, num_heads=4)

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, image):
        # Encode text
        text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_embeds = self.text_proj(text_output.last_hidden_state)  # (B, L, 512)

        # Encode image
        image_feats1 = self.image_encoder(image)
        image_feats = self.image_proj(image_feats1).unsqueeze(1)  # (B, 1, 512)
        ##visual_guided_text = self.cross_modal_attn(query=image_feats, key=text_embeds, value=text_embeds)
        # Cross-modal attention - level 1 
        img_cross = self.cross_modal_attn(query=text_embeds, key=image_feats, value=image_feats)
        txt_cross = self.cross_modal_attn(query=image_feats, key=text_embeds, value=text_embeds,key_padding_mask=~attention_mask.bool())
        ##visual_guided_text = self.cross_modal_attn(query=img_cross, key=txt_cross, value=txt_cross)

        # Cross-modal attention - level 2 
        cross_modal_img2txt = self.cross_modal_attn(query=txt_cross, key=img_cross, value=img_cross)
        cross_modal_txt2img = self.cross_modal_attn(query=img_cross, key=txt_cross, value=txt_cross)
        ##visual_guided_text = self.cross_modal_attn(query=cross_modal_img2txt, key=cross_modal_txt2img, value=cross_modal_txt2img)

        # Cross-modal attention - level 3 
        cross_modal_img2txt1 = self.cross_modal_attn(query=cross_modal_txt2img, key=cross_modal_img2txt, value=cross_modal_img2txt)
        cross_modal_txt2img1 = self.cross_modal_attn(query=cross_modal_img2txt, key=cross_modal_txt2img, value=cross_modal_txt2img)
        ##visual_guided_text = self.cross_modal_attn(query=cross_modal_img2txt1, key=cross_modal_txt2img1, value=cross_modal_txt2img1)
        # Cross-modal attention - level 4   
        cross_modal_img2txt2 = self.cross_modal_attn(query=cross_modal_txt2img1, key=cross_modal_img2txt1, value=cross_modal_img2txt1)
        cross_modal_txt2img2 = self.cross_modal_attn(query=cross_modal_img2txt1, key=cross_modal_txt2img1, value=cross_modal_txt2img1)

        # Visual Guided Text attention
        visual_guided_text = self.cross_modal_attn(query=cross_modal_img2txt2, key=cross_modal_txt2img2, value=cross_modal_txt2img2)

        # Mean pooling
        final_multi_repr = torch.mean(visual_guided_text, dim=1)  # (B, 512)

        output = self.classifier(final_multi_repr)
        return output


# --------------------- DATASET ---------------------
class ArabicMultimodal_Fake_News_Dataset(Dataset):
    def __init__(self, samples, tokenizer, transform):
        self.samples = samples
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        ##image = Image.open(sample['image_path']).convert('RGB')
        image = Image.open(sample[1]).convert('RGB')
        image = self.transform(image)

        text_enc = self.tokenizer(
            sample[0], padding='max_length', truncation=True, max_length=64, return_tensors='pt'
        )
        return {
            'input_ids': text_enc['input_ids'].squeeze(0),
            'attention_mask': text_enc['attention_mask'].squeeze(0),
            'image': image,
            'label': torch.tensor(sample[2], dtype=torch.long)
        }

# --------------------- UTILS ---------------------


def load_data(annotation_file, image_root):
    samples = []
    label2folder = {0: 'Fake', 1: 'Real'}

    with open(annotation_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            tweet_id = row['id']
            text = row['preprocess1']
            label = int(row['label'])
            labelstr=row['label']
            dataset = row['dataset']

            folder_name = label2folder[label]

            # Try different extensions
            image_path = None
            for ext in ['jpg', 'png', 'jpeg']:
                temp_path = os.path.join(image_root, folder_name, f"{tweet_id}.{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break

            if image_path:
                samples.append((text, image_path, label,dataset))
            else:
                print(f"Image not found for {tweet_id} with label {label}")

    return samples

# --------------------- TRAIN + EVAL ---------------------
def train_model(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_correct = 0, 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        images = batch['image'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
    return total_loss / len(dataloader), total_correct / len(dataloader.dataset)

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            images = batch['image'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            outputs = model(input_ids, attention_mask, images)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_preds, all_labels


# --------------------- MAIN ---------------------
if __name__ == '__main__':
    ##label2idx = {'Fake': 1, 'Real': 0}
    label2idx = {1: 1, 0: 0}  
    idx2label = {v: k for k, v in label2idx.items()}
    samples = load_data(
    '/kaggle/input/ekafnewsforkhawla/shuffled_cleaned_text_file for khawla.txt',
    '/kaggle/input/ekafnewsforkhawla/sorted_imagesOur')
    
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
    transform = models.ViT_B_32_Weights.IMAGENET1K_V1.transforms()
    
    labels = [sample[2] for sample in samples]
    strat =labels
    skf = StratifiedKFold(n_splits=FOLDS, shuffle=False, random_state=None)
    all_preds, all_trues = [], []
    for fold, (train_idx, val_idx) in enumerate(skf.split(samples, strat)):
        print(f"Fold {fold + 1}/{FOLDS}")
        train_samples = [samples[i] for i in train_idx]
        val_samples = [samples[i] for i in val_idx]

        train_dataset = ArabicMultimodal_Fake_News_Dataset(train_samples, tokenizer, transform)
        val_dataset = ArabicMultimodal_Fake_News_Dataset(val_samples, tokenizer, transform)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,worker_init_fn=seed_worker, generator=g)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,worker_init_fn=seed_worker, generator=g)

        model = Multimodal_Arabic_Fake_news_Identification_via_Hybrid_Attention_Networks(text_model_name=TEXT_MODEL_NAME,image_encoder=image_encoder,num_classes=NUM_CLASSES).to(DEVICE)


        ##model = MultiModalFusionModel(TEXT_MODEL_NAME, NUM_CLASSES).to(DEVICE)
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(NUM_EPOCHS):
            train_loss, train_acc = train_model(model, train_loader, optimizer, criterion)
            print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f}")

        preds, trues = evaluate_model(model, val_loader)
        print(f"Fold {fold + 1} Classification Report:")
        print(classification_report(trues, preds, target_names=["0","1"], digits=4))
        print("Confusion Matrix:")
        print(confusion_matrix(trues, preds))

        all_preds.extend(preds)
        all_trues.extend(trues)
        free_gpu_memory()

    print("\n=== Overall Classification Report ===")
    print(TEXT_MODEL_NAME)
    print("\n=== Overall Classification Report ===")
    print(classification_report(all_trues, all_preds, target_names=["0","1"], digits=4))
    print("=== Overall Confusion Matrix ===")
    print(confusion_matrix(all_trues, all_preds))




[INFO] All seeds set to 42
5138
Downloading: "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth" to /root/.cache/torch/hub/checkpoints/vit_b_32-d86f8d99.pth


100%|██████████| 337M/337M [00:01<00:00, 188MB/s]  


tokenizer_config.json:   0%|          | 0.00/439 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Fold 1/5


config.json:   0%|          | 0.00/757 [00:00<?, ?B/s]

2026-01-31 20:16:26.309617: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769890586.700109      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769890586.862037      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769890587.936299      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769890587.936356      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769890587.936359      55 computation_placer.cc:177] computation placer alr

pytorch_model.bin:   0%|          | 0.00/654M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/654M [00:00<?, ?B/s]



Epoch 1/3 - Loss: 0.4462 - Train Acc: 0.7881
Epoch 2/3 - Loss: 0.2146 - Train Acc: 0.9105
Epoch 3/3 - Loss: 0.0972 - Train Acc: 0.9601
Fold 1 Classification Report:
              precision    recall  f1-score   support

           0     0.8565    0.7463    0.7976       272
           1     0.9128    0.9550    0.9334       756

    accuracy                         0.8998      1028
   macro avg     0.8847    0.8507    0.8655      1028
weighted avg     0.8979    0.8998    0.8975      1028

Confusion Matrix:
[[203  69]
 [ 34 722]]
Fold 2/5




Epoch 1/3 - Loss: 0.4114 - Train Acc: 0.8109
Epoch 2/3 - Loss: 0.2201 - Train Acc: 0.9114
Epoch 3/3 - Loss: 0.1143 - Train Acc: 0.9606
Fold 2 Classification Report:
              precision    recall  f1-score   support

           0     0.9146    0.6691    0.7728       272
           1     0.8914    0.9775    0.9325       756

    accuracy                         0.8959      1028
   macro avg     0.9030    0.8233    0.8527      1028
weighted avg     0.8976    0.8959    0.8902      1028

Confusion Matrix:
[[182  90]
 [ 17 739]]
Fold 3/5




Epoch 1/3 - Loss: 0.4394 - Train Acc: 0.7866
Epoch 2/3 - Loss: 0.2137 - Train Acc: 0.9114
Epoch 3/3 - Loss: 0.0995 - Train Acc: 0.9659
Fold 3 Classification Report:
              precision    recall  f1-score   support

           0     0.8692    0.6813    0.7639       273
           1     0.8931    0.9629    0.9267       755

    accuracy                         0.8881      1028
   macro avg     0.8811    0.8221    0.8453      1028
weighted avg     0.8868    0.8881    0.8835      1028

Confusion Matrix:
[[186  87]
 [ 28 727]]
Fold 4/5
Epoch 1/3 - Loss: 0.4211 - Train Acc: 0.8066
Epoch 2/3 - Loss: 0.2056 - Train Acc: 0.9253
Epoch 3/3 - Loss: 0.1096 - Train Acc: 0.9645




Fold 4 Classification Report:
              precision    recall  f1-score   support

           0     0.7860    0.7831    0.7845       272
           1     0.9220    0.9232    0.9226       755

    accuracy                         0.8861      1027
   macro avg     0.8540    0.8531    0.8535      1027
weighted avg     0.8859    0.8861    0.8860      1027

Confusion Matrix:
[[213  59]
 [ 58 697]]
Fold 5/5




Epoch 1/3 - Loss: 0.4155 - Train Acc: 0.8151
Epoch 2/3 - Loss: 0.1928 - Train Acc: 0.9241
Epoch 3/3 - Loss: 0.0832 - Train Acc: 0.9703
Fold 5 Classification Report:
              precision    recall  f1-score   support

           0     0.8550    0.8235    0.8390       272
           1     0.9373    0.9497    0.9434       755

    accuracy                         0.9163      1027
   macro avg     0.8961    0.8866    0.8912      1027
weighted avg     0.9155    0.9163    0.9158      1027

Confusion Matrix:
[[224  48]
 [ 38 717]]

=== Overall Classification Report ===
UBC-NLP/MARBERTv2

=== Overall Classification Report ===
              precision    recall  f1-score   support

           0     0.8521    0.7406    0.7925      1361
           1     0.9107    0.9537    0.9317      3777

    accuracy                         0.8972      5138
   macro avg     0.8814    0.8471    0.8621      5138
weighted avg     0.8952    0.8972    0.8948      5138

=== Overall Confusion Matrix ===
[[1008  353