**Import**

In [None]:
# 1. Imports
import os, random, math, itertools, numpy as np
import matplotlib.pyplot as plt
import timm
import torch.nn as nn
import torch.optim as optim
import torch, torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm, trange
from sklearn.metrics import classification_report

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on", device)

**Dataset path and Transforms**

In [None]:
# 2. Paths & Transforms

data_dir = "/kaggle/input/ct-kidney-dataset-normal-cyst-tumor-and-stone/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone"   # ✏️ CHANGE if needed


base_tf   = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
weak_tf   = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
strong_tf = transforms.Compose([
    transforms.RandomVerticalFlip(),
    transforms.RandAugment(7, 15),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

**Dataset Split**

In [None]:
# 3 · ImageFolder (NO transform yet) & splits

ds_raw = datasets.ImageFolder(root=data_dir)  # ← PIL images
class_names = ds_raw.classes
num_classes = len(class_names)
print(f"Total images: {len(ds_raw)} | Classes: {class_names}")

# 80% for training
# 10% for testing
# 10% for validation

len_full  = len(ds_raw)
test_sz   = len_full // 10
val_sz    = len_full // 10
train_sz  = len_full - val_sz - test_sz
train_ds, val_ds, test_ds = random_split(
    ds_raw, [train_sz, val_sz, test_sz],
    generator=torch.Generator().manual_seed(SEED)
)
print(f"Train {train_sz} | Val {val_sz} | Test {test_sz}")

**Labeled & Unlabeled**

In [None]:
# 4 · Labeled vs Unlabeled split inside TRAIN

label_frac = 0.10             # 10 % of train data has labels
train_indices = list(range(train_sz))
random.shuffle(train_indices)
n_lab = int(label_frac * train_sz)
lab_idx, unlab_idx = train_indices[:n_lab], train_indices[n_lab:]

lab_ds   = Subset(train_ds, lab_idx)
unlab_ds = Subset(train_ds, unlab_idx)
print(f"Labeled {len(lab_ds)} | Unlabeled {len(unlab_ds)}")


**FixMatch Function**

In [None]:
# 5 · Dataset wrappers

class FixMatchDataset(Dataset):
    def __init__(self, subset, labelled=True):
        self.subset  = subset
        self.labelled = labelled
    def __len__(self): return len(self.subset)
    def __getitem__(self, i):
        img, y = self.subset[i]  # PIL image
        if self.labelled:
            return weak_tf(img), y
        return weak_tf(img), strong_tf(img)

class EvalDataset(Dataset):
    def __init__(self, subset): self.subset = subset
    def __len__(self): return len(self.subset)
    def __getitem__(self, i):
        img, y = self.subset[i]
        return base_tf(img), y


**Load Dataset**

In [None]:
# 6 · DataLoaders

batch_lab, batch_unlab = 32, 64
lab_loader   = DataLoader(FixMatchDataset(lab_ds,  True),
                          batch_size=batch_lab, shuffle=True,
                          drop_last=True, num_workers=0)
unlab_loader = DataLoader(FixMatchDataset(unlab_ds, False),
                          batch_size=batch_unlab, shuffle=True,
                          drop_last=True, num_workers=0)
val_loader   = DataLoader(EvalDataset(val_ds),
                          batch_size=64, shuffle=False, num_workers=0)
test_loader  = DataLoader(EvalDataset(test_ds),
                          batch_size=64, shuffle=False, num_workers=0)

**Swin transformer, Frozen BackBone, Model, hyperr-parameters**

In [None]:
# 7 · Swin Transformer (Frozen Backbone) with Linear Probe

# Load Swin Transformer without classification head
backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=0)

# Freeze all backbone layers
for param in backbone.parameters():
    param.requires_grad = False

# Define a linear probe on top of the frozen backbone
class SwinWithLinearHead(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(backbone.num_features, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

# Set number of classes
num_classes = 4  # Change this to match your dataset

# Create the model
model = SwinWithLinearHead(backbone, num_classes).to(device)

# Only train the linear head
optimizer = optim.AdamW(model.head.parameters(), lr=3e-4, weight_decay=1e-4)

# Hyperparameters
epochs   = 300
tau      = 0.95     # confidence threshold
lambda_u = 1.0      # unsupervised loss weight

# For tracking loss
sup_hist, unsup_hist, val_hist = [], [], []


**Training Loop**

In [None]:
# 8 · Training loop  (FixMatch + tqdm)


for epoch in trange(1, epochs + 1, desc="Epoch"):
    model.train()
    sup_meter = unsup_meter = 0.0

     # --- training progress bar ---
    train_bar = tqdm(
        zip(lab_loader, unlab_loader),
        total=min(len(lab_loader), len(unlab_loader)),
        desc=f"Train {epoch:02d}", leave=False
    )

    for (x_lab, y_lab), (w_unlab, s_unlab) in train_bar:
        x_lab, y_lab     = x_lab.to(device), y_lab.to(device)
        w_unlab, s_unlab = w_unlab.to(device), s_unlab.to(device)

        # --- supervised loss ---
        logits_lab = model(x_lab)
        loss_sup   = F.cross_entropy(logits_lab, y_lab)

        # ---  generate pseudo-label on weak aug ---
        with torch.no_grad():
            logits_w = model(w_unlab)
            probs_w  = F.softmax(logits_w, dim=1)
            max_p, pseudo = probs_w.max(1)
            mask = (max_p >= tau).float()  # 1 if confident

        # --- unsupervised loss ---
        logits_s = model(s_unlab)
        loss_uns = (F.cross_entropy(logits_s, pseudo,
                                    reduction="none") * mask).mean()

        # --- total loss and optimization ---
        loss = loss_sup + lambda_u * loss_uns

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sup_meter  += loss_sup.item()
        unsup_meter += loss_uns.item()

        train_bar.set_postfix(Sup=f"{loss_sup.item():.4f}",
                              Unsup=f"{loss_uns.item():.4f}")

    # --- record average training losses --
    
    sup_epoch   = sup_meter  / len(lab_loader)
    unsup_epoch = unsup_meter/ len(unlab_loader)
    sup_hist.append(sup_epoch)
    unsup_hist.append(unsup_epoch)

    # ---------- validation ----------
    model.eval()
    val_loss = correct = total = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            logits = model(x_val)
            val_loss += F.cross_entropy(logits, y_val,
                                        reduction="sum").item()
            preds = logits.argmax(1)
            correct += (preds == y_val).sum().item()
            total   += y_val.size(0)
    
     # --- record validation loss and accuracy ---
    val_epoch = val_loss / total
    val_hist.append(val_epoch)
    val_acc   = correct / total

    tqdm.write(f"Epoch {epoch:02d} | "
               f"Sup {sup_epoch:.4f} | Unsup {unsup_epoch:.4f} | "
               f"ValLoss {val_epoch:.4f} | ValAcc {val_acc:.3%}")




**Loss Curve**

In [None]:
# 9 · Training vs Validation Loss Curves

import matplotlib.pyplot as plt

epochs_range = range(1, len(sup_hist) + 1)
train_loss = [s + lambda_u * u for s, u in zip(sup_hist, unsup_hist)]

plt.figure(figsize=(8, 5))
plt.plot(epochs_range, train_loss, label="Training Loss", marker='o')
plt.plot(epochs_range, val_hist, label="Validation Loss", marker='x')

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("FixMatch – Swin Transformer (Frozen Backbone + Linear Probe)")
plt.legend()
plt.grid(True, linestyle="--", linewidth=0.5)
plt.tight_layout()
plt.show()


**Checkpoint**

In [None]:
# 10 · Best Validation Epoch

best_epoch = val_hist.index(min(val_hist)) + 1
best_val_loss = min(val_hist)

print(f"\n✅ Best Validation Epoch: {best_epoch}")
print(f"📉 Best Validation Loss: {best_val_loss:.4f}")


**Accuracy Report**

In [None]:
# 11 · Test-set evaluation

model.eval();
total = correct = 0
all_preds, all_lbls = [], []
with torch.no_grad():
    for x_test, y_test in tqdm(test_loader, desc="Testing", leave=False):
        x_test, y_test = x_test.to(device), y_test.to(device)
        logits = model(x_test)
        preds  = logits.argmax(1)
        correct += (preds == y_test).sum().item()
        total   += y_test.size(0)
        all_preds.extend(preds.cpu().numpy())
        all_lbls.extend(y_test.cpu().numpy())

test_acc = correct / total
print(f"\nTest accuracy: {test_acc:.3%}\n")
print(classification_report(all_lbls, all_preds,
                            target_names=class_names, digits=3))



**Save the Model**

In [None]:
# 12 · Save the model 

torch.save(model.state_dict(), "fixmatch_swin_kidney.pth")
print("Model saved to fixmatch_swin_kidney.pth")