In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision
from torchvision.transforms import v2
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.metrics import precision_score, recall_score
from torchinfo import summary

import json

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
device

##### 1. `Download and preprocess the data`

In [None]:

train = torchvision.datasets.CIFAR10(root="./cifar10",train=True,download=True)
test = torchvision.datasets.CIFAR10(root="./cifar10",train=False,download=True)

In [None]:
classes = train.classes
classes

In [None]:
X_train = torch.from_numpy(train.data)
y_train = torch.tensor(train.targets)
X_test = torch.from_numpy(test.data)
y_test = torch.tensor(test.targets)

In [None]:
positive_class_idx = 1          # Automobiles
examples_per_class = 100
sz_img = 224


In [None]:
def divide_data(X,y,pos_cls,num):
  pos = (y==pos_cls)
  y_hold_true = pos.nonzero(as_tuple=True)[0]
  torch.manual_seed(42)
  chosen_true = y_hold_true[torch.randperm(len(y_hold_true))[:num]]

  neg = (y!=pos_cls)
  y_hold_false = neg.nonzero(as_tuple=True)[0]
  torch.manual_seed(42)
  chosen_false = y_hold_false[torch.randperm(len(y_hold_false))[:num]]

  y_hold = torch.cat([chosen_true,chosen_false])
  y_hold = y_hold[torch.randperm(len(y_hold))]
  X_hold = X[y_hold]
  y_hold = (y[y_hold] == pos_cls)
  return X_hold,y_hold

X_train100,y_train100 = divide_data(X_train,y_train,positive_class_idx,examples_per_class)

In [None]:
X_test1000,y_test1000 = divide_data(X_test,y_test,positive_class_idx,1000)


In [None]:
class MyDataset(torch.utils.data.Dataset):
  def __init__(self,X,y,transform=None,target_transform = None):
    self.X = X
    self.y = y
    self.transform = transform
    self.target_transform = target_transform

  def __len__(self):
    return len(self.X)

  def __getitem__(self,idx):
    x = self.X[idx]
    y = self.y[idx]

    if self.transform:
      x = self.transform(x)
    if self.target_transform:
      y = self.target_transform(y)
    return x,y

In [None]:
class MyTransform():
  def __init__(self,mean,std):
    self.mean = torch.tensor(mean).view(3,1,1)
    self.std = torch.tensor(std).view(3,1,1)

  def __call__(self,x):
    if x.dtype == torch.uint8:
            x = x.float() / 255.0
    else:
      x  = x/255.0
    x = x.permute(2,0,1)
    x = (x-self.mean)/self.std

    return x

mytransform = MyTransform(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

In [None]:
feature_transform = transforms.Compose([

    mytransform,
    transforms.Resize((sz_img,sz_img))
])
label_transform = transforms.Compose([
    v2.ToDtype(torch.float32)
])


In [None]:
test_dataset = MyDataset(X_test1000,y_test1000,transform=feature_transform,target_transform=label_transform)

In [None]:
train_dataset100 = MyDataset(X_train100,y_train100,transform=feature_transform,target_transform=label_transform)


In [None]:
train100_data_loader = torch.utils.data.DataLoader(train_dataset100,batch_size=16,shuffle=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset,batch_size=16,shuffle=True)


In [None]:
torch.manual_seed(42)
class CustomBaselineModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnns = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),   
            nn.ReLU(),
            nn.MaxPool2d(2),                                        

            nn.Conv2d(32, 64, kernel_size=3, padding=1),            
            nn.ReLU(),
            nn.MaxPool2d(2),                                        

            nn.Conv2d(64, 128, kernel_size=3, padding=1),           
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))                           
        )
        self.linear_layers = nn.Sequential(
            nn.Flatten(),         # [128]
            nn.Linear(128, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.cnns(x)
        x = self.linear_layers(x)
        return x


basemodel = CustomBaselineModel().to(device)

In [None]:
import copy

def train_model_with_early_stopping(model, train_loader, test_loader, device, max_epochs=30, lr=1e-3, patience=5):
    model = model.to(device)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_losses, test_losses = [], []
    train_accuracies, test_accuracies = [], []
    test_precisions, test_recalls = [], []

    best_acc = 0
    best_model_wts = copy.deepcopy(model.state_dict())
    wait = 0

    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float()

            optimizer.zero_grad()
            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            preds = (torch.sigmoid(outputs) > 0.5).long()
            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)

        avg_train_loss = train_loss / total
        train_acc = correct / total
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_acc)

        # Evaluation
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device).float()
                outputs = model(inputs).squeeze(1)
                loss = criterion(outputs, labels)

                test_loss += loss.item() * inputs.size(0)
                preds = (torch.sigmoid(outputs) > 0.5).long()
                correct += (preds == labels.long()).sum().item()
                total += labels.size(0)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_test_loss = test_loss / total
        test_acc = correct / total
        precision = precision_score(all_labels, all_preds, pos_label=1)
        recall = recall_score(all_labels, all_preds, pos_label=1)

        test_losses.append(avg_test_loss)
        test_accuracies.append(test_acc)
        test_precisions.append(precision)
        test_recalls.append(recall)

        print(f"Epoch {epoch+1}/{max_epochs} | Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Test Loss: {avg_test_loss:.4f} | Test Acc: {test_acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f}")

        # Early stopping
        if test_acc > best_acc:
            best_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    model.load_state_dict(best_model_wts)

    return {
        "train_losses": train_losses,
        "test_losses": test_losses,
        "train_accuracies": train_accuracies,
        "test_accuracies": test_accuracies,
        "test_precisions": test_precisions,
        "test_recalls": test_recalls,
        "best_test_accuracy": best_acc
    }
def save_metrics(metrics_dict, filename='model_metrics.json'):
    serializable_metrics = {}

    for k, v in metrics_dict.items():
        if isinstance(v, list):
            serializable_metrics[k] = list(map(float, v))
        else:  # handle scalar values like best_test_accuracy
            serializable_metrics[k] = float(v)

    with open(filename, 'w') as f:
        json.dump(serializable_metrics, f, indent=4)

def load_metrics(filename='model_metrics.json'):
    with open(filename, 'r') as f:
        metrics = json.load(f)

    parsed_metrics = {}
    for k, v in metrics.items():
        if isinstance(v, list):
            parsed_metrics[k] = [float(val) for val in v]
        else:
            parsed_metrics[k] = float(v)

    return parsed_metrics

In [None]:
base100 = train_model_with_early_stopping(basemodel, train100_data_loader, test_data_loader, device, max_epochs=50, lr=1e-3, patience=5)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(base100["train_losses"]))],base100["train_losses"], label="Train Loss")
plt.plot([x for x in range(len(base100["train_losses"]))],base100["test_losses"], label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(base100["train_losses"]))],base100["train_accuracies"], label="Train Accuracy")
plt.plot([x for x in range(len(base100["train_losses"]))],base100["test_accuracies"], label="Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
save_metrics(base100, filename='base100.json')

Train the resnet18 pretrained model on the data

In [None]:
torch.manual_seed(42)
resnet18_pretrainedmodel = models.resnet18(pretrained = True)
resnet18_pretrainedmodel.fc = nn.Linear(resnet18_pretrainedmodel.fc.in_features, 1)
for params in resnet18_pretrainedmodel.parameters():
  params.requires_grad = False
for param in resnet18_pretrainedmodel.fc.parameters():
    param.requires_grad = True
for param in resnet18_pretrainedmodel.layer4.parameters():
    param.requires_grad = True
resnet18_pretrainedmodel.to(device)

In [None]:
resnet18_pretrained100 = train_model_with_early_stopping(resnet18_pretrainedmodel, train100_data_loader, test_data_loader, device, max_epochs=50, lr=1e-3, patience=5)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(resnet18_pretrained100["train_losses"]))],resnet18_pretrained100["train_losses"], label="Train Loss")
plt.plot([x for x in range(len(resnet18_pretrained100["train_losses"]))],resnet18_pretrained100["test_losses"], label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(resnet18_pretrained100["train_accuracies"]))],resnet18_pretrained100["train_accuracies"], label="Train Accuracy")
plt.plot([x for x in range(len(resnet18_pretrained100["train_accuracies"]))],resnet18_pretrained100["test_accuracies"], label="Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracies")
plt.legend()
plt.show()

In [None]:
save_metrics(resnet18_pretrained100, filename='resnet18_pretrained100.json')

In [None]:

torch.manual_seed(42)
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class MyResNet18(nn.Module):
    def __init__(self):
        super(MyResNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(
            BasicBlock(64, 64),
            BasicBlock(64, 64)
        )

        downsample2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(128)
        )
        self.layer2 = nn.Sequential(
            BasicBlock(64, 128, stride=2, downsample=downsample2),
            BasicBlock(128, 128)
        )

        downsample3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(256)
        )
        self.layer3 = nn.Sequential(
            BasicBlock(128, 256, stride=2, downsample=downsample3),
            BasicBlock(256, 256)
        )

        downsample4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(512)
        )
        self.layer4 = nn.Sequential(
            BasicBlock(256, 512, stride=2, downsample=downsample4),
            BasicBlock(512, 512)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 1)  # binary classification

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


customresnet18 = MyResNet18().to(device)

In [None]:
customresnet18100 = train_model_with_early_stopping(customresnet18, train100_data_loader, test_data_loader, device, max_epochs=50, lr=1e-3, patience=5)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(customresnet18100["train_losses"]))],customresnet18100["train_losses"], label="Train Loss")
plt.plot([x for x in range(len(customresnet18100["train_losses"]))],customresnet18100["test_losses"], label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(customresnet18100["train_losses"]))],customresnet18100["train_accuracies"], label="Train Accuracy")
plt.plot([x for x in range(len(customresnet18100["train_losses"]))],customresnet18100["test_accuracies"], label="Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(base100["test_losses"]))],base100["test_losses"], label="Test Loss base200")
plt.plot([x for x in range(len(resnet18_pretrained100["test_losses"]))],resnet18_pretrained100["test_losses"], label="Test Loss pretrained Resnet18")
plt.plot([x for x in range(len(customresnet18100["test_losses"]))],customresnet18100["test_losses"], label="Test Loss Custom Resnet18")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot([x for x in range(len(base100["train_accuracies"]))],base100["train_accuracies"], label="Test Accuracy base200")
plt.plot([x for x in range(len(resnet18_pretrained100["train_accuracies"]))],resnet18_pretrained100["train_accuracies"], label="Test Accuracy pretrained Resnet18")
plt.plot([x for x in range(len(customresnet18100["train_accuracies"]))],customresnet18100["train_accuracies"], label="Test Accuracy Custom Resnet18")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
# load_metrics(filename='model_metrics.json')
def plot_graphs(lists , key , y_label , x_label,labels):
    plt.figure(figsize=(10,7))
    for idx,ls in enumerate(lists):
        plt.plot([x for x in range(len(ls[key]))],ls[key],label=labels[idx])
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.show()



In [None]:
base100 = load_metrics(filename='base100.json')
base1000 = load_metrics(filename='base1000.json')
base5000  = load_metrics(filename='base5000.json')
custom100  = load_metrics(filename='customresnet18100.json')
custom1000  = load_metrics(filename='customresnet181000.json')
custom5000  = load_metrics(filename='customresnet185000.json')
pretrained100  = load_metrics(filename='resnet18_pretrained100.json')
pretrained1000  = load_metrics(filename='resnet18_pretrained1000.json')
pretrained5000  = load_metrics(filename='resnet18_pretrained5000.json')

In [None]:
plot_graphs([base100,pretrained100,custom100],"test_accuracies",y_label="Test Accuracy",x_label="Epoch",labels=["base model Accuracy","Pretrained Resnet18 Accuracy","Custom Resnet18 Accuracy"])
plot_graphs([base1000,pretrained1000,custom1000],"test_accuracies",y_label="Test Accuracy",x_label="Epoch",labels=["base model Accuracy","Pretrained Resnet18 Accuracy","Custom Resnet18 Accuracy"])
plot_graphs([base5000,pretrained5000,custom5000],"test_accuracies",y_label="Test Accuracy",x_label="Epoch",labels=["base model Accuracy","Pretrained Resnet18 Accuracy","Custom Resnet18 Accuracy"])


In [None]:
plot_graphs([base100,pretrained100,custom100],"test_losses",y_label="Test Loss",x_label="Epoch",labels=["base model Loss","Pretrained Resnet18 Loss","Custom Resnet18 Loss"])
plot_graphs([base1000,pretrained1000,custom1000],"test_losses",y_label="Test Loss",x_label="Epoch",labels=["base model Loss","Pretrained Resnet18 Loss","Custom Resnet18 Loss"])
plot_graphs([base5000,pretrained5000,custom5000],"test_losses",y_label="Test Loss",x_label="Epoch",labels=["base model Loss","Pretrained Resnet18 Loss","Custom Resnet18 Loss"])


In [None]:
base100.keys()

In [None]:
base_best_acc = [base100['best_test_accuracy'],base1000['best_test_accuracy'],base5000['best_test_accuracy']]
custom_best_acc = [custom100['best_test_accuracy'],custom1000['best_test_accuracy'],custom5000['best_test_accuracy']]
pretrained_best_acc = [pretrained100['best_test_accuracy'],pretrained1000['best_test_accuracy'],pretrained5000['best_test_accuracy']]


In [None]:
bar_width = 200
x = np.array([100,1000,5000])
plt.figure(figsize=(8,5))
plt.bar(x - bar_width, base_best_acc, width=bar_width, label="Base Model")
plt.bar(x, custom_best_acc, width=bar_width, label="Custom ResNet18")
plt.bar(x + bar_width, pretrained_best_acc, width=bar_width, label="Pretrained ResNet18")
plt.xlabel("Number of Training Examples")
plt.ylabel("Top Accuracy")
plt.title("Top Accuracy vs Training Set Size for Different Models")
plt.xticks(x, [str(size) for size in [100,1000,5000]])
plt.ylim(0, 1)
plt.legend()
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()