In [None]:
# === BYOL IMPLEMENTATION ===
import torch, torch.nn as nn, torch.nn.functional as F, copy
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np, matplotlib.pyplot as plt, seaborn as sns
from tqdm import tqdm

# Config
DATA_DIR = '/kaggle/input/riceds-original/Original'
BATCH_SIZE = 64
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data augmentations (two strong views)
class BYOLTransform:
    def __init__(self, size=224):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.8, 0.8, 0.8, 0.2),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9),
            transforms.ToTensor(),
            normalize
        ])
    def __call__(self, x): return self.transform(x), self.transform(x)

class BYOLDataset(datasets.ImageFolder):
    def __getitem__(self, index):
        path, _ = self.samples[index]
        img = self.loader(path).convert('RGB')
        return self.transform(img)

train_dataset = BYOLDataset(DATA_DIR, transform=BYOLTransform())
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Encoder
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet50(pretrained=False)
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.feature_dim = 2048

    def forward(self, x): return self.backbone(x).squeeze()

# MLP heads
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x): return self.net(x)

# BYOL model
class BYOL(nn.Module):
    def __init__(self, encoder, proj_dim=128, hidden_dim=512):
        super().__init__()
        self.online_encoder = encoder
        self.online_projector = MLP(encoder.feature_dim, hidden_dim, proj_dim)
        self.online_predictor = MLP(proj_dim, hidden_dim, proj_dim)

        self.target_encoder = copy.deepcopy(encoder)
        self.target_projector = copy.deepcopy(self.online_projector)

        for p in self.target_encoder.parameters(): p.requires_grad = False
        for p in self.target_projector.parameters(): p.requires_grad = False

    def update_target(self, m=0.996):
        for o, t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            t.data = t.data * m + o.data * (1. - m)
        for o, t in zip(self.online_projector.parameters(), self.target_projector.parameters()):
            t.data = t.data * m + o.data * (1. - m)

    def forward(self, x1, x2):
        z1 = self.online_projector(self.online_encoder(x1))
        z2 = self.online_projector(self.online_encoder(x2))
        p1 = self.online_predictor(z1)
        p2 = self.online_predictor(z2)

        with torch.no_grad():
            t1 = self.target_projector(self.target_encoder(x1))
            t2 = self.target_projector(self.target_encoder(x2))
        return p1, p2, t1.detach(), t2.detach()

def byol_loss(p, z):
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    return 2 - 2 * (p * z).sum(dim=1).mean()

# Pretraining
encoder = Encoder().to(DEVICE)
model = BYOL(encoder).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for (x1, x2) in tqdm(train_loader, desc=f"BYOL Epoch {epoch+1}"):
        x1, x2 = x1.to(DEVICE), x2.to(DEVICE)
        p1, p2, t1, t2 = model(x1, x2)
        loss = byol_loss(p1, t2) + byol_loss(p2, t1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.update_target()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss={total_loss / len(train_loader):.4f}")

torch.save(model.online_encoder.state_dict(), "byol_encoder.pth")

# === Linear Evaluation ===
class LinearClassifier(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(feat_dim, num_classes)
    def forward(self, x): return self.fc(x)

# Freeze encoder
eval_encoder = Encoder().to(DEVICE)
eval_encoder.load_state_dict(torch.load("byol_encoder.pth"))
for param in eval_encoder.parameters(): param.requires_grad = False
eval_encoder.eval()

classifier = LinearClassifier(2048, 38).to(DEVICE)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

eval_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
dataset = datasets.ImageFolder(DATA_DIR, transform=eval_tf)
train_size = int(0.8 * len(dataset))
train_ds, test_ds = random_split(dataset, [train_size, len(dataset)-train_size])
train_loader_eval = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader_eval = DataLoader(test_ds, batch_size=64, shuffle=False)

# Train classifier
for epoch in range(10):
    classifier.train()
    for x, y in train_loader_eval:
        x, y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad(): features = eval_encoder(x).squeeze()
        logits = classifier(features)
        loss = loss_fn(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate
classifier.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for x, y in test_loader_eval:
        x = x.to(DEVICE)
        features = eval_encoder(x).squeeze()
        preds = torch.argmax(classifier(features), dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(y.numpy())

print("\n BYOL Evaluation:")
print(classification_report(y_true, y_pred, digits=4))
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens')
plt.title("BYOL Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()
