In [6]:
# =========================
# IMPORTS
# =========================
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings("ignore")

from torchvision.models import (
    ResNet50_Weights,
    MobileNet_V2_Weights,
    MobileNet_V3_Small_Weights,
    EfficientNet_B1_Weights,
    Inception_V3_Weights
)

# =========================
# SEED (FULL DETERMINISM)
# =========================
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# TRANSFORM FACTORY
# =========================
def get_transform(model_name):
    if model_name == "resnet50":
        return ResNet50_Weights.IMAGENET1K_V1.transforms()
    elif model_name == "mobilenet_v2":
        return MobileNet_V2_Weights.IMAGENET1K_V1.transforms()
    elif model_name == "mobilenet_v3_small":
        return MobileNet_V3_Small_Weights.IMAGENET1K_V1.transforms()
    elif model_name == "efficientnet_b1":
        return EfficientNet_B1_Weights.IMAGENET1K_V1.transforms()
    elif model_name == "inception_v3":
        return Inception_V3_Weights.IMAGENET1K_V1.transforms()  # 299x299
    else:
        raise ValueError("Unsupported model")

# =========================
# MODEL FACTORY (CNN ONLY)
# =========================
def get_model(name, num_classes):
    if name == "resnet50":
        model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        head_params = model.fc.parameters()

    elif name == "mobilenet_v2":
        model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        head_params = model.classifier.parameters()

    elif name == "mobilenet_v3_small":
        model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        head_params = model.classifier.parameters()

    elif name == "efficientnet_b1":
        model = models.efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        head_params = model.classifier.parameters()

    elif name == "inception_v3":
        ##model = models.inception_v3(
            ##weights=Inception_V3_Weights.IMAGENET1K_V1,
            ##aux_logits=False
        ##)
        model = models.inception_v3(aux_logits=True)
        model.aux_logits = False
        model.AuxLogits = None
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        head_params = model.fc.parameters()

    else:
        raise ValueError("Model not supported")

    # ğŸ”’ Freeze backbone (FAIR CNN BASELINE)
    for p in model.parameters():
        p.requires_grad = False
    for p in head_params:
        p.requires_grad = True

    return model.to(device)

# =========================
# DATA PATH
# =========================
data_dir = "/kaggle/input/ekafnewsforkhawla/sorted_imagesOur"

models_list = [
    "resnet50",
    "mobilenet_v2",
    "mobilenet_v3_small",
    "efficientnet_b1",
    "inception_v3"
]

# =========================
# TRAIN & EVAL
# =========================
for model_name in models_list:
    print(f"\nğŸš€ MODEL: {model_name.upper()}")

    transform = get_transform(model_name)
    dataset = ImageFolder(data_dir, transform=transform)
    targets = np.array(dataset.targets)
    class_names = dataset.classes

    # ğŸ”¥ CLASS WEIGHTS (CRITICAL)
    class_counts = np.bincount(targets)
    class_weights = 1. / class_counts
    weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    criterion = nn.CrossEntropyLoss(weight=weights)

    skf = StratifiedKFold(n_splits=5, shuffle=False, random_state=None)

    all_y_true, all_y_pred = [], []

    for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets), 1):
        print(f"ğŸ”¹ Fold {fold}/5")

        train_ds = Subset(dataset, train_idx)
        val_ds   = Subset(dataset, val_idx)

        train_loader = DataLoader(train_ds,batch_size=8,shuffle=True,num_workers=2,pin_memory=True)
        val_loader = DataLoader(val_ds,batch_size=8, shuffle=False,  num_workers=2,pin_memory=True)
        model = get_model(model_name, num_classes=len(class_names))
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

        # -------- TRAIN --------
        model.train()
        for epoch in range(3):
            for imgs, lbls in train_loader:
                imgs, lbls = imgs.to(device), lbls.to(device)
                optimizer.zero_grad()
                outputs = model(imgs)
                loss = criterion(outputs, lbls)
                loss.backward()
                optimizer.step()

        # -------- VALIDATION --------
        model.eval()
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs = imgs.to(device)
                outputs = model(imgs)
                preds = torch.argmax(outputs, dim=1).cpu().numpy()
                all_y_pred.extend(preds)
                all_y_true.extend(lbls.numpy())

    print("âœ…" * 60)
    print(f"FINAL RESULTS â€” {model_name.upper()}")
    print(classification_report(all_y_true, all_y_pred, target_names=class_names, digits=4))
    print(confusion_matrix(all_y_true, all_y_pred))
    print("ğŸ”¥" * 60)



ğŸš€ MODEL: RESNET50
ğŸ”¹ Fold 1/5
ğŸ”¹ Fold 2/5
ğŸ”¹ Fold 3/5
ğŸ”¹ Fold 4/5
ğŸ”¹ Fold 5/5
âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…âœ…
FINAL RESULTS â€” RESNET50
              precision    recall  f1-score   support

        Fake     0.2670    0.3600    0.3066      1361
        Real     0.7363    0.6439    0.6870      3777

    accuracy                         0.5687      5138
   macro avg     0.5017    0.5020    0.4968      5138
weighted avg     0.6120    0.5687    0.5862      5138

[[ 490  871]
 [1345 2432]]
ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥ğŸ”¥

ğŸš€ MODEL: MOBILENET_V2
ğŸ”¹ Fold 1/5
ğŸ”¹ Fold 2/5
ğŸ”¹ Fold 3/5
ğŸ”¹ Fold 4/5
ğŸ”¹ Fold 5/5
âœ…âœ…âœ…