In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

from tqdm import tqdm

In [2]:
!nvidia-smi

Mon Apr 14 18:39:48 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:01:00.0 Off |                   On |
| N/A   24C    P0             49W /  500W |      88MiB /  81920MiB |     N/A      Default |
|                                         |                        |            Disabled* |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          Off |   00

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = "cuda" if torch.cuda.is_available() else "cpu"

## Data Loading

In [4]:
# # LOCAL CLUSTER DIRECTORY FILE LOADING
# import pickle

# # Function to unpickle the dataset files
# def unpickle(file):
#     with open(file, 'rb') as fo:
#         data_dict = pickle.load(fo, encoding='bytes')
#     return data_dict

# # Function to load all data batches
# def load_data(batch_files):
#     all_data = []
#     all_labels = []
#     for batch_file in batch_files:
#         batch_file = DATASET_PATH + batch_file
#         batch = unpickle(batch_file)
#         all_data.append(batch[b'data'])  # Image data
#         all_labels.append(batch[b'labels'])  # Labels
#     data = np.concatenate(all_data)
#     labels = np.concatenate(all_labels)
#     return data, labels

# # Function to reshape image data
# def reshape_images(data):
#     # Reshape from (10000, 3072) to (10000, 32, 32, 3)
#     data = data.reshape(-1, 3, 32, 32)  # (10000, 3, 32, 32)
#     data = data.transpose(0, 2, 3, 1)  # (10000, 32, 32, 3)
#     return data

# # Load metadata (label names)
# def load_meta(file):
#     file = DATASET_PATH + file
#     meta = unpickle(file)
#     label_names = [name.decode('utf-8') for name in meta[b'label_names']]  # Decode bytes to strings
#     return label_names

# # Example usage
# # Files for the training batches and metadata
# batch_files = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
# DATASET_PATH = "/datasets/CS747/cifar-10-batches-py/"
# test_batch_file = 'test_batch'
# meta_file = 'batches.meta'

# # Load training data
# data, labels = load_data(batch_files)
# data = reshape_images(data)  # Reshape image data into (10000, 32, 32, 3)

# # Load test data
# test_data, test_labels = load_data([test_batch_file])
# test_data = reshape_images(test_data)  # Reshape test data into (10000, 32, 32, 3)

# # Load label names (metadata)
# label_names = load_meta(meta_file)

# # Example of accessing one image and its label:
# image_index = 3
# image = data[image_index]  # 32x32x3 image
# label = labels[image_index]
# label_name = label_names[label]

# print(f"Image shape: {image.shape}, Label: {label}, Label name: {label_name}")

In [5]:
# setup data loader
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])


# LOAD FROM TORCHVISION

train_mnist = datasets.MNIST(root='./data', download=True, train=True, transform=transforms.ToTensor())
test_mnist = datasets.MNIST(root='./data', download=True, train=False, transform=transforms.ToTensor())


#transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
full_train_cifar = datasets.CIFAR10(root='./data', download=True, train=True, transform=transform_train)
test_cifar = datasets.CIFAR10(root='./data', download=True, train=False, transform=transform_test)

In [6]:
BATCH_SIZE = 128

In [7]:
# Calculate sizes for split (e.g., 90% training, 10% validation)
num_train = int(0.9 * len(full_train_cifar))  # 45,000 images
num_val = len(full_train_cifar) - num_train    # 5,000 images

train_dataset, val_dataset = random_split(full_train_cifar, [num_train, num_val])

train_loader_cifar = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader_cifar = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE*2, shuffle=False)
test_loader_cifar = DataLoader(dataset=test_cifar, batch_size=BATCH_SIZE*2, shuffle=False)

train_loader_mnist = DataLoader(dataset=train_mnist, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
# MNIST CNN CONFIGS

mnist_cnn_config = {
    "learning_rate" : 0.01, # learning rate
    "weight_decay" : 3.5e-5,
    }

cifar_config = {
    "learning_rate" : 0.01,
    "weight_decay" : 3.5e-5,
    "epochs": 100
}



In [9]:
from collections import OrderedDict
import torch.nn as nn


class SmallCNN(nn.Module):
    def __init__(self, drop=0.5):
        super(SmallCNN, self).__init__()

        self.num_channels = 1
        self.num_labels = 10

        activ = nn.ReLU(True)

        self.feature_extractor = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(self.num_channels, 32, 3)),
            ('relu1', activ),
            ('conv2', nn.Conv2d(32, 32, 3)),
            ('relu2', activ),
            ('maxpool1', nn.MaxPool2d(2, 2)),
            ('conv3', nn.Conv2d(32, 64, 3)),
            ('relu3', activ),
            ('conv4', nn.Conv2d(64, 64, 3)),
            ('relu4', activ),
            ('maxpool2', nn.MaxPool2d(2, 2)),
        ]))

        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(64 * 4 * 4, 200)),
            ('relu1', activ),
            ('drop', nn.Dropout(drop)),
            ('fc2', nn.Linear(200, 200)),
            ('relu2', activ),
            ('fc3', nn.Linear(200, self.num_labels)),
        ]))

        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        nn.init.constant_(self.classifier.fc3.weight, 0)
        nn.init.constant_(self.classifier.fc3.bias, 0)

    def forward(self, input):
        features = self.feature_extractor(input)
        #print(features.shape)
        logits = self.classifier(features.view(-1, 64 * 4 * 4))
        return logits

In [10]:
## Resnet Model

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.residual = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = F.relu(self.bn2(self.conv2(out)), inplace=True)
        out = out + self.residual(x)
        out = F.relu(out, inplace=True)
        return out


class ResNet(nn.Module):
    def __init__(self, block: BasicBlock, num_blocks: list[int], num_classes: int=10):
        super().__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer0 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer1 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer2 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer3 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)


    
    def make_layer(self, block, out_channels, num_blocks, stride):
        layers = []

        layers.append(block(self.in_channels, out_channels, stride=stride))
        self.in_channels = out_channels * block.expansion

        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.layer0(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        logits = self.linear(out)
        return logits

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])


In [11]:
#model_test = ResNet18().to(device)

In [12]:
# x_rand = torch.randn((1, 3, 32, 32)).cuda()
# print(x_rand.shape)
# output = model_test(x_rand)
# print(output.shape)

In [13]:
def train_one_epoch(model,  dataloader, criterion, optimizer, device, max_grad_norm=None):
    model.train()

    running_loss = 0.0
    running_corrects = 0
    total_samples = 0

    for images, labels in tqdm(dataloader, desc="Training", leave=False):

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        if max_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs.detach(), -1)
        running_corrects += torch.sum(preds == labels.data)
        total_samples += images.size(0)

    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc = 'Validation', leave=False):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            preds = torch.argmax(outputs, -1)
            running_corrects += torch.sum((preds == labels)).item()
            total_samples += labels.size(0)
        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples

        return epoch_loss, epoch_acc

# def adjust_learning_rate(optimizer, epoch, config):
#     """decrease the learning rate"""
#     lr = config["learning_rate"]

#     if epoch >= 100:
#         lr = config["learning_rate"] * 0.001
#     elif epoch >= 90:
#         lr = config["learning_rate"] * 0.01
#     elif epoch >= 75:
#         lr = config["learning_rate"] * 0.1

#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr

In [14]:
def save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_dir, top_checkpoints, max_checkpoints=5):
    """
    Save a checkpoint if it's among the top best ones.
    
    Args:
        model (torch.nn.Module): The model to save.
        optimizer (torch.optim.Optimizer): The optimizer state to save.
        epoch (int): The current epoch.
        val_loss (float): The validation loss metric used for ranking.
        checkpoint_dir (str): Directory to save the checkpoints.
        top_checkpoints (list): List of tuples (val_loss, filename) for current top checkpoints.
        max_checkpoints (int, optional): Maximum number of checkpoints to save. Defaults to 5.
    """
    # Ensure the checkpoint directory exists.
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Construct a filename that contains the epoch and validation loss.
    checkpoint_filename = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}_valloss_{val_loss:.4f}.pth")
    
    # Create the checkpoint dictionary
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }
    
    # Save checkpoint to disk
    torch.save(checkpoint, checkpoint_filename)
    
    # Add the checkpoint to your top checkpoints list (we assume lower loss is better)
    top_checkpoints.append((val_loss, checkpoint_filename))
    
    # Sort the checkpoints by validation loss (ascending order)
    top_checkpoints.sort(key=lambda x: x[0])
    
    # If more than max_checkpoints, delete the worst checkpoint (last one in sorted list)
    if len(top_checkpoints) > max_checkpoints:
        worst_checkpoint = top_checkpoints.pop()  # pop the highest loss checkpoint
        if os.path.exists(worst_checkpoint[1]):
            os.remove(worst_checkpoint[1])
            print(f"Removed checkpoint: {worst_checkpoint[1]}")

In [15]:
def train_model(model, train_loader, val_loader, config, scheduler=None, device="cuda"):
    num_epochs = config["epochs"]
    
    top_checkpoints = []
    checkpoint_dir = config.get("checkpoint_dir", "./checkpoints")

    history = {
        "train_loss" : [],
        "train_acc" : [],
        "val_loss" : [],
        "val_acc" : []
    }

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(), 
                    lr = config["learning_rate"], 
                    momentum= 0.9,
                    weight_decay=config["weight_decay"])

    scheduler = MultiStepLR(optimizer=optimizer, milestones= [75, 90, 100], gamma=0.1)
    
    for epoch in range(1, num_epochs+1):
        #print(f"Epoch {epoch}/{num_epochs}")

        #adjust_learning_rate(optimizer, epoch, config=config)

        train_loss, train_acc = train_one_epoch(
            model = model, dataloader = train_loader, 
            criterion=criterion, optimizer=optimizer, 
            device=device, max_grad_norm=1.0
        )

        val_loss, val_acc = validate(model=model, dataloader=val_loader, criterion=criterion, device=device)

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"Epoch [{epoch}/{num_epochs}] | "
              f"LR: {scheduler.get_last_lr()[0]:.3f} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

        if scheduler:
            scheduler.step()

        save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_dir, top_checkpoints, max_checkpoints=5)
        
    

    return model, history, top_checkpoints

In [16]:
def main():
    model = ResNet18().to(device)

    model, history = train_model(model, 
                                 train_loader=train_loader_cifar, 
                                 val_loader=val_loader_cifar,
                                 config= cifar_config
    )

    print("Training Done")
    print('================================================================')

    test_loss, test_acc = validate(model, dataloader=test_loader_cifar, criterion=nn.CrossEntropyLoss(), device=device)
    print(f"Test accuracy: {test_acc}, Test loss: {test_loss}")

    return history


In [65]:
if __name__ == '__main__':
    history = main()

Training:   0%|          | 0/352 [00:00<?, ?it/s]

                                                           

Epoch [1/100] | LR: 0.010 | Train Loss: 1.5923 | Train Acc: 41.05% | Val Loss: 1.3970 | Val Acc: 47.84%


                                                           

Epoch [2/100] | LR: 0.010 | Train Loss: 1.1824 | Train Acc: 57.12% | Val Loss: 1.5524 | Val Acc: 49.64%


                                                           

Epoch [3/100] | LR: 0.010 | Train Loss: 0.9979 | Train Acc: 64.37% | Val Loss: 1.1023 | Val Acc: 60.52%


                                                           

Epoch [4/100] | LR: 0.010 | Train Loss: 0.8744 | Train Acc: 68.76% | Val Loss: 0.9641 | Val Acc: 65.74%


                                                           

Epoch [5/100] | LR: 0.010 | Train Loss: 0.7826 | Train Acc: 72.30% | Val Loss: 0.8617 | Val Acc: 70.32%


                                                           

Epoch [6/100] | LR: 0.010 | Train Loss: 0.6932 | Train Acc: 75.62% | Val Loss: 0.9013 | Val Acc: 68.76%
Removed checkpoint: ./checkpoints/model_epoch_3_valloss_1.5524.pth


                                                           

Epoch [7/100] | LR: 0.010 | Train Loss: 0.6305 | Train Acc: 77.91% | Val Loss: 0.6746 | Val Acc: 76.16%
Removed checkpoint: ./checkpoints/model_epoch_2_valloss_1.3970.pth


                                                           

Epoch [8/100] | LR: 0.010 | Train Loss: 0.5816 | Train Acc: 79.88% | Val Loss: 0.9722 | Val Acc: 68.02%
Removed checkpoint: ./checkpoints/model_epoch_4_valloss_1.1023.pth


                                                           

Epoch [9/100] | LR: 0.010 | Train Loss: 0.5369 | Train Acc: 81.25% | Val Loss: 0.7006 | Val Acc: 76.18%
Removed checkpoint: ./checkpoints/model_epoch_9_valloss_0.9722.pth


                                                           

Epoch [10/100] | LR: 0.010 | Train Loss: 0.4976 | Train Acc: 82.53% | Val Loss: 0.5604 | Val Acc: 80.62%
Removed checkpoint: ./checkpoints/model_epoch_5_valloss_0.9641.pth


                                                           

Epoch [11/100] | LR: 0.010 | Train Loss: 0.4661 | Train Acc: 83.73% | Val Loss: 0.5593 | Val Acc: 80.24%
Removed checkpoint: ./checkpoints/model_epoch_7_valloss_0.9013.pth


                                                           

Epoch [12/100] | LR: 0.010 | Train Loss: 0.4401 | Train Acc: 84.71% | Val Loss: 0.5446 | Val Acc: 81.28%
Removed checkpoint: ./checkpoints/model_epoch_6_valloss_0.8617.pth


                                                           

Epoch [13/100] | LR: 0.010 | Train Loss: 0.4122 | Train Acc: 85.65% | Val Loss: 0.6718 | Val Acc: 77.42%
Removed checkpoint: ./checkpoints/model_epoch_10_valloss_0.7006.pth


                                                           

Epoch [14/100] | LR: 0.010 | Train Loss: 0.3885 | Train Acc: 86.37% | Val Loss: 0.5752 | Val Acc: 80.70%
Removed checkpoint: ./checkpoints/model_epoch_8_valloss_0.6746.pth


                                                           

Epoch [15/100] | LR: 0.010 | Train Loss: 0.3659 | Train Acc: 87.35% | Val Loss: 0.5365 | Val Acc: 82.02%
Removed checkpoint: ./checkpoints/model_epoch_14_valloss_0.6718.pth


                                                           

Epoch [16/100] | LR: 0.010 | Train Loss: 0.3505 | Train Acc: 87.95% | Val Loss: 0.6019 | Val Acc: 80.12%
Removed checkpoint: ./checkpoints/model_epoch_17_valloss_0.6019.pth


                                                           

Epoch [17/100] | LR: 0.010 | Train Loss: 0.3266 | Train Acc: 88.68% | Val Loss: 0.4566 | Val Acc: 84.00%
Removed checkpoint: ./checkpoints/model_epoch_15_valloss_0.5752.pth


                                                           

Epoch [18/100] | LR: 0.010 | Train Loss: 0.3142 | Train Acc: 89.07% | Val Loss: 0.4780 | Val Acc: 83.96%
Removed checkpoint: ./checkpoints/model_epoch_11_valloss_0.5604.pth


                                                           

Epoch [19/100] | LR: 0.010 | Train Loss: 0.2924 | Train Acc: 89.85% | Val Loss: 0.5462 | Val Acc: 82.44%
Removed checkpoint: ./checkpoints/model_epoch_12_valloss_0.5593.pth


                                                           

Epoch [20/100] | LR: 0.010 | Train Loss: 0.2866 | Train Acc: 90.02% | Val Loss: 0.5038 | Val Acc: 83.52%
Removed checkpoint: ./checkpoints/model_epoch_20_valloss_0.5462.pth


                                                           

Epoch [21/100] | LR: 0.010 | Train Loss: 0.2680 | Train Acc: 90.73% | Val Loss: 0.5630 | Val Acc: 82.72%
Removed checkpoint: ./checkpoints/model_epoch_22_valloss_0.5630.pth


                                                           

Epoch [22/100] | LR: 0.010 | Train Loss: 0.2571 | Train Acc: 90.91% | Val Loss: 0.5002 | Val Acc: 83.82%
Removed checkpoint: ./checkpoints/model_epoch_13_valloss_0.5446.pth


                                                           

Epoch [23/100] | LR: 0.010 | Train Loss: 0.2447 | Train Acc: 91.57% | Val Loss: 0.4816 | Val Acc: 84.78%
Removed checkpoint: ./checkpoints/model_epoch_16_valloss_0.5365.pth


                                                           

Epoch [24/100] | LR: 0.010 | Train Loss: 0.2326 | Train Acc: 91.65% | Val Loss: 0.4467 | Val Acc: 85.82%
Removed checkpoint: ./checkpoints/model_epoch_21_valloss_0.5038.pth


                                                           

Epoch [25/100] | LR: 0.010 | Train Loss: 0.2237 | Train Acc: 92.23% | Val Loss: 0.4381 | Val Acc: 85.84%
Removed checkpoint: ./checkpoints/model_epoch_23_valloss_0.5002.pth


                                                           

Epoch [26/100] | LR: 0.010 | Train Loss: 0.2120 | Train Acc: 92.50% | Val Loss: 0.5084 | Val Acc: 84.52%
Removed checkpoint: ./checkpoints/model_epoch_27_valloss_0.5084.pth


                                                           

Epoch [27/100] | LR: 0.010 | Train Loss: 0.2007 | Train Acc: 92.94% | Val Loss: 0.4134 | Val Acc: 86.58%
Removed checkpoint: ./checkpoints/model_epoch_24_valloss_0.4816.pth


                                                           

Epoch [28/100] | LR: 0.010 | Train Loss: 0.1877 | Train Acc: 93.34% | Val Loss: 0.4866 | Val Acc: 85.06%
Removed checkpoint: ./checkpoints/model_epoch_29_valloss_0.4866.pth


                                                           

Epoch [29/100] | LR: 0.010 | Train Loss: 0.1817 | Train Acc: 93.64% | Val Loss: 0.6008 | Val Acc: 83.70%
Removed checkpoint: ./checkpoints/model_epoch_30_valloss_0.6008.pth


                                                           

Epoch [30/100] | LR: 0.010 | Train Loss: 0.1779 | Train Acc: 93.79% | Val Loss: 0.4546 | Val Acc: 86.38%
Removed checkpoint: ./checkpoints/model_epoch_19_valloss_0.4780.pth


                                                           

Epoch [31/100] | LR: 0.010 | Train Loss: 0.1679 | Train Acc: 94.09% | Val Loss: 0.4727 | Val Acc: 86.16%
Removed checkpoint: ./checkpoints/model_epoch_32_valloss_0.4727.pth


                                                           

Epoch [32/100] | LR: 0.010 | Train Loss: 0.1586 | Train Acc: 94.34% | Val Loss: 0.4929 | Val Acc: 85.54%
Removed checkpoint: ./checkpoints/model_epoch_33_valloss_0.4929.pth


                                                           

Epoch [33/100] | LR: 0.010 | Train Loss: 0.1580 | Train Acc: 94.36% | Val Loss: 0.4285 | Val Acc: 87.30%
Removed checkpoint: ./checkpoints/model_epoch_18_valloss_0.4566.pth


                                                           

Epoch [34/100] | LR: 0.010 | Train Loss: 0.1460 | Train Acc: 94.86% | Val Loss: 0.5633 | Val Acc: 84.60%
Removed checkpoint: ./checkpoints/model_epoch_35_valloss_0.5633.pth


                                                           

Epoch [35/100] | LR: 0.010 | Train Loss: 0.1426 | Train Acc: 95.01% | Val Loss: 0.4684 | Val Acc: 85.90%
Removed checkpoint: ./checkpoints/model_epoch_36_valloss_0.4684.pth


                                                           

Epoch [36/100] | LR: 0.010 | Train Loss: 0.1360 | Train Acc: 95.18% | Val Loss: 0.5018 | Val Acc: 84.96%
Removed checkpoint: ./checkpoints/model_epoch_37_valloss_0.5018.pth


                                                           

Epoch [37/100] | LR: 0.010 | Train Loss: 0.1253 | Train Acc: 95.61% | Val Loss: 0.4286 | Val Acc: 87.24%
Removed checkpoint: ./checkpoints/model_epoch_31_valloss_0.4546.pth


                                                           

Epoch [38/100] | LR: 0.010 | Train Loss: 0.1240 | Train Acc: 95.62% | Val Loss: 0.4194 | Val Acc: 87.42%
Removed checkpoint: ./checkpoints/model_epoch_25_valloss_0.4467.pth


                                                           

Epoch [39/100] | LR: 0.010 | Train Loss: 0.1219 | Train Acc: 95.75% | Val Loss: 0.4666 | Val Acc: 86.66%
Removed checkpoint: ./checkpoints/model_epoch_40_valloss_0.4666.pth


                                                           

Epoch [40/100] | LR: 0.010 | Train Loss: 0.1157 | Train Acc: 95.93% | Val Loss: 0.4387 | Val Acc: 87.52%
Removed checkpoint: ./checkpoints/model_epoch_41_valloss_0.4387.pth


                                                           

Epoch [41/100] | LR: 0.010 | Train Loss: 0.1074 | Train Acc: 96.15% | Val Loss: 0.4681 | Val Acc: 86.78%
Removed checkpoint: ./checkpoints/model_epoch_42_valloss_0.4681.pth


                                                           

Epoch [42/100] | LR: 0.010 | Train Loss: 0.1053 | Train Acc: 96.31% | Val Loss: 0.4729 | Val Acc: 86.78%
Removed checkpoint: ./checkpoints/model_epoch_43_valloss_0.4729.pth


                                                           

Epoch [43/100] | LR: 0.010 | Train Loss: 0.1036 | Train Acc: 96.36% | Val Loss: 0.5232 | Val Acc: 85.66%
Removed checkpoint: ./checkpoints/model_epoch_44_valloss_0.5232.pth


                                                           

Epoch [44/100] | LR: 0.010 | Train Loss: 0.0979 | Train Acc: 96.55% | Val Loss: 0.4658 | Val Acc: 87.34%
Removed checkpoint: ./checkpoints/model_epoch_45_valloss_0.4658.pth


                                                           

Epoch [45/100] | LR: 0.010 | Train Loss: 0.0951 | Train Acc: 96.73% | Val Loss: 0.4496 | Val Acc: 87.16%
Removed checkpoint: ./checkpoints/model_epoch_46_valloss_0.4496.pth


                                                           

Epoch [46/100] | LR: 0.010 | Train Loss: 0.0932 | Train Acc: 96.69% | Val Loss: 0.6077 | Val Acc: 85.28%
Removed checkpoint: ./checkpoints/model_epoch_47_valloss_0.6077.pth


                                                           

Epoch [47/100] | LR: 0.010 | Train Loss: 0.0900 | Train Acc: 96.78% | Val Loss: 0.4750 | Val Acc: 87.60%
Removed checkpoint: ./checkpoints/model_epoch_48_valloss_0.4750.pth


                                                           

Epoch [48/100] | LR: 0.010 | Train Loss: 0.0890 | Train Acc: 96.78% | Val Loss: 0.4472 | Val Acc: 87.68%
Removed checkpoint: ./checkpoints/model_epoch_49_valloss_0.4472.pth


                                                           

Epoch [49/100] | LR: 0.010 | Train Loss: 0.0816 | Train Acc: 97.12% | Val Loss: 0.5545 | Val Acc: 86.24%
Removed checkpoint: ./checkpoints/model_epoch_50_valloss_0.5545.pth


                                                           

Epoch [50/100] | LR: 0.010 | Train Loss: 0.0834 | Train Acc: 97.11% | Val Loss: 0.5834 | Val Acc: 85.06%
Removed checkpoint: ./checkpoints/model_epoch_51_valloss_0.5834.pth


                                                           

Epoch [51/100] | LR: 0.010 | Train Loss: 0.0779 | Train Acc: 97.30% | Val Loss: 0.4439 | Val Acc: 88.62%
Removed checkpoint: ./checkpoints/model_epoch_52_valloss_0.4439.pth


                                                           

Epoch [52/100] | LR: 0.010 | Train Loss: 0.0751 | Train Acc: 97.35% | Val Loss: 0.4739 | Val Acc: 88.42%
Removed checkpoint: ./checkpoints/model_epoch_53_valloss_0.4739.pth


                                                           

Epoch [53/100] | LR: 0.010 | Train Loss: 0.0729 | Train Acc: 97.45% | Val Loss: 0.5144 | Val Acc: 87.54%
Removed checkpoint: ./checkpoints/model_epoch_54_valloss_0.5144.pth


                                                           

Epoch [54/100] | LR: 0.010 | Train Loss: 0.0716 | Train Acc: 97.40% | Val Loss: 0.5253 | Val Acc: 87.32%
Removed checkpoint: ./checkpoints/model_epoch_55_valloss_0.5253.pth


                                                           

Epoch [55/100] | LR: 0.010 | Train Loss: 0.0706 | Train Acc: 97.44% | Val Loss: 0.5555 | Val Acc: 85.96%
Removed checkpoint: ./checkpoints/model_epoch_56_valloss_0.5555.pth


                                                           

Epoch [56/100] | LR: 0.010 | Train Loss: 0.0688 | Train Acc: 97.59% | Val Loss: 0.5357 | Val Acc: 87.02%
Removed checkpoint: ./checkpoints/model_epoch_57_valloss_0.5357.pth


                                                           

Epoch [57/100] | LR: 0.010 | Train Loss: 0.0644 | Train Acc: 97.69% | Val Loss: 0.4884 | Val Acc: 87.88%
Removed checkpoint: ./checkpoints/model_epoch_58_valloss_0.4884.pth


                                                           

Epoch [58/100] | LR: 0.010 | Train Loss: 0.0655 | Train Acc: 97.67% | Val Loss: 0.4522 | Val Acc: 88.94%
Removed checkpoint: ./checkpoints/model_epoch_59_valloss_0.4522.pth


                                                           

Epoch [59/100] | LR: 0.010 | Train Loss: 0.0602 | Train Acc: 97.84% | Val Loss: 0.4370 | Val Acc: 88.88%
Removed checkpoint: ./checkpoints/model_epoch_26_valloss_0.4381.pth


                                                           

Epoch [60/100] | LR: 0.010 | Train Loss: 0.0631 | Train Acc: 97.80% | Val Loss: 0.5179 | Val Acc: 87.90%
Removed checkpoint: ./checkpoints/model_epoch_61_valloss_0.5179.pth


                                                           

Epoch [61/100] | LR: 0.010 | Train Loss: 0.0621 | Train Acc: 97.82% | Val Loss: 0.4662 | Val Acc: 88.38%
Removed checkpoint: ./checkpoints/model_epoch_62_valloss_0.4662.pth


                                                           

Epoch [62/100] | LR: 0.010 | Train Loss: 0.0576 | Train Acc: 97.97% | Val Loss: 0.5007 | Val Acc: 88.22%
Removed checkpoint: ./checkpoints/model_epoch_63_valloss_0.5007.pth


                                                           

Epoch [63/100] | LR: 0.010 | Train Loss: 0.0568 | Train Acc: 98.03% | Val Loss: 0.4974 | Val Acc: 88.34%
Removed checkpoint: ./checkpoints/model_epoch_64_valloss_0.4974.pth


                                                           

Epoch [64/100] | LR: 0.010 | Train Loss: 0.0565 | Train Acc: 97.96% | Val Loss: 0.5001 | Val Acc: 88.26%
Removed checkpoint: ./checkpoints/model_epoch_65_valloss_0.5001.pth


                                                           

Epoch [65/100] | LR: 0.010 | Train Loss: 0.0533 | Train Acc: 98.09% | Val Loss: 0.5265 | Val Acc: 87.52%
Removed checkpoint: ./checkpoints/model_epoch_66_valloss_0.5265.pth


                                                           

Epoch [66/100] | LR: 0.010 | Train Loss: 0.0493 | Train Acc: 98.29% | Val Loss: 0.5054 | Val Acc: 88.66%
Removed checkpoint: ./checkpoints/model_epoch_67_valloss_0.5054.pth


                                                           

Epoch [67/100] | LR: 0.010 | Train Loss: 0.0529 | Train Acc: 98.16% | Val Loss: 0.4860 | Val Acc: 88.86%
Removed checkpoint: ./checkpoints/model_epoch_68_valloss_0.4860.pth


                                                           

Epoch [68/100] | LR: 0.010 | Train Loss: 0.0500 | Train Acc: 98.21% | Val Loss: 0.5265 | Val Acc: 87.74%
Removed checkpoint: ./checkpoints/model_epoch_69_valloss_0.5265.pth


                                                           

Epoch [69/100] | LR: 0.010 | Train Loss: 0.0507 | Train Acc: 98.19% | Val Loss: 0.5325 | Val Acc: 88.20%
Removed checkpoint: ./checkpoints/model_epoch_70_valloss_0.5325.pth


                                                           

Epoch [70/100] | LR: 0.010 | Train Loss: 0.0489 | Train Acc: 98.27% | Val Loss: 0.4866 | Val Acc: 88.54%
Removed checkpoint: ./checkpoints/model_epoch_71_valloss_0.4866.pth


                                                           

Epoch [71/100] | LR: 0.010 | Train Loss: 0.0472 | Train Acc: 98.40% | Val Loss: 0.5315 | Val Acc: 88.34%
Removed checkpoint: ./checkpoints/model_epoch_72_valloss_0.5315.pth


                                                           

Epoch [72/100] | LR: 0.010 | Train Loss: 0.0484 | Train Acc: 98.24% | Val Loss: 0.5775 | Val Acc: 87.18%
Removed checkpoint: ./checkpoints/model_epoch_73_valloss_0.5775.pth


                                                           

Epoch [73/100] | LR: 0.010 | Train Loss: 0.0474 | Train Acc: 98.33% | Val Loss: 0.5114 | Val Acc: 88.74%
Removed checkpoint: ./checkpoints/model_epoch_74_valloss_0.5114.pth


                                                           

Epoch [74/100] | LR: 0.010 | Train Loss: 0.0450 | Train Acc: 98.44% | Val Loss: 0.5152 | Val Acc: 88.84%
Removed checkpoint: ./checkpoints/model_epoch_75_valloss_0.5152.pth


                                                           

Epoch [75/100] | LR: 0.010 | Train Loss: 0.0458 | Train Acc: 98.38% | Val Loss: 0.4955 | Val Acc: 88.78%
Removed checkpoint: ./checkpoints/model_epoch_76_valloss_0.4955.pth


                                                           

Epoch [76/100] | LR: 0.001 | Train Loss: 0.0203 | Train Acc: 99.32% | Val Loss: 0.4125 | Val Acc: 90.46%
Removed checkpoint: ./checkpoints/model_epoch_60_valloss_0.4370.pth


                                                           

Epoch [77/100] | LR: 0.001 | Train Loss: 0.0110 | Train Acc: 99.69% | Val Loss: 0.3907 | Val Acc: 90.60%
Removed checkpoint: ./checkpoints/model_epoch_38_valloss_0.4286.pth


                                                           

Epoch [78/100] | LR: 0.001 | Train Loss: 0.0099 | Train Acc: 99.70% | Val Loss: 0.4019 | Val Acc: 90.52%
Removed checkpoint: ./checkpoints/model_epoch_34_valloss_0.4285.pth


                                                           

Epoch [79/100] | LR: 0.001 | Train Loss: 0.0081 | Train Acc: 99.78% | Val Loss: 0.4111 | Val Acc: 90.66%
Removed checkpoint: ./checkpoints/model_epoch_39_valloss_0.4194.pth


                                                           

Epoch [80/100] | LR: 0.001 | Train Loss: 0.0067 | Train Acc: 99.84% | Val Loss: 0.4019 | Val Acc: 91.08%
Removed checkpoint: ./checkpoints/model_epoch_28_valloss_0.4134.pth


                                                           

Epoch [81/100] | LR: 0.001 | Train Loss: 0.0065 | Train Acc: 99.84% | Val Loss: 0.3964 | Val Acc: 90.92%
Removed checkpoint: ./checkpoints/model_epoch_77_valloss_0.4125.pth


                                                           

Epoch [82/100] | LR: 0.001 | Train Loss: 0.0060 | Train Acc: 99.86% | Val Loss: 0.3893 | Val Acc: 91.02%
Removed checkpoint: ./checkpoints/model_epoch_80_valloss_0.4111.pth


                                                           

Epoch [83/100] | LR: 0.001 | Train Loss: 0.0054 | Train Acc: 99.86% | Val Loss: 0.3908 | Val Acc: 91.26%
Removed checkpoint: ./checkpoints/model_epoch_81_valloss_0.4019.pth


                                                           

Epoch [84/100] | LR: 0.001 | Train Loss: 0.0045 | Train Acc: 99.90% | Val Loss: 0.3982 | Val Acc: 91.06%
Removed checkpoint: ./checkpoints/model_epoch_79_valloss_0.4019.pth


                                                           

Epoch [85/100] | LR: 0.001 | Train Loss: 0.0047 | Train Acc: 99.91% | Val Loss: 0.4135 | Val Acc: 91.08%
Removed checkpoint: ./checkpoints/model_epoch_86_valloss_0.4135.pth


                                                           

Epoch [86/100] | LR: 0.001 | Train Loss: 0.0043 | Train Acc: 99.93% | Val Loss: 0.3925 | Val Acc: 91.06%
Removed checkpoint: ./checkpoints/model_epoch_85_valloss_0.3982.pth


                                                           

Epoch [87/100] | LR: 0.001 | Train Loss: 0.0037 | Train Acc: 99.93% | Val Loss: 0.4181 | Val Acc: 90.74%
Removed checkpoint: ./checkpoints/model_epoch_88_valloss_0.4181.pth


                                                           

Epoch [88/100] | LR: 0.001 | Train Loss: 0.0037 | Train Acc: 99.94% | Val Loss: 0.4232 | Val Acc: 90.94%
Removed checkpoint: ./checkpoints/model_epoch_89_valloss_0.4232.pth


                                                           

Epoch [89/100] | LR: 0.001 | Train Loss: 0.0034 | Train Acc: 99.95% | Val Loss: 0.4161 | Val Acc: 91.08%
Removed checkpoint: ./checkpoints/model_epoch_90_valloss_0.4161.pth


                                                           

Epoch [90/100] | LR: 0.001 | Train Loss: 0.0033 | Train Acc: 99.95% | Val Loss: 0.4019 | Val Acc: 91.02%
Removed checkpoint: ./checkpoints/model_epoch_91_valloss_0.4019.pth


                                                           

Epoch [91/100] | LR: 0.000 | Train Loss: 0.0030 | Train Acc: 99.95% | Val Loss: 0.4311 | Val Acc: 91.02%
Removed checkpoint: ./checkpoints/model_epoch_92_valloss_0.4311.pth


                                                           

Epoch [92/100] | LR: 0.000 | Train Loss: 0.0032 | Train Acc: 99.94% | Val Loss: 0.4220 | Val Acc: 91.32%
Removed checkpoint: ./checkpoints/model_epoch_93_valloss_0.4220.pth


                                                           

Epoch [93/100] | LR: 0.000 | Train Loss: 0.0030 | Train Acc: 99.95% | Val Loss: 0.4215 | Val Acc: 91.00%
Removed checkpoint: ./checkpoints/model_epoch_94_valloss_0.4215.pth


                                                           

Epoch [94/100] | LR: 0.000 | Train Loss: 0.0029 | Train Acc: 99.95% | Val Loss: 0.4081 | Val Acc: 90.94%
Removed checkpoint: ./checkpoints/model_epoch_95_valloss_0.4081.pth


                                                           

Epoch [95/100] | LR: 0.000 | Train Loss: 0.0032 | Train Acc: 99.94% | Val Loss: 0.4125 | Val Acc: 91.12%
Removed checkpoint: ./checkpoints/model_epoch_96_valloss_0.4125.pth


                                                           

Epoch [96/100] | LR: 0.000 | Train Loss: 0.0028 | Train Acc: 99.96% | Val Loss: 0.4069 | Val Acc: 91.34%
Removed checkpoint: ./checkpoints/model_epoch_97_valloss_0.4069.pth


                                                           

Epoch [97/100] | LR: 0.000 | Train Loss: 0.0030 | Train Acc: 99.95% | Val Loss: 0.3978 | Val Acc: 91.40%
Removed checkpoint: ./checkpoints/model_epoch_98_valloss_0.3978.pth


                                                           

Epoch [98/100] | LR: 0.000 | Train Loss: 0.0031 | Train Acc: 99.94% | Val Loss: 0.4304 | Val Acc: 90.88%
Removed checkpoint: ./checkpoints/model_epoch_99_valloss_0.4304.pth


                                                           

Epoch [99/100] | LR: 0.000 | Train Loss: 0.0028 | Train Acc: 99.95% | Val Loss: 0.4165 | Val Acc: 90.82%
Removed checkpoint: ./checkpoints/model_epoch_100_valloss_0.4165.pth


                                                           

Epoch [100/100] | LR: 0.000 | Train Loss: 0.0027 | Train Acc: 99.95% | Val Loss: 0.4339 | Val Acc: 90.74%
Removed checkpoint: ./checkpoints/model_epoch_101_valloss_0.4339.pth


ValueError: too many values to unpack (expected 2)

In [None]:
# _, test_acc = validate(model, dataloader=test_loader_cifar, criterion=nn.CrossEntropyLoss(), device=device)
# print(f"Test accuracy: {test_acc}")

## Adver training

In [None]:
def create_adv_samples(model, inputs, labels, config, random_start=True):
    epsilon = config["epsilon"]
    step_size = config["step_size"]
    num_steps = config["num_steps"]
    #norm = config["norm"]

    model.eval()
    # Keep original clean samples
    x_clean = inputs.clone().detach()


    if random_start:
        delta = torch.rand_like(inputs, device=device) * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(inputs, device=device)
    
    delta = torch.clamp(delta, -epsilon, epsilon)
    delta.requires_grad = True

    for step in range(num_steps):
        x_adv = x_clean + delta
        
        # for stable batch norm stats and to disable dropout
        with torch.enable_grad():
            outputs = model(x_adv)
        loss = F.cross_entropy(outputs, labels)

        grad = torch.autograd.grad(loss, delta)[0]
        delta = delta.detach() + step_size * torch.sign(grad)
        delta = torch.clamp(delta, -epsilon, epsilon)

        # clamp adversarial sample to valid pixel range        
        x_adv_clamped = torch.clamp(x_clean + delta, 0.0, 1.0)
        
        # get accurate delat from clamped x
        delta = x_adv_clamped - x_clean
        delta = delta.detach()
        
        if step < num_steps - 1:
            delta.requires_grad = True
    
    # final adversarial sample
    x_adv_batch = (x_clean + delta).detach()

    return x_adv_batch


def mart_loss(logits_clean, logits_adv, labels, lambda_reg):
    # BCE LOSS = standard cross entropy + margin maximization
    # get the probabilty distribution for the adversarial logits
    probs_adv = F.softmax(logits_adv, dim=-1)

    # sort it and get the two highest values/prediction
    tmp1 = torch.argsort(input=probs_adv, dim=-1)[-2:]
    
    # get the max probability for the incorrect class prediction. basically if the highest probability is for the correct class, get the next higest probability
    labels_new = torch.where(tmp1[:,-1] == labels, input=tmp1[:,-2], other=tmp1[:,-1])

    bce_loss = F.cross_entropy(logits_adv, labels) + F.nll_loss(torch.log(1 - probs_adv + 1e-12), labels_new)


    # KL term
    kl = nn.KLDivLoss(reduction='none')
    probs_clean = F.softmax(logits_clean, dim=-1)
    probs_true = torch.gather(probs_clean, dim=1, index=(labels.unsqueeze(1)).long()).squeeze()
    reg_loss = torch.sum(torch.sum(kl(torch.log(probs_adv + 1e-12), probs_clean), dim=1) * 1.000001 - probs_true) / labels.size[0]

    loss = bce_loss + float(lambda_reg) * reg_loss

    return loss

    # reg_loss = torch.sum(torch.sum(kl(torch.log(probs_adv + 1e-12), probs_nat), dim=1) * (1 - true_probs)) / inputs.size(0)

    # loss_robust = (1.0 / batch_size) * torch.sum(torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs))
    pass



def adv_train_one_epoch(model,  dataloader, criterion, optimizer, config, device="cuda", max_grad_norm=None):
    #model.train()

    running_loss = 0.0
    running_corrects = 0
    total_samples = 0

    for images, labels in tqdm(dataloader, desc="Training", leave=False):

        images = images.to(device)
        labels = labels.to(device)


        model.train()
        x_adv = create_adv_samples(model, inputs=images, labels=labels, config=config, random_start=True)

        optimizer.zero_grad()

        # Forward pass
        logits_clean = model(images)
        logits_adv = model(x_adv)

        loss = mart_loss(logits_clean, logits_adv, labels, lambda_reg=config["lambda_reg"])
        loss.backward()
        
        if max_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs.detach(), -1)
        running_corrects += torch.sum(preds == labels.data)
        total_samples += images.size(0)

    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc