In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset,DataLoader, Subset, random_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import StepLR
import torchvision.models as models

In [2]:
def set_device():
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    return device
device=set_device()
print(f"Using {device} device.")

Using mps device.


In [3]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
])

In [4]:
trainset = datasets.CIFAR10(root='./data', train=True, download=True)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
selected_classes = np.random.choice(range(10), 2, replace=False)
print(selected_classes)

[7 9]


In [6]:
model = models.vit_b_16(pretrained=True)

num_classes = 10
dropout_rate = 0.5 

model.heads = nn.Sequential(
    nn.Dropout(p=dropout_rate),  
    nn.Linear(model.heads[0].in_features, num_classes)
)



In [7]:
k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

In [8]:
best_val_accuracy = 0
best_model_weights = None

In [9]:
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [10]:
def validate(model, device, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(val_loader):
            data, targets = data.to(device), targets.to(device)
            output = model(data)
            loss = criterion(output, targets)

            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [11]:
class TransformSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

In [12]:
for fold, (train_ids, val_ids) in enumerate(kfold.split(trainset)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_indices = [i for i, (_, label) in enumerate(trainset) if label in selected_classes]

    N = 25  
    class_counts = {label: 0 for label in selected_classes}
    filtered_train_indices = []

    for i in train_indices:
        _, label = trainset[i]
        if class_counts[label] < N:
            filtered_train_indices.append(i)
            class_counts[label] += 1
    
    
    np.random.seed(42)  
    np.random.shuffle(filtered_train_indices)  
    split = int(0.8 * len(filtered_train_indices))  
    train_idx, val_idx = filtered_train_indices[:split], filtered_train_indices[split:]
    
    train_subset = Subset(trainset, train_idx)
    transformed_train_subset = TransformSubset(train_subset, transform=train_transform)

    val_subset = Subset(trainset, val_idx)
    transformed_val_subset = TransformSubset(val_subset, transform=test_transform)
    
    train_loader = DataLoader(transformed_train_subset, batch_size=5, shuffle=True)
    val_loader = DataLoader(transformed_val_subset, batch_size=5, shuffle=False)
    
    
    # Init the neural network
    model = model.to(device)
    
    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=2e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

    best_val_loss = float('inf')
    patience_counter = 0
    patience = 25

    
    # Run the training loop for defined number of epochs
    num_epochs = 50
    for epoch in range(num_epochs):
        # Print epoch
        print(f'Starting epoch {epoch+1}')
        
        # Perform training and validation
        train_loss, train_accuracy = train(model, device, train_loader, optimizer, criterion)
        val_loss, val_accuracy = validate(model, device, val_loader, criterion)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
        
        # Save the model if it has the best val accuracy so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_weights = model.state_dict().copy()  # Save the best model weights
            patience_counter = 0  # Reset patience
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break  # Stop training if no improvement
        
        # Step the learning rate scheduler
        scheduler.step()
            
    print('--------------------------------')
    
# Save the best model weights
torch.save(best_model_weights, 'c2_best_model.pth')
print(f'Best validation accuracy of {best_val_accuracy:.2f}% achieved, model saved as c2_best_model.pth')

FOLD 0
--------------------------------
Starting epoch 1
Train Loss: 1.7833, Train Acc: 42.50%, Val Loss: 0.9103, Val Acc: 90.00%
Starting epoch 2
Train Loss: 0.5530, Train Acc: 90.00%, Val Loss: 0.3101, Val Acc: 100.00%
Starting epoch 3
Train Loss: 0.2313, Train Acc: 95.00%, Val Loss: 0.2122, Val Acc: 100.00%
Starting epoch 4
Train Loss: 0.1644, Train Acc: 97.50%, Val Loss: 0.1252, Val Acc: 100.00%
Starting epoch 5
Train Loss: 0.0898, Train Acc: 100.00%, Val Loss: 0.0841, Val Acc: 100.00%
Starting epoch 6
Train Loss: 0.0416, Train Acc: 100.00%, Val Loss: 0.0652, Val Acc: 100.00%
Starting epoch 7
Train Loss: 0.0311, Train Acc: 100.00%, Val Loss: 0.0541, Val Acc: 100.00%
Starting epoch 8
Train Loss: 0.0237, Train Acc: 100.00%, Val Loss: 0.0431, Val Acc: 100.00%
Starting epoch 9
Train Loss: 0.0178, Train Acc: 100.00%, Val Loss: 0.0364, Val Acc: 100.00%
Starting epoch 10
Train Loss: 0.0147, Train Acc: 100.00%, Val Loss: 0.0324, Val Acc: 100.00%
Starting epoch 11
Train Loss: 0.0119, Train 

In [13]:
model = model.to(device)
model.load_state_dict(torch.load('c2_best_model.pth'))

<All keys matched successfully>

In [14]:
test_indices = [i for i, (_, label) in enumerate(testset) if label in selected_classes]
N = 1000 
class_counts = {label: 0 for label in selected_classes}
filtered_train_indices = []

for i in test_indices:
    _, label = testset[i]
    if class_counts[label] < N:
        filtered_train_indices.append(i)
        class_counts[label] += 1

test_subset = Subset(testset, filtered_train_indices)
test_loader = DataLoader(test_subset, batch_size=50, shuffle=False)

In [15]:
# Ensure the model is in evaluation mode
model.eval()

# Assuming the test_loader and the device are already defined
# Define the loss function
criterion = nn.CrossEntropyLoss()

# Evaluation
test_loss = 0
correct = 0
total = 0

# No gradient is needed for evaluation
with torch.no_grad():
    for data, targets in test_loader:
        # Move data and targets to the correct device
        data = data.to(device)
        targets = targets.to(device)

        # Compute the model output
        output = model(data)
        loss = criterion(output, targets)
        
        # Accumulate the loss and calculate accuracy
        test_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# Calculate average loss and accuracy percentage
avg_loss = test_loss / len(test_loader)
accuracy = 100. * correct / total

# Print the test loss and accuracy
print(f'Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')


Test Loss: 0.1201, Test Accuracy: 97.75%
