In [151]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [152]:
device = torch.device("mps")
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
writer = SummaryWriter("../runs/transfer_learning_experiment")

In [153]:
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    ),
}
data_dir = "../data/hymenoptera_data"
sets = ["train", "val"]
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in sets
}
dataloaders = {
    x: torch.utils.data.DataLoader(
        image_datasets[x], batch_size=4, shuffle=True, num_workers=4
    )
    for x in sets
}
dataset_sizes = {x: len(image_datasets[x]) for x in sets}
class_names = image_datasets["train"].classes
class_names

['ants', 'bees']

In [154]:
def train_model(
    model: nn.Module,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: lr_scheduler.LRScheduler,
    num_epochs: int = 25,
) -> nn.Module:
    since = time.time()

    best_model_wts = copy.deepcopy(
        model.state_dict()
    )  # Do this to keep the best model weights
    best_acc = 0.0

    for epoch in tqdm(range(num_epochs)):
        print(f"Epoch {epoch}/{num_epochs-1}")
        print("-" * 10)

        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            labels_lst = []
            preds_lst = []

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)

                    class_preds = [F.softmax(output, dim=0) for output in outputs]
                    labels_lst.append(preds)
                    preds_lst.append(class_preds)

                    loss = criterion(outputs, labels)

                    if phase == "train":
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            preds_lst = torch.cat([torch.stack(batch) for batch in preds_lst])
            labels_lst = torch.cat(labels_lst)

            print(preds_lst.shape, labels_lst.shape)
            print(f"Pred {preds_lst[:5]}")
            print(f"Labels {labels_lst[:5]}")

            if phase == "train":
                scheduler.step()
                classes = range(len(class_names))
                for i in classes:
                    labels_i = labels_lst == i # Get the labels of the class. == i means get the labels that are equal to i
                    preds_i = preds_lst[:, i] # Get the probability of the class. [:, i] means get the ith column of the tensor
                    writer.add_pr_curve(str(i), labels_i, preds_i, global_step=0)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
            writer.add_scalar(f"{phase} loss", epoch_loss, epoch)
            writer.add_scalar(f"{phase} accuracy", epoch_acc, epoch)

            # Deep copy the model
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best val Acc: {best_acc:.4f}")

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [155]:
model = models.resnet18(weights="ResNet18_Weights.DEFAULT")

# Freeze all layers except the final layer
for param in model.parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features

model.fc = nn.Linear(in_features=num_ftrs, out_features=2)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Scheduler
step_lr_scheduler = lr_scheduler.StepLR(
    optimizer, step_size=7, gamma=0.1
)  # For every 7 epochs, decrease the learning rate by 0.1

# Finetune the model
model = train_model(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=step_lr_scheduler,
    num_epochs=2,
)

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 0/1
----------


In [None]:
inputs, classes = next(iter(dataloaders["train"]))
batch1 = inputs[0].to(device)
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
writer.add_image("Training Images", out, 0)
writer.add_graph(model, batch1.unsqueeze(0))
writer.close()