In [1]:
import copy
import os
import numpy as np
from torchvision.datasets.folder import ImageFolder
from torchvision.datasets.folder import default_loader
import timm
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from sklearn.utils.class_weight import compute_class_weight
from torch.cuda.amp import autocast, GradScaler

In [2]:
# Parameters
img_height, img_width = 512, 512
batch_size = 16

# Data Loaders without normalization to calculate mean and std
calc_transforms = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor()
])

def get_mean_and_std(loader):
    # Var[X] = E[X^2] - (E[X])^2
    # So we will first calculate E[X] and E[X^2] to determine mean and std dev
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0

    for data, _ in loader:
        # Shape of data is [batch_size, 3, height, width]
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
    return mean, std

# Calculate mean and std for dataset
train_data_for_calc = datasets.ImageFolder(root='Train_Original/Train/Train', transform=calc_transforms)
train_loader_for_calc = DataLoader(train_data_for_calc, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)
mean, std = get_mean_and_std(train_loader_for_calc)
print(f'Calculated Mean: {mean}, Calculated Std: {std}')

Calculated Mean: tensor([0.4966, 0.4965, 0.4964]), Calculated Std: tensor([0.2105, 0.2105, 0.2106])


In [3]:
# Setting up data transformations and custom dataset class
def get_transforms():
    return transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4965, 0.4965, 0.4964],
                             std=[0.2105, 0.2105, 0.2106])
    ])

In [4]:
class CustomDataset(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        sample = sample.convert("RGB")
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

In [5]:
# Dataset and DataLoader setup
def get_dataloaders(train_dir, val_dir, batch_size):
    train_transforms = get_transforms()
    val_transforms = get_transforms()

    train_dataset = CustomDataset(root=train_dir, transform=train_transforms)
    val_dataset = CustomDataset(root=val_dir, transform=val_transforms)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

    return train_loader, val_loader, train_dataset

In [6]:
def setup_model(model, device, num_classes):
    # Load the ResNet-50 model pre-trained on ImageNet from timm
    model = timm.create_model(model, pretrained=True, num_classes=num_classes)
    
    # No need to modify the final layer manually as in torchvision, 
    # timm handles this with the num_classes parameter in create_model
    model = model.to(device)
    
    return model

In [7]:
def calculate_class_weights(train_dataset):
    # Count the number of occurrences of each class
    class_counts = torch.zeros((len(train_dataset.classes)), dtype=torch.float)
    for _, label in train_dataset:
        class_counts[label] += 1
    
    # Compute class weights (inverse frequency or another method)
    class_weights = 1. / class_counts
    class_weights = (class_weights / class_weights.sum()) * len(train_dataset.classes)
    return class_weights

In [8]:
# Validation function
def validate_model(model, val_loader, device):
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    val_corrects = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
            
            val_loss += loss.item() * inputs.size(0)
            val_corrects += torch.sum(preds == labels.data)

    epoch_loss = val_loss / len(val_loader.dataset)
    epoch_acc = val_corrects.double() / len(val_loader.dataset)
    return epoch_loss, epoch_acc

In [9]:
# Training loop with mixed precision
def train_model(model_name, model, train_loader, device, epochs, class_weights):
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=3e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.01)
    scaler = GradScaler()

    best_val_loss = float('inf')
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)

        scheduler.step()
        epoch_loss = running_loss / len(train_loader.dataset)
        val_loss, val_acc = validate_model(model, val_loader, device)

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
        
    # Save the best model found during training
    if best_model_state:
        torch.save(best_model_state, f'{model_name}_final.pth')
        print("Saved the best model based on validation loss.")

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {torch.cuda.get_device_name(0)}')

val_dir = 'Valid/'
train_dir = 'Cleaned_Images/'
batch_size = 32
num_classes = 3
epochs = 100

train_loader, val_loader, train_dataset = get_dataloaders(train_dir, val_dir, batch_size)

class_weights = calculate_class_weights(train_dataset)
class_weights = class_weights.to(device)

model_resnet = setup_model('resnet50', device, num_classes)
# model_efficientnet = setup_model('tf_efficientnet_b7', device, num_classes)
train_model('resnet50', model_resnet, train_loader, device, epochs, class_weights)
# train_model('tf_efficientnet_b7', model_efficientnet, train_loader, device, epochs, class_weights)

Using device: NVIDIA A100 80GB PCIe
Epoch 1/100, Train Loss: 0.8850, Val Loss: 0.8710, Val Acc: 0.6150
Epoch 2/100, Train Loss: 0.4257, Val Loss: 0.7312, Val Acc: 0.5800
Epoch 3/100, Train Loss: 0.2303, Val Loss: 0.7763, Val Acc: 0.6200
Epoch 4/100, Train Loss: 0.1720, Val Loss: 0.7440, Val Acc: 0.6350
Epoch 5/100, Train Loss: 0.1479, Val Loss: 0.7792, Val Acc: 0.6350
Epoch 6/100, Train Loss: 0.1245, Val Loss: 0.7412, Val Acc: 0.6500
Epoch 7/100, Train Loss: 0.1109, Val Loss: 0.7403, Val Acc: 0.6700
Epoch 8/100, Train Loss: 0.0972, Val Loss: 0.7263, Val Acc: 0.6600
Epoch 9/100, Train Loss: 0.0830, Val Loss: 0.7802, Val Acc: 0.6550
Epoch 10/100, Train Loss: 0.0699, Val Loss: 0.8230, Val Acc: 0.6700
Epoch 11/100, Train Loss: 0.0513, Val Loss: 0.9345, Val Acc: 0.6650
Epoch 12/100, Train Loss: 0.0481, Val Loss: 1.0397, Val Acc: 0.6600
Epoch 13/100, Train Loss: 0.0379, Val Loss: 1.0415, Val Acc: 0.6450
Epoch 14/100, Train Loss: 0.0347, Val Loss: 1.0412, Val Acc: 0.6750
Epoch 15/100, Train L

  class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)


In [None]:
from azureml.core import Workspace
ws = Workspace.from_config()  # Or use .get() with explicit parameters
compute_target = ws.compute_targets['akrishn21']
compute_target.stop(show_output=True)