In [None]:
# === MOCO IMPLEMENTATION ===
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np, matplotlib.pyplot as plt, seaborn as sns
from tqdm import tqdm

# Config
DATA_DIR = '/kaggle/input/riceds-original/Original'
BATCH_SIZE = 64
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MoCo Augmentations
moco_tf = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# MoCo Dataset (two views)
class MoCoDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform):
        self.base = datasets.ImageFolder(root, transform=transform)
        self.transform = transform

    def __getitem__(self, index):
        path, _ = self.base.samples[index]
        img = self.base.loader(path)
        return self.transform(img), self.transform(img)

    def __len__(self):
        return len(self.base)

train_dataset = MoCoDataset(DATA_DIR, moco_tf)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Encoder + projection
def get_encoder():
    resnet = models.resnet50(pretrained=False)
    encoder = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC
    return encoder, resnet.fc.in_features

class MoCo(nn.Module):
    def __init__(self, dim=128, K=4096, m=0.999, T=0.07):
        super().__init__()
        self.K, self.m, self.T = K, m, T
        self.encoder_q, feat_dim = get_encoder()
        self.encoder_k, _ = get_encoder()
        self.fc_q = nn.Linear(feat_dim, dim)
        self.fc_k = nn.Linear(feat_dim, dim)

        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        for param_q, param_k in zip(self.fc_q.parameters(), self.fc_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
        for param_q, param_k in zip(self.fc_q.parameters(), self.fc_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        keys = concat_all_gather(keys)
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr+batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        q = F.normalize(self.fc_q(self.encoder_q(im_q).squeeze()), dim=1)
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = F.normalize(self.fc_k(self.encoder_k(im_k).squeeze()), dim=1)

        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        logits = torch.cat([l_pos, l_neg], dim=1)
        labels = torch.zeros(logits.size(0), dtype=torch.long).to(DEVICE)
        logits /= self.T
        self._dequeue_and_enqueue(k)
        return logits, labels

@torch.no_grad()
def concat_all_gather(tensor):
    return tensor  # No multi-GPU needed here

# Training
model = MoCo().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for im_q, im_k in tqdm(train_loader, desc=f"MoCo Epoch {epoch+1}"):
        im_q, im_k = im_q.to(DEVICE), im_k.to(DEVICE)
        logits, labels = model(im_q, im_k)
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f}")

torch.save(model.encoder_q.state_dict(), "moco_encoder.pth")

# === Linear Evaluation ===
class LinearClassifier(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, x): return self.fc(x)

# Load encoder
encoder = get_encoder()[0]
encoder.load_state_dict(torch.load("moco_encoder.pth"))
encoder.eval()
for param in encoder.parameters(): param.requires_grad = False
encoder.to(DEVICE)

classifier = LinearClassifier(2048, 38).to(DEVICE)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

eval_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
dataset = datasets.ImageFolder(DATA_DIR, transform=eval_tf)
train_size = int(0.8 * len(dataset))
train_ds, test_ds = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader_eval = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader_eval = DataLoader(test_ds, batch_size=64, shuffle=False)

# Train classifier
for epoch in range(10):
    classifier.train()
    for x, y in train_loader_eval:
        x, y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad(): features = encoder(x).squeeze()
        logits = classifier(features)
        loss = loss_fn(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate
classifier.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for x, y in test_loader_eval:
        x = x.to(DEVICE)
        features = encoder(x).squeeze()
        preds = torch.argmax(classifier(features), dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(y.numpy())

print("\n MoCo Evaluation:")
print(classification_report(y_true, y_pred, digits=4))
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Oranges')
plt.title("MoCo Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()
