In [None]:
# === SIMCLR IMPLEMENTATION ===
import os, torch, torch.nn as nn, torch.nn.functional as F, numpy as np
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt, seaborn as sns
from tqdm import tqdm

# Hyperparams
DATA_DIR = '/kaggle/input/riceds-original/Original'
BATCH_SIZE = 64
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# SimCLR Augmentation
simclr_tf = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# Dataset loader for two augmented views
class SimCLRDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform):
        self.base_dataset = datasets.ImageFolder(root, transform=transform)
        self.transform = transform

    def __getitem__(self, index):
        path, _ = self.base_dataset.samples[index]
        img = self.base_dataset.loader(path)
        x1 = self.transform(img)
        x2 = self.transform(img)
        return x1, x2

    def __len__(self):
        return len(self.base_dataset)

train_dataset = SimCLRDataset(DATA_DIR, simclr_tf)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# SimCLR Model: Encoder + MLP head
class SimCLR(nn.Module):
    def __init__(self, base_model='resnet50', out_dim=128):
        super().__init__()
        resnet = models.resnet50(pretrained=False)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC
        self.projection_head = nn.Sequential(
            nn.Linear(resnet.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

    def forward(self, x):
        h = self.encoder(x).squeeze()
        z = self.projection_head(h)
        return F.normalize(z, dim=1)

# NT-Xent Loss
def nt_xent_loss(z1, z2, temperature=0.5):
    z = torch.cat([z1, z2], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    sim /= temperature
    labels = torch.arange(z.size(0)).to(z.device)
    labels = (labels + z.size(0) // 2) % z.size(0)
    loss = F.cross_entropy(sim, labels)
    return loss

# Training SimCLR
model = SimCLR().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for x1, x2 in tqdm(train_loader, desc=f"SimCLR Epoch {epoch+1}"):
        x1, x2 = x1.to(DEVICE), x2.to(DEVICE)
        z1, z2 = model(x1), model(x2)
        loss = nt_xent_loss(z1, z2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")

torch.save(model.encoder.state_dict(), "simclr_encoder.pth")

# Linear evaluation
class LinearEval(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Load pretrained encoder
encoder = models.resnet50(pretrained=False)
encoder.load_state_dict(torch.load("simclr_encoder.pth"))
encoder = nn.Sequential(*list(encoder.children())[:-1])
for param in encoder.parameters(): param.requires_grad = False
encoder.to(DEVICE)

classifier = LinearEval(2048, 38).to(DEVICE)
eval_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
full_dataset = datasets.ImageFolder(DATA_DIR, transform=eval_tf)
train_size = int(0.8 * len(full_dataset))
train_set, test_set = random_split(full_dataset, [train_size, len(full_dataset)-train_size])
train_loader_eval = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader_eval = DataLoader(test_set, batch_size=64, shuffle=False)

# Train linear classifier
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    classifier.train()
    for x, y in train_loader_eval:
        x, y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad(): features = encoder(x).squeeze()
        logits = classifier(features)
        loss = loss_fn(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate
classifier.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for x, y in test_loader_eval:
        x = x.to(DEVICE)
        features = encoder(x).squeeze()
        preds = torch.argmax(classifier(features), dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(y.numpy())

print("\nSimCLR Evaluation:")
print(classification_report(y_true, y_pred, digits=4))
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title("SimCLR Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()
