In [None]:
# === SimCLR + End-to-End Fine-Tune (no freezing) ===
# With EMA (fine-tune), MixUp, and TTA + Curves (fp16-safe)
import os, math, numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------
# Config
# -----------------------
DATA_DIR = '/kaggle/input/riceds-original/Original'
SAVE_DIR = './'
os.makedirs(SAVE_DIR, exist_ok=True)

BATCH_SIZE = 64
NUM_WORKERS = 4
SEED = 42
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(SEED); np.random.seed(SEED)

# SimCLR pretrain
SIMCLR_EPOCHS = 100        # try 200–400 if you can
SIMCLR_LR = 3e-4
SIMCLR_WEIGHT_DECAY = 1e-6
TEMPERATURE = 0.2

# Supervised fine-tune (NO FREEZING)
FINETUNE_EPOCHS = 30
FT_LR_BACKBONE = 1e-4      # lower LR for encoder
FT_LR_HEAD = 1e-3          # higher LR for classifier head
FT_WEIGHT_DECAY = 1e-4
LABEL_SMOOTH = 0.0         # if using MixUp, keep 0 or very small

USE_IMAGENET_WEIGHTS = True

# Extras
USE_EMA = True             # EMA of fine-tune weights
EMA_DECAY = 0.999
USE_MIXUP = True
MIXUP_ALPHA = 0.2
USE_TTA = True             # eval-time flip TTA

# Histories for curves
simclr_loss_hist = []
ft_loss_hist = []
ft_acc_hist = []
val_acc_hist = []

# -----------------------
# Augmentations
# -----------------------
class TwoCropsTransform:
    """Return two random augmentations of the same image (SimCLR)."""
    def __init__(self, size=224):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            normalize,
        ])
    def __call__(self, x):
        return self.transform(x), self.transform(x)

supervised_train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])
eval_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

# -----------------------
# Stratified 80/20 split BEFORE training
# -----------------------
_for_split = datasets.ImageFolder(DATA_DIR, transform=transforms.ToTensor())
labels_all = [lbl for _, lbl in _for_split.samples]

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
train_idx, test_idx = next(sss.split(np.zeros(len(labels_all)), labels_all))

# -----------------------
# SSL datasets/loaders with SAME indices
# -----------------------
ssl_dataset = datasets.ImageFolder(DATA_DIR, transform=TwoCropsTransform())
train_ssl = Subset(ssl_dataset, train_idx)

train_loader_ssl = DataLoader(
    train_ssl,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,                 # BN stability
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# -----------------------
# Encoder & Heads
# -----------------------
class Encoder(nn.Module):
    def __init__(self, use_imagenet=True):
        super().__init__()
        if use_imagenet:
            try:
                base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
            except Exception:
                base = models.resnet50(pretrained=True)
        else:
            base = models.resnet50(weights=None)
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # (B, 2048, 1, 1)
        self.feature_dim = 2048
    def forward(self, x):
        x = self.backbone(x)
        return torch.flatten(x, 1)                                   # (B, 2048)

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, bn_last=False):
        super().__init__()
        layers = [
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        ]
        if bn_last:
            layers.append(nn.BatchNorm1d(out_dim, affine=False))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

# -----------------------
# SimCLR model
# -----------------------
class SimCLR(nn.Module):
    def __init__(self, encoder, proj_dim=256, hidden=2048):
        super().__init__()
        self.encoder = encoder
        self.projector = MLP(encoder.feature_dim, hidden, proj_dim, bn_last=True)
    def forward(self, x1, x2):
        h1 = self.encoder(x1)
        h2 = self.encoder(x2)
        z1 = self.projector(h1)
        z2 = self.projector(h2)
        return z1, z2

# ---- FP16-safe InfoNCE (compute logits in float32) ----
def nt_xent_loss(z1, z2, temperature=0.2):
    """
    Compute SimCLR NT-Xent loss in float32 to avoid fp16 overflow on large negative masks.
    """
    b = z1.size(0)
    z1 = F.normalize(z1, dim=1).float()
    z2 = F.normalize(z2, dim=1).float()
    z = torch.cat([z1, z2], dim=0)                                   # (2B, d) float32
    sim = torch.matmul(z, z.T) / float(temperature)                  # (2B, 2B) float32
    mask = torch.eye(2*b, device=z.device, dtype=torch.bool)
    sim = sim.masked_fill(mask, float('-inf'))
    targets = torch.arange(2*b, device=z.device)
    targets = (targets + b) % (2*b)
    loss = F.cross_entropy(sim, targets)                             # fp32 CE
    return loss

# -----------------------
# SimCLR Pretraining
# -----------------------
encoder = Encoder(use_imagenet=USE_IMAGENET_WEIGHTS).to(DEVICE)
simclr = SimCLR(encoder).to(DEVICE)

ssl_optimizer = torch.optim.AdamW(simclr.parameters(), lr=SIMCLR_LR, weight_decay=SIMCLR_WEIGHT_DECAY)
ssl_sched = torch.optim.lr_scheduler.CosineAnnealingLR(ssl_optimizer, T_max=SIMCLR_EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

for epoch in range(SIMCLR_EPOCHS):
    simclr.train()
    running = 0.0
    pbar = tqdm(train_loader_ssl, desc=f"SimCLR Epoch {epoch+1}/{SIMCLR_EPOCHS}")
    for (v1, v2), _ in pbar:                                          # ((view1, view2), label)
        x1 = v1.to(DEVICE, non_blocking=True)
        x2 = v2.to(DEVICE, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
            z1, z2 = simclr(x1, x2)
            loss = nt_xent_loss(z1, z2, temperature=TEMPERATURE)
        ssl_optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(ssl_optimizer)
        scaler.update()
        running += loss.item()
        pbar.set_postfix(loss=f"{running / (pbar.n or 1):.4f}")
    epoch_loss = running / len(train_loader_ssl)
    simclr_loss_hist.append(epoch_loss)
    ssl_sched.step()
    print(f"SimCLR Epoch {epoch+1}: loss={epoch_loss:.4f}")

# Save encoder after SimCLR pretrain
enc_path = os.path.join(SAVE_DIR, "simclr_encoder.pth")
torch.save(simclr.encoder.state_dict(), enc_path)
print(f"Saved SimCLR encoder to: {enc_path}")

# -----------------------
# Supervised Fine-Tuning (NO FREEZING) + EMA + MixUp
# -----------------------
sup_train_ds = datasets.ImageFolder(DATA_DIR, transform=supervised_train_tf)
sup_test_ds  = datasets.ImageFolder(DATA_DIR, transform=eval_tf)
num_classes = len(sup_train_ds.classes)

train_sup = Subset(sup_train_ds, train_idx)
test_sup  = Subset(sup_test_ds,  test_idx)

train_loader_sup = DataLoader(
    train_sup, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)
test_loader_sup = DataLoader(
    test_sup, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

class SupModel(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Linear(encoder.feature_dim, num_classes)
    def forward(self, x):
        feats = self.encoder(x)                                       # (B, 2048)
        return self.head(feats)

finetune_encoder = Encoder(use_imagenet=USE_IMAGENET_WEIGHTS).to(DEVICE)
finetune_encoder.load_state_dict(torch.load(enc_path, map_location=DEVICE))
sup_model = SupModel(finetune_encoder, num_classes).to(DEVICE)

param_groups = [
    {"params": sup_model.encoder.parameters(), "lr": FT_LR_BACKBONE, "weight_decay": FT_WEIGHT_DECAY},
    {"params": sup_model.head.parameters(),    "lr": FT_LR_HEAD,     "weight_decay": FT_WEIGHT_DECAY},
]
ft_optimizer = torch.optim.AdamW(param_groups)
ft_sched = torch.optim.lr_scheduler.CosineAnnealingLR(ft_optimizer, T_max=FINETUNE_EPOCHS)

# --- MixUp helpers ---
def one_hot_with_smoothing(y, num_classes, eps=0.0):
    y = y.view(-1)
    oh = torch.zeros(y.size(0), num_classes, device=y.device)
    oh.scatter_(1, y.unsqueeze(1), 1.0)
    if eps > 0:
        oh = oh * (1 - eps) + eps / num_classes
    return oh

def mixup_data(x, y, alpha=0.2, num_classes=1000, eps=0.0):
    if alpha <= 0:
        return x, one_hot_with_smoothing(y, num_classes, eps), 1.0
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(x.size(0), device=x.device)
    x_mix = lam * x + (1 - lam) * x[index]
    y1 = one_hot_with_smoothing(y, num_classes, eps)
    y2 = y1[index]
    y_mix = lam * y1 + (1 - lam) * y2
    return x_mix, y_mix, lam

def soft_cross_entropy(logits, target_prob):
    log_prob = F.log_softmax(logits, dim=1)
    return -(target_prob * log_prob).sum(dim=1).mean()

# --- EMA for fine-tune ---
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone()
                       for k, v in model.state_dict().items()
                       if v.dtype.is_floating_point}
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if k in self.shadow and v.dtype.is_floating_point:
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1 - self.decay)
    @torch.no_grad()
    def copy_to(self, model):
        sd = model.state_dict()
        for k, v in self.shadow.items():
            sd[k].copy_(v)

ema = EMA(sup_model, decay=EMA_DECAY) if USE_EMA else None

scaler_ft = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

def accuracy_top1(logits, targets):
    return (logits.argmax(dim=1) == targets).float().mean().item()

def predict_tta(model, x):
    # flip TTA
    logits1 = model(x)
    logits2 = model(torch.flip(x, dims=[3]))  # horizontal flip
    return (logits1 + logits2) / 2

best_acc = 0.0
best_path = os.path.join(SAVE_DIR, "simclr_finetune_best.pt")

for epoch in range(FINETUNE_EPOCHS):
    sup_model.train()
    run_loss, run_acc = 0.0, 0.0
    for x, y in tqdm(train_loader_sup, desc=f"FT Epoch {epoch+1}/{FINETUNE_EPOCHS}"):
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)

        if USE_MIXUP:
            x_in, y_soft, _ = mixup_data(x, y, alpha=MIXUP_ALPHA, num_classes=num_classes, eps=LABEL_SMOOTH)
        else:
            x_in, y_soft = x, one_hot_with_smoothing(y, num_classes, eps=LABEL_SMOOTH)

        with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
            logits = sup_model(x_in)
            loss = soft_cross_entropy(logits, y_soft)

        ft_optimizer.zero_grad(set_to_none=True)
        scaler_ft.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(sup_model.parameters(), 5.0)
        scaler_ft.step(ft_optimizer)
        scaler_ft.update()

        if USE_EMA:
            ema.update(sup_model)

        # track training accuracy on *non-mixed* view for readability
        with torch.no_grad():
            logits_nomix = sup_model(x)
            run_acc += accuracy_top1(logits_nomix, y)

        run_loss += loss.item()

    epoch_train_loss = run_loss / len(train_loader_sup)
    epoch_train_acc  = run_acc  / len(train_loader_sup)
    ft_loss_hist.append(epoch_train_loss)
    ft_acc_hist.append(epoch_train_acc)
    ft_sched.step()
    print(f"[FT] Epoch {epoch+1}: loss={epoch_train_loss:.4f} | acc={epoch_train_acc:.4f}")

    # --- Eval each epoch with EMA (if enabled) + TTA ---
    sup_model.eval()
    # backup current floating weights
    backup = {k: v.detach().clone() for k, v in sup_model.state_dict().items() if v.dtype.is_floating_point}
    if USE_EMA:
        ema.copy_to(sup_model)

    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader_sup:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            if USE_TTA:
                logits = predict_tta(sup_model, x)
            else:
                logits = sup_model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    epoch_acc = correct / max(total, 1)
    val_acc_hist.append(epoch_acc)

    # restore original weights after eval
    sd = sup_model.state_dict()
    for k, v in backup.items():
        sd[k].copy_(v)

    # save best (EMA-TTA) checkpoint
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        # save an EMA snapshot if enabled, else current weights
        if USE_EMA:
            ema.copy_to(sup_model)
            torch.save(sup_model.state_dict(), best_path)
            # restore again for next epoch training
            for k, v in backup.items():
                sd[k].copy_(v)
        else:
            torch.save(sup_model.state_dict(), best_path)

    print(f"[Eval] top1={epoch_acc:.4f} (best={best_acc:.4f})")

# -----------------------
# Final evaluation on held-out 20% (using best checkpoint, EMA if enabled)
# -----------------------
# Load fresh model for clean eval
finetune_encoder = Encoder(use_imagenet=USE_IMAGENET_WEIGHTS).to(DEVICE)
finetune_encoder.load_state_dict(torch.load(enc_path, map_location=DEVICE))
sup_model = SupModel(finetune_encoder, num_classes).to(DEVICE)
sup_model.load_state_dict(torch.load(best_path, map_location=DEVICE))
sup_model.eval()

y_true, y_pred = [], []
with torch.no_grad():
    for x, y in test_loader_sup:
        x = x.to(DEVICE, non_blocking=True)
        logits = predict_tta(sup_model, x) if USE_TTA else sup_model(x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(y.numpy())

print("\n=== SimCLR -> End-to-End Fine-Tune: Held-out Test ===")
print(classification_report(y_true, y_pred, digits=4))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens')
plt.title("Confusion Matrix (SimCLR Fine-Tune, EMA+MixUp+TTA)")
plt.xlabel("Predicted"); plt.ylabel("True")
plt.tight_layout(); plt.savefig(os.path.join(SAVE_DIR, "simclr_confusion_matrix.png"), dpi=150)
plt.show()

# -----------------------
# Curves: pretrain loss & fine-tune loss/accuracy
# -----------------------
plt.figure(figsize=(7,4))
plt.plot(simclr_loss_hist, marker='o')
plt.title("SimCLR Pretraining Loss")
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, "simclr_pretrain_loss.png"), dpi=150)
plt.show()

fig, ax = plt.subplots(1, 3, figsize=(18,4))
ax[0].plot(ft_loss_hist, marker='o')
ax[0].set_title("Fine-tune Loss"); ax[0].set_xlabel("Epoch"); ax[0].set_ylabel("Loss"); ax[0].grid(True, alpha=0.3)

ax[1].plot(ft_acc_hist, marker='o', label='Train')
ax[1].plot(val_acc_hist, marker='s', label='Val (EMA+TTA)')
ax[1].set_title("Accuracy"); ax[1].set_xlabel("Epoch"); ax[1].set_ylabel("Top-1 Acc"); ax[1].grid(True, alpha=0.3); ax[1].legend()

ax[2].plot(simclr_loss_hist, marker='o')
ax[2].set_title("Pretrain Loss"); ax[2].set_xlabel("Epoch"); ax[2].set_ylabel("Loss"); ax[2].grid(True, alpha=0.3)

plt.tight_layout()
fig.savefig(os.path.join(SAVE_DIR, "simclr_training_curves.png"), dpi=150)
plt.show()



=== End-to-End Fine-Tune Evaluation (Held-out 20%) ===
              precision    recall  f1-score   support

           0     0.8704    0.9400    0.9038       100
           1     0.9468    0.9271    0.9368        96
           2     1.0000    1.0000    1.0000       109
           3     0.8941    0.7917    0.8398        96
           4     0.9884    1.0000    0.9942        85
           5     0.9189    0.9623    0.9401       106
           6     0.9268    0.9661    0.9461       118
           7     1.0000    1.0000    1.0000        92
           8     0.9870    0.9500    0.9682        80
           9     0.9429    1.0000    0.9706        99
          10     0.9121    0.8737    0.8925        95
          11     1.0000    0.9727    0.9862       110
          12     0.9391    0.9818    0.9600       110
          13     0.9205    0.9643    0.9419        84
          14     0.9895    0.9792    0.9843        96
          15     0.9905    0.9811    0.9858       106
          16     0.9890    0.9375    0.9626        96
          17     1.0000    1.0000    1.0000        94
          18     1.0000    0.9286    0.9630        98
          19     0.9818    1.0000    0.9908       108
          20     0.9783    0.9890    0.9836        91
          21     0.9792    0.9792    0.9792        96
          22     1.0000    1.0000    1.0000       101
          23     0.8191    0.9167    0.8652        84
          24     0.8862    0.9820    0.9316       111
          25     0.9038    0.9691    0.9353        97
          26     0.9118    0.8158    0.8611       114
          27     0.8509    0.9417    0.8940       103
          28     0.9524    0.9091    0.9302       110
          29     0.9792    0.9216    0.9495       102
          30     1.0000    1.0000    1.0000       104
          31     0.9570    0.8165    0.8812       109
          32     0.9478    0.9732    0.9604       112
          33     0.8989    0.8511    0.8743        94
          34     0.9817    1.0000    0.9907       107
          35     0.8246    0.9592    0.8868        98
          36     0.9294    0.8587    0.8927        92
          37     0.9540    0.8557    0.9022        97

    accuracy                         0.9447      3800
   macro avg     0.9461    0.9446    0.9443      3800
weighted avg     0.9463    0.9447    0.9445      3800