# Class bias with data augmentation

In this notebook we will try to analyze and reproduce the results shown inside [this paper](https://arxiv.org/pdf/2204.03632.pdf) and find a possible solution for this problem. 
We will use the tiny-imagenet dataset.

### Mount google drive, load and prepare the data

Note that you will need the .zip containing the tiny-imagenet dataset in the root page of your drive

In [None]:
import os
import shutil
checkpoints = '/content/drive/MyDrive/'
if not os.path.exists('tiny-imagenet'):
  print("Copying to local runtime...")
  shutil.copy(checkpoints + 'tiny-imagenet-200.zip', './tiny-imagenet-200.zip')
  print("Uncompressing...")
  !unzip 'tiny-imagenet-200.zip'
print("Data ready!")

In [None]:
import io
import glob
import os
from shutil import move
from os.path import join
from os import listdir, rmdir

target_folder = './tiny-imagenet-200/val/'
test_folder   = './tiny-imagenet-200/testLabel/'

os.mkdir(test_folder)
val_dict = {}
with open('./tiny-imagenet-200/val/val_annotations.txt', 'r') as f:
    for line in f.readlines():
        split_line = line.split('\t')
        val_dict[split_line[0]] = split_line[1]
        
paths = glob.glob('./tiny-imagenet-200/val/images/*')
for path in paths:
    file = path.split('/')[-1]
    folder = val_dict[file]
    if not os.path.exists(target_folder + str(folder)):
        os.mkdir(target_folder + str(folder))
        os.mkdir(target_folder + str(folder) + '/images')
    if not os.path.exists(test_folder + str(folder)):
        os.mkdir(test_folder + str(folder))
        os.mkdir(test_folder + str(folder) + '/images')
        
        
for path in paths:
    file = path.split('/')[-1]
    folder = val_dict[file]
    if len(glob.glob(target_folder + str(folder) + '/images/*')) <25:
        dest = target_folder + str(folder) + '/images/' + str(file)
    else:
        dest = test_folder + str(folder) + '/images/' + str(file)
    move(path, dest)
    
rmdir('./tiny-imagenet-200/val/images')

## Single ResNet18

This will train a ResNet18 from scratch with 13 different level of crop, then it will plot the accuracy of each class for each level of crop.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_class = 200
crop_range = np.linspace(0.08, 1.0, num=13)
epochs = 20

def train(net, trainloader, criterion, optimizer):
    net.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs = inputs.to(device)
        labels = labels.to(device)
                
        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()

        if i % 100 == 0:
          print (f'Epoch [{e+1}/{epochs}], Loss: {loss.item():.6f}')

    return running_loss / len(trainloader)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

transform_test = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)

# Initialize the lists to store the training and testing accuracy
lower_bound = []
test_acc = []
class_acc = []

# Train the model for each random crop parameter value
for crop in crop_range:
    # Update the data augmentation pipeline with the new random crop parameter
    transform_train = transforms.Compose([
        transforms.Resize((64,64)),
        transforms.RandomResizedCrop(64, scale=(crop, 1.0), ratio=(0.8, 1.25)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/train/', transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=14)

    model = models.resnet18(pretrained=False, num_classes=n_class)
    # Adapt the model to 64x64 images
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    model = model.to(device)

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # Train the model
    for e in range(epochs):
        train_loss = train(model, trainloader, criterion, optimizer)

    # Test the model on the center crop of the test set
    acc, class_accCrop = test(model, testloader)

    # Save the training and testing accuracy
    lower_bound.append(100 * crop)
    test_acc.append(acc)
    class_acc.append(class_accCrop)

    path=  f'/content/drive/My Drive/Colab Notebooks/model/ResNet18/test{str(round(crop, 5))}.pth'
    torch.save(model.state_dict(), path)

# Plot the result in batch of 10 class each
for k in range(20):
  start = k * 10
  end = start + 10
  fig, ax = plt.subplots(figsize=(12,8))
  ax.plot(lower_bound, test_acc, '-o', label='Average Test Accuracy', color='blue')
  
  for i in range(start, end):
    ax.plot(lower_bound, [class_acc[j][i] for j in range(len(class_acc))], '-o', label='Class {} Accuracy'.format(i), color=plt.cm.tab10(i%10))

  ax.set_xlabel('Lower Bound on Random Crop Parameter (%)', fontsize=14)
  ax.set_ylabel('Accuracy (%)', fontsize=14)
  ax.set_xlim([0, 100])
  ax.set_ylim([0, 100])
  ax.legend(loc='lower right', fontsize=8)
  ax.grid(True, linestyle='--', alpha=0.5)
  ax.set_title('Accuracy vs Lower Bound on Random Crop Parameter', fontsize=16)
  plt.show()

## Single ResNet34

Same as before but using a ResNet34

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_class = 200
crop_range = np.linspace(0.08, 1.0, num=13)
epochs = 20

def train(net, trainloader, criterion, optimizer):
    net.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs = inputs.to(device)
        labels = labels.to(device)
                
        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()

        if i % 100 == 0:
          print (f'Epoch [{e+1}/{epochs}], Loss: {loss.item():.6f}')

    return running_loss / len(trainloader)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

transform_test = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)

# Initialize the lists to store the training and testing accuracy
lower_bound = []
test_acc = []
class_acc = []

# Train the model for each random crop parameter value
for crop in crop_range:
    # Update the data augmentation pipeline with the new random crop parameter
    transform_train = transforms.Compose([
        transforms.Resize((64,64)),
        transforms.RandomResizedCrop(64, scale=(crop, 1.0), ratio=(0.8, 1.25)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/train/', transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=14)

    model = models.resnet34(pretrained=False, num_classes=n_class)
    # Adapt the model to 64x64 images
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    model = model.to(device)

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # Train the model
    for e in range(epochs):
        train_loss = train(model, trainloader, criterion, optimizer)

    # Test the model on the center crop of the test set
    acc, class_accCrop = test(model, testloader)

    # Save the training and testing accuracy
    lower_bound.append(100 * crop)
    test_acc.append(acc)
    class_acc.append(class_accCrop)

    path=  f'/content/drive/My Drive/Colab Notebooks/model/ResNet34/test{str(round(crop, 5))}.pth'
    torch.save(model.state_dict(), path)

# Plot the result in batch of 10 class each
for k in range(20):
  start = k * 10
  end = start + 10
  fig, ax = plt.subplots(figsize=(12,8))
  ax.plot(lower_bound, test_acc, '-o', label='Average Test Accuracy', color='blue')
  
  for i in range(start, end):
    ax.plot(lower_bound, [class_acc[j][i] for j in range(len(class_acc))], '-o', label='Class {} Accuracy'.format(i), color=plt.cm.tab10(i%10))

  ax.set_xlabel('Lower Bound on Random Crop Parameter (%)', fontsize=14)
  ax.set_ylabel('Accuracy (%)', fontsize=14)
  ax.set_xlim([0, 100])
  ax.set_ylim([0, 100])
  ax.legend(loc='lower right', fontsize=8)
  ax.grid(True, linestyle='--', alpha=0.5)
  ax.set_title('Accuracy vs Lower Bound on Random Crop Parameter', fontsize=16)
  plt.show()

## Combined ResNet18 trained from scratch

This is the first version of the combined net. We are combining two separate ResNet18 concatenating the last Linear layer of the two in order to have a single output. It uses a dataset that gives to the combined net two images, one that is original and the other with data augmentation.

In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_class = 200
crop_range = np.linspace(0.08, 1.0, num=13)
epochs = 20

checkpoints = '/content/drive/MyDrive/Colab Notebooks/model/CombinedModel/checkpoints/'
trainPath = './tiny-imagenet-200/train/'
testPath = './tiny-imagenet-200/val/'

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
  
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class CombinedResnet18(nn.Module):
    def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=200):
        self.inplanes = 64
        super(CombinedResnet18, self).__init__()

        self.resnet1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Identity(),
            self._make_layer(block, 64, layers[0]),
            self._make_layer(block, 128, layers[1], stride=2),
            self._make_layer(block, 256, layers[2], stride=2),
            self._make_layer(block, 512, layers[3], stride=2),
            nn.AvgPool2d(7, stride=1)
        )

        self.inplanes = 64

        self.resnet2 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Identity(),
            self._make_layer(block, 64, layers[0]),
            self._make_layer(block, 128, layers[1], stride=2),
            self._make_layer(block, 256, layers[2], stride=2),
            self._make_layer(block, 512, layers[3], stride=2),
            nn.AvgPool2d(7, stride=1)
        )
        
        self.fc = nn.Linear(4096 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x1, x2):
        x1 = self.resnet1(x1)
        x1 = x1.view(x1.size(0), -1)

        x2 = self.resnet2(x2)
        x2 = x2.view(x2.size(0), -1)

        x = torch.cat((x1, x2), 1)
        x = self.fc(x)

        return x
    

def train(net, trainloader, criterion, optimizer, state=None, checkpoint_path=None, start_epoch=0):
    # Load previous training state
    if state:
        net.load_state_dict(state['net'])
        optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['epoch']
        losses = state['losses']

    for epoch in range(start_epoch, epochs):
      net.train()
      running_loss = 0.0
      for i, (originalInput, augmentedInput, labels) in enumerate(trainloader, 0):
          originalInput = originalInput.to(device)
          augmentedInput = augmentedInput.to(device)
          labels = labels.to(device)
                  
          optimizer.zero_grad()

          outputs = net(originalInput, augmentedInput)

          loss = criterion(outputs, labels)
          loss.backward()

          optimizer.step()
          running_loss += loss.item()

          if i % 100 == 0:
            print (f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.6f}')
      
      if checkpoint_path:
        state = {'epoch': epoch+1, 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'losses': running_loss}
        torch.save(state, checkpoint_path + 'checkpoint-%d.pkl'%(epoch+1))
        print("Checkpoint %d saved"%(epoch+1))

    return running_loss / len(trainloader)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images, images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

class CombinedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transformOriginal, transformAugmented):
        self.dataset = dataset
        self.transformOriginal = transformOriginal
        self.transformAugmented = transformAugmented

    def __getitem__(self, index):
        img, label = self.dataset[index]
        img1 = self.transformOriginal(img)
        img2 = self.transformAugmented(img)
        return img1, img2, label

    def __len__(self):
        return len(self.dataset)


transform_test = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_trainOriginal = transform_train = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root=testPath, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)


# Initialize the lists to store the training and testing accuracy
lower_bound = []
test_acc = []
class_acc = []

# Reload state from a previous checkpoint
#state = torch.load(checkpoints + 'checkpoint-19.pkl')
state = None

# Train the model for each random crop parameter value
for crop in crop_range:
    # Update the data augmentation pipeline with the new random crop parameter
    transform_train = transforms.Compose([
        transforms.Resize((64,64)),
        transforms.RandomResizedCrop(64, scale=(crop, 1.0), ratio=(0.8, 1.25)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = torchvision.datasets.ImageFolder(root=trainPath)
    trainset = CombinedDataset(dataset=dataset, transformOriginal=transform_trainOriginal, transformAugmented=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=14)

    model = CombinedResnet18()
    model = model.to(device)

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # Train the model
    train_loss = train(model, trainloader, criterion, optimizer, checkpoint_path=checkpoints, state=state)

    # Test the model on the center crop of the test set
    acc, class_accCrop = test(model, testloader)

    # Save the training and testing accuracy
    lower_bound.append(100 * crop)
    test_acc.append(acc)
    class_acc.append(class_accCrop)

    path = f'/content/drive/My Drive/Colab Notebooks/model/CombinedModel/test{str(round(crop, 5))}.pth'
    torch.save(model.state_dict(), path)

    state = None


for k in range(20):
  start = k * 10
  end = start + 10
  fig, ax = plt.subplots(figsize=(12,8))
  ax.plot(lower_bound, test_acc, '-o', label='Average Test Accuracy', color='blue')
  
  for i in range(start, end):
    ax.plot(lower_bound, [class_acc[j][i] for j in range(len(class_acc))], '-o', label='Class {} Accuracy'.format(i), color=plt.cm.tab10(i%10))

  ax.set_xlabel('Lower Bound on Random Crop Parameter (%)', fontsize=14)
  ax.set_ylabel('Accuracy (%)', fontsize=14)
  ax.set_xlim([0, 100])
  ax.set_ylim([0, 100])
  ax.legend(loc='lower right', fontsize=8)
  ax.grid(True, linestyle='--', alpha=0.5)
  ax.set_title('Accuracy vs Lower Bound on Random Crop Parameter', fontsize=16)
  plt.show()


## Combined ResNet18 pre-trained

In this part we created the combined net loading the previous trained ResNet18 and a ResNet18 trained without data augmentation. We then apply only a finetuning to the last common Linear layer.

In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

crop_range = np.linspace(0.08, 1.0, num=13)
crop_range = [ 0.46333]
n_class = 200
epochs = 10

checkpoints = '/content/drive/MyDrive/Colab Notebooks/model/CombinedModelPreTrained/checkpoints/'
trainPath = './tiny-imagenet-200/train/'
testPath = './tiny-imagenet-200/val/'
basicModelPath = f'/content/drive/My Drive/Colab Notebooks/model/CombinedModelPreTrained/basicModel.pth'

class CombinedResnet18(nn.Module):
    def __init__(self, basicModel, augmentedModel, num_classes=200):
        super(CombinedResnet18, self).__init__()

        self.resnet1 = basicModel
        self.resnet1.fc = nn.Identity()

        for param in self.resnet1.parameters():
          param.requires_grad = False

        self.resnet2 = augmentedModel
        self.resnet2.fc = nn.Identity()

        for param in self.resnet2.parameters():
          param.requires_grad = False
        
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x1, x2):
        x1 = self.resnet1(x1)
        x1 = x1.view(x1.size(0), -1)

        x2 = self.resnet2(x2)
        x2 = x2.view(x2.size(0), -1)

        x = torch.cat((x1, x2), 1)
        x = self.fc(x)

        return x

class CombinedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transformOriginal, transformAugmented):
        self.dataset = dataset
        self.transformOriginal = transformOriginal
        self.transformAugmented = transformAugmented

    def __getitem__(self, index):
        img, label = self.dataset[index]
        img1 = self.transformOriginal(img)
        img2 = self.transformAugmented(img)
        return img1, img2, label

    def __len__(self):
        return len(self.dataset)

def train(net, trainloader, criterion, optimizer, state=None, checkpoint_path=None, start_epoch=0):
    # Load previous training state
    if state:
        net.load_state_dict(state['net'])
        optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['epoch']
        losses = state['losses']

    for epoch in range(start_epoch, epochs):
      net.train()
      running_loss = 0.0
      for i, (originalInput, augmentedInput, labels) in enumerate(trainloader, 0):
          originalInput = originalInput.to(device)
          augmentedInput = augmentedInput.to(device)
          labels = labels.to(device)
                  
          optimizer.zero_grad()

          outputs = net(originalInput, augmentedInput)

          loss = criterion(outputs, labels)
          loss.backward()

          optimizer.step()
          running_loss += loss.item()

          if i % 100 == 0:
            print (f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.6f}')
      
      if checkpoint_path:
        state = {'epoch': epoch+1, 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'losses': running_loss}
        torch.save(state, checkpoint_path + 'checkpoint-%d.pkl'%(epoch+1))
        print("Checkpoint %d saved"%(epoch+1))

    return running_loss / len(trainloader)


basicModel = models.resnet18(pretrained=False, num_classes=200)

basicModel.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
basicModel.maxpool = nn.Identity()

basicModel.load_state_dict(torch.load(basicModelPath, map_location=device))


transform_trainOriginal = transform_train = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# Initialize the lists to store the training and testing accuracy
lower_bound = []
test_acc = []
class_acc = []

# Reload state from a previous checkpoint
#state = torch.load(checkpoints + 'checkpoint-7.pkl')
state = None

# Train the model for each random crop parameter value
for crop in crop_range:
    # Update the data augmentation pipeline with the new random crop parameter
    transform_train = transforms.Compose([
        transforms.Resize((64,64)),
        transforms.RandomResizedCrop(64, scale=(crop, 1.0), ratio=(0.8, 1.25)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = torchvision.datasets.ImageFolder(root=trainPath)
    trainset = CombinedDataset(dataset=dataset, transformOriginal=transform_trainOriginal, transformAugmented=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=14)

    augmentedModel = models.resnet18(pretrained=False, num_classes=200)
    augmentedModel.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    augmentedModel.maxpool = nn.Identity()

    augmentedModelPath = f'/content/drive/My Drive/Colab Notebooks/model/ResNet18ImageNet/test{str(round(crop, 5))}.pth'
    augmentedModel.load_state_dict(torch.load(augmentedModelPath, map_location=device))

    model = CombinedResnet18(basicModel, augmentedModel)
    model = model.to(device)  

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # Train the model
    train_loss = train(model, trainloader, criterion, optimizer, checkpoint_path=checkpoints, state=state)

    lower_bound.append(100 * crop)

    path = f'/content/drive/My Drive/Colab Notebooks/model/CombinedModelPreTrained/test{str(round(crop, 5))}.pth'
    torch.save(model.state_dict(), path)

    state = None

for k in range(20):
  start = k * 10
  end = start + 10
  fig, ax = plt.subplots(figsize=(12,8))
  ax.plot(lower_bound, test_acc, '-o', label='Average Test Accuracy', color='blue')
  
  for i in range(start, end):
    ax.plot(lower_bound, [class_acc[j][i] for j in range(len(class_acc))], '-o', label='Class {} Accuracy'.format(i), color=plt.cm.tab10(i%10))

  ax.set_xlabel('Lower Bound on Random Crop Parameter (%)', fontsize=14)
  ax.set_ylabel('Accuracy (%)', fontsize=14)
  ax.set_xlim([0, 100])
  ax.set_ylim([0, 100])
  ax.legend(loc='lower right', fontsize=8)
  ax.grid(True, linestyle='--', alpha=0.5)
  ax.set_title('Accuracy vs Lower Bound on Random Crop Parameter', fontsize=16)
  plt.show()

## Evaluate the results

These are some indexes that we used to evaluate the results of our experiments.
CVaR 5% is the mean of the worst 5% classes,
Top 5% is the mean of the best 5% classes

### ResNet18

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
from tabulate import tabulate

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

crop_range = np.linspace(0.08, 1.0, num=13)
n_class = 200

model = models.resnet18(pretrained=False, num_classes=n_class)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model = model.to(device)

transform_test = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

lower_bound = []
test_acc = []
class_acc = []

for crop in crop_range:
  path = f'/content/drive/My Drive/Colab Notebooks/model/ResNet18ImageNet/test{str(round(crop, 5))}.pth'
  model.load_state_dict(torch.load(path))
  model.eval()

  acc, class_accCrop = test(model, testloader)

  # Save the training and testing accuracy
  lower_bound.append(100 * crop)
  test_acc.append(acc)
  class_acc.append(class_accCrop)


data = []
tens = torch.Tensor(test_acc)
tensClass = torch.Tensor(class_acc)
fivePerc = n_class * 0.05

for i, crop in enumerate(crop_range):
    tmp = []

    sortedValues = torch.Tensor(sorted(tensClass[i]))

    tmp.append(crop)
    tmp.append(tens[i])
    tmp.append(tensClass[i].var())
    tmp.append(tensClass[i].std())
    tmp.append(sortedValues[:int(fivePerc)].mean())
    tmp.append(sortedValues[-int(fivePerc):].mean())
    data.append(tmp)


print(tabulate(data, headers=["Crop %", "Mean %", "Var", "Std", "CVaR 5%", "Top 5%"], tablefmt="psql"))


### ResNet34

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
from tabulate import tabulate

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

crop_range = np.linspace(0.08, 1.0, num=13)
n_class = 200

model = models.resnet34(pretrained=False, num_classes=n_class)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model = model.to(device)

transform_test = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

lower_bound = []
test_acc = []
class_acc = []

for crop in crop_range:
  path = f'/content/drive/My Drive/Colab Notebooks/model/ResNet34/test{str(round(crop, 5))}.pth'
  model.load_state_dict(torch.load(path))
  model.eval()

  acc, class_accCrop = test(model, testloader)

  # Save the training and testing accuracy
  lower_bound.append(100 * crop)
  test_acc.append(acc)
  class_acc.append(class_accCrop)


data = []
tens = torch.Tensor(test_acc)
tensClass = torch.Tensor(class_acc)
fivePerc = n_class * 0.05

for i, crop in enumerate(crop_range):
    tmp = []

    sortedValues = torch.Tensor(sorted(tensClass[i]))

    tmp.append(crop)
    tmp.append(tens[i])
    tmp.append(tensClass[i].var())
    tmp.append(tensClass[i].std())
    tmp.append(sortedValues[:int(fivePerc)].mean())
    tmp.append(sortedValues[-int(fivePerc):].mean())
    data.append(tmp)


print(tabulate(data, headers=["Crop %", "Mean %", "Var", "Std", "CVaR 5%", "Top 5%"], tablefmt="psql"))


### Combined ResNet18 trained from scratch

In [None]:
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tabulate import tabulate

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

crop_range = np.linspace(0.08, 1.0, num=13)
n_class = 200

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
  
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class CombinedResnet18(nn.Module):
    def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=200):
        self.inplanes = 64
        super(CombinedResnet18, self).__init__()

        self.resnet1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Identity(),
            self._make_layer(block, 64, layers[0]),
            self._make_layer(block, 128, layers[1], stride=2),
            self._make_layer(block, 256, layers[2], stride=2),
            self._make_layer(block, 512, layers[3], stride=2),
            nn.AvgPool2d(7, stride=1)
        )

        self.inplanes = 64

        self.resnet2 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Identity(),
            self._make_layer(block, 64, layers[0]),
            self._make_layer(block, 128, layers[1], stride=2),
            self._make_layer(block, 256, layers[2], stride=2),
            self._make_layer(block, 512, layers[3], stride=2),
            nn.AvgPool2d(7, stride=1)
        )
        
        self.fc = nn.Linear(4096 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x1, x2):
        x1 = self.resnet1(x1)
        x1 = x1.view(x1.size(0), -1)

        x2 = self.resnet2(x2)
        x2 = x2.view(x2.size(0), -1)

        x = torch.cat((x1, x2), 1)
        x = self.fc(x)

        return x

model = CombinedResnet18()
model = model.to(device)

transform_test = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images, images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

lower_bound = []
test_acc = []
class_acc = []

for crop in crop_range:
  path = f'/content/drive/My Drive/Colab Notebooks/model/CombinedModel/test{str(round(crop, 5))}.pth'
  model.load_state_dict(torch.load(path, map_location=device))
  model.eval()

  acc, class_accCrop = test(model, testloader)

  # Save the training and testing accuracy
  lower_bound.append(100 * crop)
  test_acc.append(acc)
  class_acc.append(class_accCrop)

data = []
tens = torch.Tensor(test_acc)
tensClass = torch.Tensor(class_acc)
fivePerc = n_class * 0.05

for i, crop in enumerate(crop_range):
    tmp = []

    sortedValues = torch.Tensor(sorted(tensClass[i]))

    tmp.append(crop)
    tmp.append(tens[i])
    tmp.append(tensClass[i].var())
    tmp.append(tensClass[i].std())
    tmp.append(sortedValues[:int(fivePerc)].mean())
    tmp.append(sortedValues[-int(fivePerc):].mean())
    data.append(tmp)


print(tabulate(data, headers=["Crop %", "Mean %", "Var", "Std", "CVaR 5%", "Top 5%"], tablefmt="psql"))

### Combined ResNet18 pre-trained

In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from tabulate import tabulate

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

crop_range = np.linspace(0.08, 1.0, num=13)
n_class = 200

class CombinedResnet18(nn.Module):
    def __init__(self, basicModel, augmentedModel, num_classes=200):
        super(CombinedResnet18, self).__init__()

        self.resnet1 = basicModel
        self.resnet1.fc = nn.Identity()

        for param in self.resnet1.parameters():
          param.requires_grad = False

        self.resnet2 = augmentedModel
        self.resnet2.fc = nn.Identity()

        for param in self.resnet2.parameters():
          param.requires_grad = False
        
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x1, x2):
        x1 = self.resnet1(x1)
        x1 = x1.view(x1.size(0), -1)

        x2 = self.resnet2(x2)
        x2 = x2.view(x2.size(0), -1)

        x = torch.cat((x1, x2), 1)
        x = self.fc(x)

        return x


basicModel = models.resnet18(pretrained=False, num_classes=200)
basicModel.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
basicModel.maxpool = nn.Identity()

augmentedModel = models.resnet18(pretrained=False, num_classes=200)
augmentedModel.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
augmentedModel.maxpool = nn.Identity()

model = CombinedResnet18(basicModel, augmentedModel)
model = model.to(device)

transform_test = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=14)

def test(net, testloader):
    net.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(n_class))
    class_total = list(0. for i in range(n_class))

    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images, images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            c = (predicted == labels).squeeze()

            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    acc = 100 * correct / total
    class_acc = [100 * class_correct[i] / class_total[i] for i in range(n_class)]

    return acc, class_acc

lower_bound = []
test_acc = []
class_acc = []
data = []

for crop in crop_range:
  path = f'/content/drive/My Drive/Colab Notebooks/model/CombinedModelPreTrained/test{str(round(crop, 5))}.pth'
  model.load_state_dict(torch.load(path, map_location=device))
  model.eval()

  acc, class_accCrop = test(model, testloader)

  # Save the training and testing accuracy
  lower_bound.append(100 * crop)
  test_acc.append(acc)
  class_acc.append(class_accCrop)

data = []
tens = torch.Tensor(test_acc)
tensClass = torch.Tensor(class_acc)
fivePerc = n_class * 0.05

for i, crop in enumerate(crop_range):
    tmp = []

    sortedValues = torch.Tensor(sorted(tensClass[i]))

    tmp.append(crop)
    tmp.append(tens[i])
    tmp.append(tensClass[i].var())
    tmp.append(tensClass[i].std())
    tmp.append(sortedValues[:int(fivePerc)].mean())
    tmp.append(sortedValues[-int(fivePerc):].mean())

    data.append(tmp)


print(tabulate(data, headers=["Crop %", "Mean %", "Var", "Std", "CVaR 5%", "Top 5%"], tablefmt="psql"))