# To build and train resnet, use the pytorch

In [None]:
import os
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from accelerate import Accelerator
from evaluate import load
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [None]:
# Set device
accelerator = Accelerator(device_placement=True)
device = accelerator.device

In [None]:
# Define hyperparameters
num_epochs = 70
batch_size = 2 ** 7

In [None]:
# weights = models.ResNet50_Weights.DEFAULT

In [None]:
norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    norm,
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    norm,
])

In [None]:
# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
d = datasets.CIFAR100(root='./data', train=False, download=True)
print(f'{d.data.shape}')

# Randomly select an image from the training dataset
index = np.random.randint(0, len(d))
image, label = d[index]
target = d.classes[label]

# Plot the image
plt.imshow(image)
plt.title(f"Label: {target}")
plt.axis('off')
plt.show()
print(target)

del index, image, label, target, plt, d

In [None]:
# Load model architecture.
model = models.resnet50(weights=None)
# model.eval()

model.fc = nn.Linear(model.fc.in_features, 100)  # resnet

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

model, optimizer, scheduler, train_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, scheduler, train_dataloader, test_dataloader
)

In [None]:
# Load evaluation metrics
accuracy = load("accuracy")
f1 = load("f1")

In [None]:
training_result = []

# Training loop
for epoch in range(num_epochs):
    tic = datetime.now()

    model.train()

    train_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_dataloader)

    model.eval()
    test_loss = 0.0
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for images, labels in test_dataloader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(accelerator.gather(preds).cpu().numpy())
            test_labels.extend(accelerator.gather(labels).cpu().numpy())

    test_loss /= len(test_dataloader)
    test_acc = accuracy.compute(references=test_labels, predictions=test_preds)["accuracy"]
    test_f1 = f1.compute(references=test_labels, predictions=test_preds, average="macro")["f1"]

    # Update the learning rate based on validation loss
    scheduler.step()

    # Time calculation
    toc = datetime.now()
    elapsed_time = toc - tic
    elapsed_time_in_hh_mm_ss = str(elapsed_time).split('.')[0]

    print(
        f"Epoch [{epoch + 1}/{num_epochs}]: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, "
        f'Train Accuracy: {correct / total:.3f}, '
        f"Test Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}, "
        f'lr: {optimizer.param_groups[0]["lr"]}, '
        f'Elapsed Time: {elapsed_time_in_hh_mm_ss}\n'
    )

    training_result.append({
        'train_loss': train_loss,
        'test_loss': test_loss,
        'train_acc': correct / total,
        'test_acc': test_acc,
        'lr': optimizer.param_groups[0]["lr"]
    })

In [None]:
# index is epoch number.
tr = pd.DataFrame(training_result, columns=['train_loss', 'test_loss', 'train_acc', 'test_acc', 'lr'])
tr.to_csv('resnet_cifar100_result.csv')
tr

In [None]:
torch.save(model.state_dict(), 'resnet_cifar100.pth')
# Get the size of the saved model file
model_size = os.path.getsize('resnet_cifar100.pth') / (1024 * 1024)  # Size in MB
print(f"Pruned model size: {model_size:.2f} MB")