**Mean Teacher + Swin**

**1. Import**

In [None]:
# Import

import math, random, time, copy, os, json, warnings, pathlib
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, datasets, models
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from torch.cuda.amp import autocast, GradScaler

import timm  # ADDED: for Swin Transformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
torch.cuda.manual_seed(42)
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

**Device Check**

In [None]:
# Device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

**2. Dataset Load**

In [None]:
# Dataset Load

data_dir = "/kaggle/input/kidney-dataset/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone"

base_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


weak_tf   = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
strong_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.3, 2.0)),
    transforms.RandAugment(7, 15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])



**3. custom dataset to return weak and strong augmentations**

In [None]:
# dataset (weak + strong augmentations) 

class UnlabeledPair(Dataset):
    def __init__(self, base_dataset, weak_tf, strong_tf):
        self.base_dataset = base_dataset
        self.weak_tf = weak_tf
        self.strong_tf = strong_tf

    def __getitem__(self, idx):
        img, _ = self.base_dataset[idx]
        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)
        return self.weak_tf(img), self.strong_tf(img)

    def __len__(self):
        return len(self.base_dataset)


**4. Split Dataset**

In [None]:
# Build splits: (train_labeled, train_unlabeled, val)

FULL = datasets.ImageFolder(data_dir)

# ----  extract labels for stratified splitting   ------

labels = [lbl for _, lbl in FULL.samples]
labels = np.array(labels)

# --- split into training and validation sets (90% train, 10% val) ---
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.10, random_state=SEED)
train_idx, val_idx = next(sss.split(np.zeros(len(labels)), labels))
train_labels = labels[train_idx]

# --- split training set into labeled and unlabeled (50% labeled) ---
p_labeled = 0.50
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=1 - p_labeled,
                              random_state=SEED)
lab_idx, unlab_idx = next(sss2.split(np.zeros(len(train_labels)), train_labels))

lab_idx   = train_idx[lab_idx]
unlab_idx = train_idx[unlab_idx]

# ---------- labelled ----------
train_lab_ds = Subset(
    datasets.ImageFolder(data_dir, transform=weak_tf),
    lab_idx)

# ---------- unlabeled ----------
base_unlab   = Subset(
    datasets.ImageFolder(data_dir),        # ← NO transform
    unlab_idx)

# wrap with UnlabeledPair to apply weak and strong augmentations

train_unlab_ds = UnlabeledPair(base_unlab, weak_tf, strong_tf)

# ---------- validation ----------
val_ds = Subset(
    datasets.ImageFolder(data_dir, transform=base_tf),
    val_idx)

# ---------- create data loaders ----------
BATCH_L = 32
BATCH_U = 64
lab_loader   = DataLoader(train_lab_ds,   batch_size=BATCH_L,
                          shuffle=True,  drop_last=True,  num_workers=2)
unlab_loader = DataLoader(train_unlab_ds, batch_size=BATCH_U,
                          shuffle=True,  drop_last=True,  num_workers=2)
val_loader   = DataLoader(val_ds,         batch_size=64,
                          shuffle=True, num_workers=2)


**All the Swin transformer Model**

In [None]:
import timm
models = timm.list_models('*swin*')
print(models)

**5. Swin Backbone (tiny variant)**

In [None]:
# Define a custom backbone model using Swin Transformer (tiny variant)

# === Swin Backbone + Head ===
class SwinBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('swin_tiny_patch4_window7_224', 
                                          pretrained=True, num_classes=0)
        for param in self.backbone.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.backbone(x)

**6. Linear Head**

In [None]:
# Define a linear classification head
class LinearHead(nn.Module):
    def __init__(self, in_features=768, num_classes=4):
        super().__init__()
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.classifier(x)


**7. Mean Teacher Model with Swin Transformer(Backbone)**

In [None]:
# Full model (Student/Teacher)

# --- Swin Backbone + Linear Head wrapped in a model for Mean Teacher ----
class SwinMeanTeacherModel(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.backbone = SwinBackbone()
        self.head = LinearHead(768, num_classes)

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone(x)
        return self.head(feats)

**8. Initialize Student and Teacher Models with EMA**

In [None]:
# Load the model

num_classes = 4  
student = SwinMeanTeacherModel(num_classes).to(device)
teacher = SwinMeanTeacherModel(num_classes).to(device)

# Copy student weights to teacher initially
teacher.load_state_dict(student.state_dict())  

ema_decay = 0.999

def update_ema(student, teacher, decay):
    with torch.no_grad():
        for s, t in zip(student.parameters(), teacher.parameters()):
            t.data.mul_(decay).add_(s.data, alpha=1 - decay)


**9. Sigmoid, Loss, Classification Report Define**

In [None]:
from sklearn.metrics import classification_report, accuracy_score, precision_recall_fscore_support

# --- loss function and optimizer setup ---
criterion_sup = nn.CrossEntropyLoss()
opt   = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=2e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100)

# --- ramp-up function for consistency weight ---
def sigmoid_rampup(current, rampup_length):
    if rampup_length == 0: return 1.0
    current = np.clip(current, 0.0, rampup_length)
    phase = 1.0 - current / rampup_length
    return math.exp(-5.0 * phase * phase)
    
def evaluate_model(model, loader):
    model.eval()
    all_p, all_t = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, 1)
            all_p.append(preds.cpu().numpy())
            all_t.append(y.cpu().numpy())

    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    # Overall metrics
    acc = accuracy_score(y_true, y_pred)
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)

    return acc, pr, rc, f1


**Class Name of the Dataset**

In [None]:
class_names = ['Cyst', 'Normal', 'Stone', 'Tumor']

**10. Mean Teacher Training with Consistency Loss, Mixed Precision, and EMA**

In [None]:
best_f1, best_state = 0.0, None
history = []
train_loss_list = []
val_loss_list = []
EPOCHS = 200
lambda_max = 30.0             
ramp_len  =   4  

# gradient scaler for mixed precision
scaler = GradScaler()

for epoch in range(1, EPOCHS+1):
    student.train()
    lab_iter, unlab_iter = iter(lab_loader), iter(unlab_loader)
    n_steps = max(len(lab_loader), len(unlab_loader))
    total_loss = 0

    # --- training loop with progress bar ---
    pbar = tqdm(range(n_steps), desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for _ in pbar:
        try: 
            x_lab, y_lab = next(lab_iter)
        except StopIteration: 
            lab_iter = iter(lab_loader); x_lab, y_lab = next(lab_iter)
        try: 
            x_w, x_s = next(unlab_iter)
        except StopIteration: 
            unlab_iter = iter(unlab_loader); x_w, x_s = next(unlab_iter)

        x_lab, y_lab = x_lab.to(device), y_lab.to(device)
        x_w, x_s = x_w.to(device), x_s.to(device)

        opt.zero_grad()

        # --- forward pass using autocast for mixed precision ---
        with torch.amp.autocast(device_type='cuda'):
            logits_lab = student(x_lab)
            loss_sup   = criterion_sup(logits_lab, y_lab)

            with torch.no_grad():
                t_prob = F.softmax(teacher(x_w), dim=1)
            s_prob = F.softmax(student(x_s), dim=1)
            loss_cons = F.mse_loss(s_prob, t_prob)

            lam = lambda_max * sigmoid_rampup(epoch-1, ramp_len)
            loss = loss_sup + lam * loss_cons

        # --- backward and optimizer step with scaler ---
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        # --- update teacher model using EMA ---
        update_ema(student, teacher, ema_decay)

        total_loss += loss.item()
        pbar.set_postfix(sup=loss_sup.item(), cons=loss_cons.item(), lam=lam)


    # --- compute and store average training loss ---
    avg_train_loss = total_loss / n_steps
    train_loss_list.append(avg_train_loss)

    # --- compute validation loss using teacher model ---
    teacher.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = teacher(x)
            val_loss += criterion_sup(logits, y).item()
    val_loss /= len(val_loader)
    val_loss_list.append(val_loss)

    acc, pr, rc, f1 = evaluate_model(teacher, val_loader)
    history.append(dict(epoch=epoch, acc=acc, prec=pr, rec=rc, f1=f1))
    print(f"Epoch {epoch:02d}: train_loss={avg_train_loss:.4f} val_loss={val_loss:.4f} val_acc={acc:.4f} F1={f1:.4f}")

    if f1 > best_f1:
        best_f1 = f1
        best_state = copy.deepcopy(teacher.state_dict())
        torch.save(best_state, "best_mean_teacher.pth")
    

print(f"\nBest F1 = {best_f1:.4f}")
with open("history.json", "w") as f: json.dump(history, f, indent=2)

# === Plot Loss Curves ===
plt.figure(figsize=(8, 5))
plt.plot(train_loss_list, label="Train Loss")
plt.plot(val_loss_list, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss Curve")
plt.legend()
plt.grid()
plt.savefig("loss_curve.png")
plt.show()


**11. Best Epoch Selection Based on F1-Score**

In [None]:
import json

with open("history.json", "r") as f:
    history = json.load(f)

# Find the epoch with the highest F1-score
best_epoch_info = max(history, key=lambda x: x['f1'])
print(f"Best epoch was: {best_epoch_info['epoch']} with F1 = {best_epoch_info['f1']:.4f}")


**12. Evaluation**

In [None]:
def evaluate_model(model, loader, class_names):
    model.eval()
    all_p, all_t = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, 1)
            all_p.append(preds.cpu().numpy())
            all_t.append(y.cpu().numpy())

    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    # Overall metrics
    acc = accuracy_score(y_true, y_pred)
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)

    print(f"\nOverall Metrics:")
    print(f"Accuracy : {acc:.4f}")
    print(f"Precision: {pr:.4f}")
    print(f"Recall   : {rc:.4f}")
    print(f"F1 Score : {f1:.4f}")

    # Per-class report
    print(f"\nPer-Class Metrics:")
    report = classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
    print(report)

    return acc, pr, rc, f1

**13. Classification Report**

In [None]:
class_names = ['Cyst', 'Normal', 'Stone', 'Tumor']
evaluate_model(model=teacher, loader=val_loader, class_names=class_names)