In [None]:
'''
We include some files from the author's repo 
https://github.com/zhanghang1989/ResNeSt/tree/master/resnest/torch/models
we change the path for some files to make it consistent with ours, like 
from .resnet import * -> from resnet import*
'''


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import time
from torchvision import transforms
import torchvision
import numpy as np
from resnet import ResNet, Bottleneck

class Trainer:
    def __init__(self, name, model, criterion, optimizer, device):
        """
        Initializes the Trainer.

        Args:
            model (nn.Module): The PyTorch model to train.
            criterion (nn.Module): The loss function.
            optimizer (torch.optim.Optimizer): The optimizer.
            device (torch.device): The device to train on (e.g., 'cuda' or 'cpu').
        """
        self.name = name
        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []

    def train_epoch(self, dataloader):
        """
        Trains the model for one epoch.

        Args:
            dataloader (DataLoader): The DataLoader for the training set.

        Returns:
            float: The average training loss for the epoch.
        """
        self.model.train()
        total_loss = 0.0
        num_batches = len(dataloader)

        for inputs, labels in dataloader:
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / num_batches
        self.train_losses.append(avg_loss)
        return avg_loss

    def validate_epoch(self, dataloader):
        """
        Evaluates the model on the validation set for one epoch.

        Args:
            dataloader (DataLoader): The DataLoader for the validation set.

        Returns:
            float: The average validation loss for the epoch.
            float: The average validation accuracy for the epoch.
        """
        self.model.eval()
        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        num_batches = len(dataloader)

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                total_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                total_samples += labels.size(0)
                correct_predictions += (predicted == labels).sum().item()

        avg_loss = total_loss / num_batches
        accuracy = correct_predictions / total_samples
        self.val_losses.append(avg_loss)
        self.val_accuracies.append(accuracy)
        return avg_loss, accuracy

    def train(self, train_dataloader, val_dataloader, num_epochs):
        """
        Trains the model for a specified number of epochs and validates it, recording latency and throughput.

        Args:
            train_dataloader (DataLoader): The DataLoader for the training set.
            val_dataloader (DataLoader): The DataLoader for the validation set.
            num_epochs (int): The number of training epochs.
        """
        print(f"Training on {self.device}")
        for epoch in range(num_epochs):
            train_loss = self.train_epoch(train_dataloader)
            val_loss, val_acc = self.validate_epoch(val_dataloader)
            print(f"Epoch [{epoch+1}/{num_epochs}], "
                  f"Train Loss: {train_loss:.4f}, "
                  f"Val Loss: {val_loss:.4f}, "
                  f"Val Acc: {val_acc:.4f}")

    def plot_losses_accuracies(self):
        """
        Plots the training loss, validation loss, and validation accuracy against the number of epochs.
        """
        epochs = range(1, len(self.train_losses) + 1)

        plt.figure()
        plt.plot(epochs, self.train_losses, label='Train Loss')
        plt.plot(epochs, self.val_losses, label='Validation Loss')
        plt.plot(epochs, self.val_accuracies, label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(self.name + ' Training and Validation Loss')
        plt.legend()
        plt.show()

trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
# Load the CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=trans)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=trans)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=256)

@torch.no_grad()
def measure_inference_time_with_warmup(model, data_loader, warmup_steps=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    # Warm-up
    for i, (inputs, _) in enumerate(data_loader):
        inputs = inputs.to(device)
        _ = model(inputs)
        if i >= warmup_steps - 1:
            break

    total_inference_time = 0
    num_samples = 0
    start_time = time.time() # Start timer after warm-up

    for inputs, labels in data_loader:
        inputs = inputs.to(device)
        _ = model(inputs)
        num_samples += inputs.size(0)

    end_time = time.time()
    total_inference_time = end_time - start_time
    throughput = num_samples / total_inference_time
    avg_latency_per_sample = total_inference_time / num_samples
    print(f"Average Latency per Sample (with warm-up): {avg_latency_per_sample * 1000:.2f} ms")
    print(f"Throughput (with warm-up): {throughput:.2f} samples/second")
    return avg_latency_per_sample, throughput

def fit_one(name,model,epochs,train_dataloader,val_dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trainer = Trainer(name = name,model = model,criterion = nn.CrossEntropyLoss(),optimizer= optim.SGD(model.parameters(),weight_decay=0.0001,momentum=0.9),device=device)
    trainer.train(train_dataloader, val_dataloader, epochs)
    trainer.plot_losses_accuracies()
    max_acc = np.max(trainer.val_accuracies)
    avg_latency, throughput = measure_inference_time_with_warmup(model, val_dataloader)
    return avg_latency,max_acc

In [None]:
#Ablation Study
model_dict = {}
result_list = []
model_dict['resnet14_1s1x64d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2)
model_dict['resnest14_fast_0s1x64d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2,radix=0, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=32, avg_down=True,
                   avd=True, avd_first=True)
model_dict['resnest14_fast_1s1x64d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2,radix=1, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=32, avg_down=True,
                   avd=True, avd_first=True)
model_dict['resnest14_fast_1s1x40d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2,radix=1, groups=1, bottleneck_width=40,
                   deep_stem=True, stem_width=32, avg_down=True,
                   avd=True, avd_first=True)
model_dict['resnest14_fast_1s2x64d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2,radix=1, groups=2, bottleneck_width=64,
                   deep_stem=True, stem_width=32, avg_down=True,
                   avd=True, avd_first=True)
model_dict['resnet14_0s1x64d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2,radix=0)
model_dict['resnest14_fast_2s1x64d'] = ResNet(Bottleneck, [1, 1, 1, 1],num_classes=10,final_drop=0.2,radix=2, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=32, avg_down=True,
                   avd=True, avd_first=True)
for name,model in model_dict.items():
    avg_latency,max_acc = fit_one(name,model,20,train_dataloader,val_dataloader)
    result_list.append((name,avg_latency,max_acc))
print(result_list)

def count_parameters(model):
    """
    Counts the total number of parameters in a PyTorch model.

    Args:
        model (torch.nn.Module): The PyTorch model.

    Returns:
        int: The total number of parameters.
    """
    return sum(p.numel() for p in model.parameters())/1e6

for name,model in model_dict.items():
    print(name,count_parameters(model))

In [None]:
#Model Comparison 1
model_dict = {}
result_list = []
model_dict['resnest_14'] = ResNet(Bottleneck, [1, 1, 1, 1],
                                  radix=2, groups=1, bottleneck_width=64,
                                  deep_stem=True, stem_width=32, avg_down=True,
                                  avd=True, avd_first=False, num_classes=10,final_drop=0.2)
model_dict['resnest_32'] = ResNet(Bottleneck, [2, 3, 3, 2],
                                  radix=2, groups=1, bottleneck_width=64,
                                  deep_stem=True, stem_width=32, avg_down=True,
                                  avd=True, avd_first=False, num_classes=10,final_drop=0.2)
model_dict['resnest_50'] = ResNet(Bottleneck, [3, 4, 6, 3],
                                  radix=2, groups=1, bottleneck_width=64,
                                  deep_stem=True, stem_width=32, avg_down=True,
                                  avd=True, avd_first=False, num_classes=10,final_drop=0.2)
model_dict['resnest_101'] = ResNet(Bottleneck, [3, 4, 23, 3],
                                  radix=2, groups=1, bottleneck_width=64,
                                  deep_stem=True, stem_width=32, avg_down=True,
                                  avd=True, avd_first=False,num_classes=10,final_drop=0.2)
for name,model in model_dict.items():
    avg_latency,max_acc = fit_one(name,model,20,train_dataloader,val_dataloader)
    result_list.append((name,avg_latency,max_acc))
print(result_list)

In [None]:
#Model Comparison 2
from torchvision.models import efficientnet_b0,efficientnet_b1,efficientnet_b2,efficientnet_b3,resnet18
model_dict = {}
result_list = []
model_dict['efficientnet_b0'] = efficientnet_b0(num_classes = 10, dropout = 0.2)
model_dict['efficientnet_b1'] = efficientnet_b1(num_classes = 10, dropout = 0.2)
model_dict['efficientnet_b2'] = efficientnet_b2(num_classes = 10, dropout = 0.2)
model_dict['efficientnet_b3'] = efficientnet_b3(num_classes = 10, dropout = 0.2)
model_dict['resnet18'] = resnet18(num_classes = 10)
for name,model in model_dict.items():
    avg_latency,max_acc = fit_one(name,model,20,train_dataloader,val_dataloader)
    result_list.append((name,avg_latency,max_acc))
print(result_list)

In [None]:
import torch
import torchvision
from d2l import torch as d2l
import os
from torch import nn
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"


d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = d2l.download_extract('hotdog')

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))


# Specify means and standard deviations of three RGB channels to standardize each channel
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# Using the mean and std of Imagenet is a common practice. They are calculated
# based on millions of images. If you want to train from scratch on your own
# dataset, you can calculate the new mean and std. Otherwise, using Imagenet
# pretrained model with its own mean and std is recommended.


train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(), normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(), normalize])


# If `param_group=True`, parameters in output layer are updated using learning rate 10 times greater
def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=5, param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs), batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs), batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters() # all parameters not in output layer
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(),
                  'lr': learning_rate * 10}], lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

def transform_fc_train(model):
    model.fc = nn.Linear(model.fc.in_features, 2)   # change no. of classes to 2
    nn.init.xavier_uniform_(model.fc.weight)
    train_fine_tuning(model, 5e-5)
# Fine-tuning
# using ResNeSt-50 as an example
# get list of models
torch.hub.list('zhanghang1989/ResNeSt', force_reload=True)

# load pretrained models, using ResNeSt-50 as an example
resnest50_pretrained = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
transform_fc_train(resnest50_pretrained)

In [None]:
from torchvision.models import resnet50,ResNet50_Weights
resnet50_pretrained = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
transform_fc_train(resnet50_pretrained)