# To build and train resnet, use the pytorch

In [None]:
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from accelerate import Accelerator
from evaluate import load
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [None]:
# Set device
accelerator = Accelerator(device_placement=True)
device = accelerator.device

In [None]:
# Define hyperparameters
num_epochs = 200
batch_size = 2 ** 5
learning_rate = 0.01

In [None]:
norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),

# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    norm,
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    norm,
])

In [None]:
# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Load pre-trained ResNet101 model
model = models.densenet121(weights=None)
model.classifier = nn.Linear(model.classifier.in_features, 100)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create the learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 0.65 ** x)

# Prepare for distributed training
model, optimizer, train_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, test_dataloader
)

In [None]:
# Load evaluation metrics
accuracy = load("accuracy")
f1 = load("f1")

In [None]:
# Training loop
for epoch in range(num_epochs):
    tic = datetime.now()
    model.train()
    train_loss = 0.0

    for images, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_dataloader)

    model.eval()
    test_loss = 0.0
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for images, labels in test_dataloader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(accelerator.gather(preds))  #.cpu().numpy())
            test_labels.extend(accelerator.gather(labels))  #.cpu().numpy())

    test_loss /= len(test_dataloader)
    test_acc = accuracy.compute(references=test_labels, predictions=test_preds)["accuracy"]
    test_f1 = f1.compute(references=test_labels, predictions=test_preds, average="macro")["f1"]

    # Update the learning rate based on validation loss
    scheduler.step(test_loss)

    # Time calculation
    toc = datetime.now()
    elapsed_time = toc - tic
    elapsed_time_in_hh_mm_ss = str(elapsed_time).split('.')[0]

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, "
        f"Test Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}, "
        f'Elapsed Time: {elapsed_time_in_hh_mm_ss}\n'
    )