In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from torchvision.models import resnet50
import csv
import timm


import os, random, torch, numpy as np

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    ##torch.backends.cudnn.deterministic = False
    ##torch.backends.cudnn.benchmark = True
    torch.use_deterministic_algorithms(True, warn_only=False)
    print(f"[INFO] All seeds set to {seed}")

seed_everything(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)

def free_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
# --------------------- CONFIG ---------------------

TEXT_MODEL_NAME = 'sentence-transformers/distilbert-multilingual-nli-stsb-quora-ranking'
NUM_CLASSES = 2
BATCH_SIZE = 8
NUM_EPOCHS = 3
FOLDS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("/kaggle/input/ekafnewsforkhawla/shuffled_cleaned_text_file for khawla.txt", sep="\t",encoding="utf-8")
print(len(df))

image_encoder = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT)


image_encoder.classifier = nn.Identity()

class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(0.4)

    def forward(self, query, key, value, key_padding_mask=None):
        attn_output, _ = self.attn(query, key, value, key_padding_mask=key_padding_mask)
        x = self.norm(query + self.dropout(attn_output))
        ff_output = self.norm(x + self.dropout(self.ff(x)))
        return ff_output

# Define the Full Multimodal Classifier
class Multimodal_Arabic_Fake_news_Identification_via_Hybrid_Attention_Networks(nn.Module):
    def __init__(self, text_model_name, image_encoder, num_classes):
        super().__init__()

        # Text encoder (e.g., BERT)
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, 512)

        # Image encoder (e.g., ResNet, ViT)
        self.image_encoder = image_encoder 
        self.image_proj = nn.Linear(1280, 512)

        # Cross-modal interaction attention
        self.cross_modal_attn = CrossAttentionLayer(embed_dim=512, num_heads=4)

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, image):
        # Encode text
        text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_embeds = self.text_proj(text_output.last_hidden_state)  # (B, L, 512)

        # Encode image
        image_feats1 = self.image_encoder(image) 
        image_feats = self.image_proj(image_feats1).unsqueeze(1)  # (B, 1, 512)
        ##visual_guided_text = self.cross_modal_attn(query=image_feats, key=text_embeds, value=text_embeds)
        # Cross-modal attention - level 1 
        img_cross = self.cross_modal_attn(query=text_embeds, key=image_feats, value=image_feats)
        txt_cross = self.cross_modal_attn(query=image_feats, key=text_embeds, value=text_embeds,key_padding_mask=~attention_mask.bool())
        ##visual_guided_text = self.cross_modal_attn(query=img_cross, key=txt_cross, value=txt_cross)

        # Cross-modal attention - level 2 
        cross_modal_img2txt = self.cross_modal_attn(query=txt_cross, key=img_cross, value=img_cross)
        cross_modal_txt2img = self.cross_modal_attn(query=img_cross, key=txt_cross, value=txt_cross)
        ##visual_guided_text = self.cross_modal_attn(query=cross_modal_img2txt, key=cross_modal_txt2img, value=cross_modal_txt2img)

        # Cross-modal attention - level 3 
        cross_modal_img2txt1 = self.cross_modal_attn(query=cross_modal_txt2img, key=cross_modal_img2txt, value=cross_modal_img2txt)
        cross_modal_txt2img1 = self.cross_modal_attn(query=cross_modal_img2txt, key=cross_modal_txt2img, value=cross_modal_txt2img)
        ##visual_guided_text = self.cross_modal_attn(query=cross_modal_img2txt1, key=cross_modal_txt2img1, value=cross_modal_txt2img1)
        # Cross-modal attention - level 4   
        cross_modal_img2txt2 = self.cross_modal_attn(query=cross_modal_txt2img1, key=cross_modal_img2txt1, value=cross_modal_img2txt1)
        cross_modal_txt2img2 = self.cross_modal_attn(query=cross_modal_img2txt1, key=cross_modal_txt2img1, value=cross_modal_txt2img1)

        # Visual Guided Text attention
        visual_guided_text = self.cross_modal_attn(query=cross_modal_img2txt2, key=cross_modal_txt2img2, value=cross_modal_txt2img2)

        # Mean pooling
        final_multi_repr = torch.mean(visual_guided_text, dim=1)  # (B, 512)


        output = self.classifier(final_multi_repr)
        return output


# --------------------- DATASET ---------------------
class ArabicMultimodal_Fake_News_Dataset(Dataset):
    def __init__(self, samples, tokenizer, transform):
        self.samples = samples
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        ##image = Image.open(sample['image_path']).convert('RGB')
        image = Image.open(sample[1]).convert('RGB')
        image = self.transform(image)

        text_enc = self.tokenizer(
            sample[0], padding='max_length', truncation=True, max_length=64, return_tensors='pt'
        )
        return {
            'input_ids': text_enc['input_ids'].squeeze(0),
            'attention_mask': text_enc['attention_mask'].squeeze(0),
            'image': image,
            'label': torch.tensor(sample[2], dtype=torch.long)
        }

# --------------------- UTILS ---------------------


def load_data(annotation_file, image_root):
    samples = []
    label2folder = {0: 'Fake', 1: 'Real'}

    with open(annotation_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            tweet_id = row['id']
            text = row['preprocess1']
            label = int(row['label'])
            labelstr=row['label']
            dataset = row['dataset']

            folder_name = label2folder[label]

            # Try different extensions
            image_path = None
            for ext in ['jpg', 'png', 'jpeg']:
                temp_path = os.path.join(image_root, folder_name, f"{tweet_id}.{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break

            if image_path:
                samples.append((text, image_path, label,dataset))
            else:
                print(f"Image not found for {tweet_id} with label {label}")

    return samples

# --------------------- TRAIN + EVAL ---------------------
def train_model(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_correct = 0, 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        images = batch['image'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
    return total_loss / len(dataloader), total_correct / len(dataloader.dataset)

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            images = batch['image'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            outputs = model(input_ids, attention_mask, images)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_preds, all_labels


# --------------------- MAIN ---------------------
if __name__ == '__main__':
    ##label2idx = {'Fake': 1, 'Real': 0}
    label2idx = {1: 1, 0: 0}  
    idx2label = {v: k for k, v in label2idx.items()}
    samples = load_data(
    '/kaggle/input/ekafnewsforkhawla/shuffled_cleaned_text_file for khawla.txt',
    '/kaggle/input/ekafnewsforkhawla/sorted_imagesOur')
    
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
    transform = models.EfficientNet_B1_Weights.DEFAULT.transforms()
    
    labels = [sample[2] for sample in samples]  
    strat =labels
    skf = StratifiedKFold(n_splits=FOLDS, shuffle=False, random_state=None)
    all_preds, all_trues = [], []
    for fold, (train_idx, val_idx) in enumerate(skf.split(samples, strat)):
        print(f"Fold {fold + 1}/{FOLDS}")
        train_samples = [samples[i] for i in train_idx]
        val_samples = [samples[i] for i in val_idx]

        train_dataset = ArabicMultimodal_Fake_News_Dataset(train_samples, tokenizer, transform)
        val_dataset = ArabicMultimodal_Fake_News_Dataset(val_samples, tokenizer, transform)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False,worker_init_fn=seed_worker, generator=g)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,worker_init_fn=seed_worker, generator=g)

        model = Multimodal_Arabic_Fake_news_Identification_via_Hybrid_Attention_Networks(text_model_name=TEXT_MODEL_NAME,image_encoder=image_encoder,num_classes=NUM_CLASSES).to(DEVICE)


        ##model = MultiModalFusionModel(TEXT_MODEL_NAME, NUM_CLASSES).to(DEVICE)
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(NUM_EPOCHS):
            train_loss, train_acc = train_model(model, train_loader, optimizer, criterion)
            print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f}")

        preds, trues = evaluate_model(model, val_loader)
        print(f"Fold {fold + 1} Classification Report:")
        print(classification_report(trues, preds, target_names=["0","1"], digits=4))
        print("Confusion Matrix:")
        print(confusion_matrix(trues, preds))

        all_preds.extend(preds)
        all_trues.extend(trues)
        free_gpu_memory()

    print("\n=== Overall Classification Report ===")
    print(TEXT_MODEL_NAME)
    print("\n=== Overall Classification Report ===")
    print(classification_report(all_trues, all_preds, target_names=["0","1"], digits=4))
    print("=== Overall Confusion Matrix ===")
    print(confusion_matrix(all_trues, all_preds))


[INFO] All seeds set to 42
5138
Downloading: "https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b1-c27df63c.pth


100%|██████████| 30.1M/30.1M [00:00<00:00, 120MB/s] 


tokenizer_config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/589 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Fold 1/5


model.safetensors:   0%|          | 0.00/539M [00:00<?, ?B/s]



Epoch 1/3 - Loss: 0.4346 - Train Acc: 0.7983
Epoch 2/3 - Loss: 0.2340 - Train Acc: 0.9080
Epoch 3/3 - Loss: 0.1343 - Train Acc: 0.9506
Fold 1 Classification Report:
              precision    recall  f1-score   support

           0     0.7671    0.8235    0.7943       272
           1     0.9348    0.9101    0.9223       756

    accuracy                         0.8872      1028
   macro avg     0.8510    0.8668    0.8583      1028
weighted avg     0.8904    0.8872    0.8884      1028

Confusion Matrix:
[[224  48]
 [ 68 688]]
Fold 2/5




Epoch 1/3 - Loss: 0.3903 - Train Acc: 0.8178
Epoch 2/3 - Loss: 0.1827 - Train Acc: 0.9285
Epoch 3/3 - Loss: 0.1036 - Train Acc: 0.9618
Fold 2 Classification Report:
              precision    recall  f1-score   support

           0     0.8379    0.7794    0.8076       272
           1     0.9226    0.9458    0.9340       756

    accuracy                         0.9018      1028
   macro avg     0.8803    0.8626    0.8708      1028
weighted avg     0.9002    0.9018    0.9006      1028

Confusion Matrix:
[[212  60]
 [ 41 715]]
Fold 3/5




Epoch 1/3 - Loss: 0.3571 - Train Acc: 0.8401
Epoch 2/3 - Loss: 0.1644 - Train Acc: 0.9394
Epoch 3/3 - Loss: 0.0859 - Train Acc: 0.9720
Fold 3 Classification Report:
              precision    recall  f1-score   support

           0     0.7024    0.8645    0.7750       273
           1     0.9465    0.8675    0.9053       755

    accuracy                         0.8667      1028
   macro avg     0.8245    0.8660    0.8402      1028
weighted avg     0.8817    0.8667    0.8707      1028

Confusion Matrix:
[[236  37]
 [100 655]]
Fold 4/5
Epoch 1/3 - Loss: 0.3341 - Train Acc: 0.8489
Epoch 2/3 - Loss: 0.1311 - Train Acc: 0.9509
Epoch 3/3 - Loss: 0.0845 - Train Acc: 0.9691




Fold 4 Classification Report:
              precision    recall  f1-score   support

           0     0.7743    0.9081    0.8359       272
           1     0.9647    0.9046    0.9337       755

    accuracy                         0.9056      1027
   macro avg     0.8695    0.9064    0.8848      1027
weighted avg     0.9143    0.9056    0.9078      1027

Confusion Matrix:
[[247  25]
 [ 72 683]]
Fold 5/5




Epoch 1/3 - Loss: 0.2632 - Train Acc: 0.8903
Epoch 2/3 - Loss: 0.1126 - Train Acc: 0.9579
Epoch 3/3 - Loss: 0.0733 - Train Acc: 0.9752
Fold 5 Classification Report:
              precision    recall  f1-score   support

           0     0.6897    0.9559    0.8012       272
           1     0.9815    0.8450    0.9082       755

    accuracy                         0.8744      1027
   macro avg     0.8356    0.9005    0.8547      1027
weighted avg     0.9042    0.8744    0.8799      1027

Confusion Matrix:
[[260  12]
 [117 638]]

=== Overall Classification Report ===
sentence-transformers/distilbert-multilingual-nli-stsb-quora-ranking

=== Overall Classification Report ===
              precision    recall  f1-score   support

           0     0.7476    0.8663    0.8026      1361
           1     0.9489    0.8946    0.9210      3777

    accuracy                         0.8871      5138
   macro avg     0.8483    0.8805    0.8618      5138
weighted avg     0.8956    0.8871    0.8896     