In [None]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Less aggressive cropping
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),  # Reduce rotation angle
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Less aggressive color jitter
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import normalize

class MoCo(nn.Module):
    def __init__(self, base_model, feature_dim=128, queue_size=65536, momentum=0.999):
        super(MoCo, self).__init__()
        self.online_encoder = base_model
        self.target_encoder = base_model
        self.momentum = momentum
        self.queue_size = queue_size
        self.feature_dim = feature_dim

        self.online_projector = nn.Sequential(
            nn.Linear(1000, 4096),
            nn.ReLU(),
            nn.Linear(4096, feature_dim)
        )
        self.target_projector = nn.Sequential(
            nn.Linear(1000, 4096),
            nn.ReLU(),
            nn.Linear(4096, feature_dim)
        )

        self._initialize_target_encoder()

        self.register_buffer("queue", torch.randn(feature_dim, queue_size))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    def _initialize_target_encoder(self):
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False 

    @torch.no_grad()
    def _momentum_update_target_encoder(self):
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1 - self.momentum)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.queue_size
        self.queue_ptr[0] = ptr

    def forward(self, x_q, x_k):
        q = self.online_projector(self.online_encoder(x_q))
        q = normalize(q, dim=1)

        with torch.no_grad():
            self._momentum_update_target_encoder()
            k = self.target_projector(self.target_encoder(x_k))
            k = normalize(k, dim=1)

        queue = self.queue.clone().detach()
        logits_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        logits_neg = torch.einsum('nc,ck->nk', [q, queue])

        logits = torch.cat([logits_pos, logits_neg], dim=1)
        logits /= 0.07

        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        self._dequeue_and_enqueue(k)

        return logits, labels

def moco_loss_fn(logits, labels):
    return F.cross_entropy(logits, labels)

resnet_encoder = models.resnet50(pretrained=True)
resnet_encoder = nn.Sequential(*list(resnet_encoder.children())[:-1])

moco_model = MoCo(resnet_encoder)
optimizer = optim.Adam(moco_model.parameters(), lr=1e-3)

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    true_labels, pred_labels = [], []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
    
    avg_loss = train_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='macro')
    recall = recall_score(true_labels, pred_labels, average='macro')
    f1 = f1_score(true_labels, pred_labels, average='macro')
    
    return avg_loss, accuracy, precision, recall, f1

# Validation Loop
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    true_labels, pred_labels = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())
    
    avg_loss = val_loss / len(val_loader)
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='macro')
    recall = recall_score(true_labels, pred_labels, average='macro')
    f1 = f1_score(true_labels, pred_labels, average='macro')

    return avg_loss, accuracy, precision, recall, f1

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}:")
        train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_prec, val_rec, val_f1 = validate_epoch(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Recall: {train_rec:.4f}, F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Recall: {val_rec:.4f}, F1: {val_f1:.4f}")

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        if epoch % 5 == 0:
            plot_curves(epoch, train_losses, val_losses, train_accuracies, val_accuracies)

def plot_curves(epoch, train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 5))

    # Loss curves
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epoch+1), train_losses, label='Train Loss')
    plt.plot(range(1, epoch+1), val_losses, label='Validation Loss')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy curves
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epoch+1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, epoch+1), val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.savefig(f'accuracy_loss_epoch_{epoch}.png')
    plt.show()

In [None]:
import torch
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = models.efficientnet_b0(pretrained=True)
base_model = base_model.to(device)

In [None]:
byol_model = BYOL(base_model).to(device)
optimizer = optim.Adam(byol_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
classifier_model = Classifier(base_model).to(device)
optimizer = optim.Adam(classifier_model.parameters(), lr=1e-3)
train_model(classifier_model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device)