In [None]:
# === MoCo (v2-style) + End-to-End Fine-Tune (no freezing) ===
import os, copy, math, numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split, Subset

from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------
# Config
# -----------------------
DATA_DIR = '/kaggle/input/riceds-original/Original'
SAVE_DIR = './'
os.makedirs(SAVE_DIR, exist_ok=True)

BATCH_SIZE = 64

# === MoCo pretrain ===
MOCO_EPOCHS = 100          # increase if possible (e.g., 200–400)
SSL_LR = 3e-4
SSL_WEIGHT_DECAY = 1e-6
# MoCo hyperparams
MOCO_DIM = 256             # projection dimension
MOCO_K = 16384             # queue size (power of 2 is nice)
MOCO_M = 0.996             # EMA/momentum for key encoder
MOCO_T = 0.2               # softmax temperature

# Supervised fine-tune (NO FREEZING)
FINETUNE_EPOCHS = 30
FT_LR_BACKBONE = 1e-4
FT_LR_HEAD = 1e-3
FT_WEIGHT_DECAY = 1e-4
LABEL_SMOOTH = 0.1

USE_IMAGENET_WEIGHTS = True
NUM_WORKERS = 4
SEED = 42

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(SEED)
np.random.seed(SEED)

# -----------------------
# Augmentations
# -----------------------
class TwoCropsTransform:
    """Two strong views per image (MoCo v2-like)."""
    def __init__(self, size=224):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        self.base = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            normalize,
        ])
    def __call__(self, x):
        return self.base(x), self.base(x)

# weaker augs for supervised training
supervised_train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

eval_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

# -----------------------
# Dataset & 80/20 split BEFORE training
# -----------------------
full_ssl = datasets.ImageFolder(DATA_DIR, transform=TwoCropsTransform())
num_total = len(full_ssl)
num_train = int(0.8 * num_total)
num_test = num_total - num_train

train_subset, test_subset = random_split(
    full_ssl, [num_train, num_test],
    generator=torch.Generator().manual_seed(SEED)
)

train_indices = train_subset.indices
test_indices  = test_subset.indices

train_loader_ssl = DataLoader(
    Subset(full_ssl, train_indices),
    batch_size=BATCH_SIZE, shuffle=True, drop_last=True,
    num_workers=NUM_WORKERS, pin_memory=True
)

# -----------------------
# Encoder (ResNet-50 trunk)
# -----------------------
class Encoder(nn.Module):
    def __init__(self, use_imagenet=True):
        super().__init__()
        if use_imagenet:
            try:
                base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
            except Exception:
                base = models.resnet50(pretrained=True)
        else:
            base = models.resnet50(weights=None)
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # (B, 2048, 1, 1)
        self.feature_dim = 2048

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        return x

# -----------------------
# MLP Head (MoCo v2 uses 2-layer MLP)
# -----------------------
class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

# -----------------------
# === MoCo model ===
# -----------------------
class MoCo(nn.Module):
    """
    Momentum Contrast (v2-style): query encoder, key encoder (EMA), queue.
    """
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super().__init__()
        self.K = K
        self.m = m
        self.T = T

        # encoders
        self.encoder_q = base_encoder
        self.encoder_k = copy.deepcopy(base_encoder)
        for p in self.encoder_k.parameters():
            p.requires_grad = False

        # projection heads
        self.proj_q = ProjectionMLP(self.encoder_q.feature_dim, 2048, dim)
        self.proj_k = copy.deepcopy(self.proj_q)
        for p in self.proj_k.parameters():
            p.requires_grad = False

        # queue (features are L2-normalized)
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """EMA update for key encoder and projector."""
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
        for param_q, param_k in zip(self.proj_q.parameters(), self.proj_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0, "Queue size must be divisible by batch size for simplicity."
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        # compute query features
        q = self.encoder_q(im_q)
        q = self.proj_q(q)
        q = F.normalize(q, dim=1)  # (B, dim)

        # compute key features
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = self.proj_k(k)
            k = F.normalize(k, dim=1)

        # positive logits: q @ k
        l_pos = torch.einsum('nd,nd->n', [q, k]).unsqueeze(-1)  # (B, 1)
        # negative logits: q @ queue
        l_neg = torch.einsum('nd,dk->nk', [q, self.queue.clone().detach()])  # (B, K)

        logits = torch.cat([l_pos, l_neg], dim=1)    # (B, 1+K)
        logits /= self.T

        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)  # positives are 0

        # dequeue and enqueue
        with torch.no_grad():
            self._dequeue_and_enqueue(k)

        return logits, labels

# -----------------------
# MoCo Pretraining
# -----------------------
encoder = Encoder(use_imagenet=USE_IMAGENET_WEIGHTS).to(DEVICE)
model = MoCo(
    base_encoder=encoder,
    dim=MOCO_DIM, K=MOCO_K, m=MOCO_M, T=MOCO_T
).to(DEVICE)

ssl_optimizer = torch.optim.AdamW(
    list(model.encoder_q.parameters()) + list(model.proj_q.parameters()),
    lr=SSL_LR, weight_decay=SSL_WEIGHT_DECAY
)
ssl_sched = torch.optim.lr_scheduler.CosineAnnealingLR(ssl_optimizer, T_max=MOCO_EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

criterion_ssl = nn.CrossEntropyLoss()

for epoch in range(MOCO_EPOCHS):
    model.train()
    running = 0.0
    pbar = tqdm(train_loader_ssl, desc=f"MoCo Epoch {epoch+1}/{MOCO_EPOCHS}")
    for (v1, v2), _ in pbar:
        x_q = v1.to(DEVICE, non_blocking=True)  # query view
        x_k = v2.to(DEVICE, non_blocking=True)  # key view

        with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
            logits, labels = model(x_q, x_k)
            loss = criterion_ssl(logits, labels)

        ssl_optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(ssl_optimizer)
        scaler.update()

        running += loss.item()
        pbar.set_postfix(loss=f"{running / (pbar.n or 1):.4f}")

    ssl_sched.step()
    print(f"MoCo Epoch {epoch+1}: loss={running / len(train_loader_ssl):.4f}")

# Save encoder weights (query encoder after MoCo pretrain)
encoder_path = os.path.join(SAVE_DIR, "moco_encoder.pth")
torch.save(model.encoder_q.state_dict(), encoder_path)
print(f"Saved pretrain encoder to: {encoder_path}")

# -----------------------
# Supervised Fine-Tuning (NO FREEZING)
#   - Use SAME 80/20 indices
# -----------------------
full_sup_train = datasets.ImageFolder(DATA_DIR, transform=supervised_train_tf)
full_sup_test  = datasets.ImageFolder(DATA_DIR, transform=eval_tf)
num_classes = len(full_sup_train.classes)

train_sup = Subset(full_sup_train, train_indices)
test_sup  = Subset(full_sup_test,  test_indices)

train_loader_sup = DataLoader(
    train_sup, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)
test_loader_sup = DataLoader(
    test_sup, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

# Build encoder and load MoCo-pretrained weights
finetune_encoder = Encoder(use_imagenet=USE_IMAGENET_WEIGHTS).to(DEVICE)
finetune_encoder.load_state_dict(torch.load(encoder_path, map_location=DEVICE))

# Classifier head
class SupModel(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Linear(encoder.feature_dim, num_classes)
    def forward(self, x):
        feats = self.encoder(x)
        return self.head(feats)

sup_model = SupModel(finetune_encoder, num_classes).to(DEVICE)

# Optimizer with differential LR
param_groups = [
    {"params": sup_model.encoder.parameters(), "lr": FT_LR_BACKBONE, "weight_decay": FT_WEIGHT_DECAY},
    {"params": sup_model.head.parameters(),    "lr": FT_LR_HEAD,      "weight_decay": FT_WEIGHT_DECAY},
]
ft_optimizer = torch.optim.AdamW(param_groups)
ft_sched = torch.optim.lr_scheduler.CosineAnnealingLR(ft_optimizer, T_max=FINETUNE_EPOCHS)
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)

scaler_ft = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

def accuracy_top1(logits, targets):
    pred = logits.argmax(dim=1)
    return (pred == targets).float().mean().item()

for epoch in range(FINETUNE_EPOCHS):
    sup_model.train()
    running_loss, running_acc = 0.0, 0.0
    for x, y in tqdm(train_loader_sup, desc=f"FT Epoch {epoch+1}/{FINETUNE_EPOCHS}"):
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
            logits = sup_model(x)
            loss = criterion(logits, y)
        ft_optimizer.zero_grad(set_to_none=True)
        scaler_ft.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(sup_model.parameters(), max_norm=5.0)
        scaler_ft.step(ft_optimizer)
        scaler_ft.update()

        running_loss += loss.item()
        running_acc += accuracy_top1(logits, y)

    ft_sched.step()
    print(f"[FT] Epoch {epoch+1}: loss={running_loss/len(train_loader_sup):.4f} | "
          f"acc={running_acc/len(train_loader_sup):.4f}")

# -----------------------
# Evaluation on held-out 20%
# -----------------------
sup_model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for x, y in test_loader_sup:
        x = x.to(DEVICE, non_blocking=True)
        logits = sup_model(x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(y.numpy())

print("\n=== End-to-End Fine-Tune Evaluation (Held-out 20%) ===")
print(classification_report(y_true, y_pred, digits=4))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens')
plt.title("Confusion Matrix (End-to-End Fine-Tune)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()


=== End-to-End Fine-Tune Evaluation (Held-out 20%) ===
              precision    recall  f1-score   support

           0     0.9048    0.9500    0.9268       100
           1     0.9184    0.9375    0.9278        96
           2     1.0000    0.9908    0.9954       109
           3     0.9277    0.8021    0.8603        96
           4     0.9655    0.9882    0.9767        85
           5     0.8938    0.9528    0.9224       106
           6     0.8810    0.9407    0.9098       118
           7     1.0000    0.9891    0.9945        92
           8     0.9630    0.9750    0.9689        80
           9     0.9515    0.9899    0.9703        99
          10     0.8817    0.8632    0.8723        95
          11     0.9910    1.0000    0.9955       110
          12     0.8793    0.9273    0.9027       110
          13     0.9268    0.9048    0.9157        84
          14     0.9896    0.9896    0.9896        96
          15     0.9806    0.9528    0.9665       106
          16     0.9333    0.8750    0.9032        96
          17     0.9894    0.9894    0.9894        94
          18     1.0000    0.9184    0.9574        98
          19     0.9908    1.0000    0.9954       108
          20     0.9889    0.9780    0.9834        91
          21     0.9688    0.9688    0.9688        96
          22     1.0000    1.0000    1.0000       101
          23     0.8152    0.8929    0.8523        84
          24     0.8730    0.9910    0.9283       111
          25     0.9596    0.9794    0.9694        97
          26     0.9320    0.8421    0.8848       114
          27     0.8870    0.9903    0.9358       103
          28     0.9903    0.9273    0.9577       110
          29     0.9796    0.9412    0.9600       102
          30     1.0000    1.0000    1.0000       104
          31     0.9184    0.8257    0.8696       109
          32     0.9474    0.9643    0.9558       112
          33     0.9506    0.8191    0.8800        94
          34     0.9907    0.9907    0.9907       107
          35     0.8785    0.9592    0.9171        98
          36     0.9062    0.9457    0.9255        92
          37     0.9121    0.8557    0.8830        97

    accuracy                         0.9426      3800
   macro avg     0.9438    0.9423    0.9422      3800
weighted avg     0.9440    0.9426    0.9424      3800
