# Dataset

In [26]:
import os
import time
import copy
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [27]:
base_dir = r"/Users/h383kim/pytorch/AlexNet/splitted"
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

BATCH_SIZE = 128

img_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

train_dataset = ImageFolder(root=train_dir,
                            transform=img_transform)
val_dataset = ImageFolder(root=val_dir,
                          transform=img_transform)
test_dataset = ImageFolder(root=test_dir, 
                           transform=img_transform)

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=os.cpu_count())
test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=os.cpu_count())

# Pre-trained and Fine-tuning

In [28]:
from torchvision import models

ResNet152 = models.resnet152(weights="IMAGENET1K_V2", progress=True)

In [29]:
from torchsummary import summary

#summary(ResNet152, input_size=(3, 224, 224), device="cpu") 
print(ResNet152)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [30]:
NUM_CLASSES = 10
IN_FEATURES = ResNet152.fc.in_features
# Modifying the model
ResNet152.fc = nn.Linear(in_features=IN_FEATURES, out_features=NUM_CLASSES)
print(ResNet152.fc)

Linear(in_features=2048, out_features=10, bias=True)


In [31]:
for name, child in ResNet152.named_children():
    if name != "fc":
        for param in child.parameters():
            param.requires_grad = False

In [32]:
print("After Freezing...")
for name, param in ResNet152.named_parameters():
    print(f"Layer: {name} | requires_grad: {param.requires_grad}")

After Freezing...
Layer: conv1.weight | requires_grad: False
Layer: bn1.weight | requires_grad: False
Layer: bn1.bias | requires_grad: False
Layer: layer1.0.conv1.weight | requires_grad: False
Layer: layer1.0.bn1.weight | requires_grad: False
Layer: layer1.0.bn1.bias | requires_grad: False
Layer: layer1.0.conv2.weight | requires_grad: False
Layer: layer1.0.bn2.weight | requires_grad: False
Layer: layer1.0.bn2.bias | requires_grad: False
Layer: layer1.0.conv3.weight | requires_grad: False
Layer: layer1.0.bn3.weight | requires_grad: False
Layer: layer1.0.bn3.bias | requires_grad: False
Layer: layer1.0.downsample.0.weight | requires_grad: False
Layer: layer1.0.downsample.1.weight | requires_grad: False
Layer: layer1.0.downsample.1.bias | requires_grad: False
Layer: layer1.1.conv1.weight | requires_grad: False
Layer: layer1.1.bn1.weight | requires_grad: False
Layer: layer1.1.bn1.bias | requires_grad: False
Layer: layer1.1.conv2.weight | requires_grad: False
Layer: layer1.1.bn2.weight | req

In [33]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ResNet152.parameters()), lr=0.00001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [34]:
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
DEVICE

'mps'

In [35]:
def train(model: torch.nn.Module,
          dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module):
    
    model.train()
    train_loss, train_acc, correct = 0, 0, 0
        
    for X, y in dataloader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        # Forward pass
        preds_prob = model(X) # Shape of preds_prob = (batch_size, num_classes)
        # Calculate the loss
        loss = loss_fn(preds_prob, y) # Shape of loss = [float] (i.e. scalar tensor containing the average loss over the batch)
        train_loss += loss.item()
        # Optimizer zero_grad
        optimizer.zero_grad()
        # Loss backward
        loss.backward()
        # Update
        optimizer.step()

        preds = torch.argmax(preds_prob, dim=1) # max values over the num_classes of (batch_size, num_classes)
        correct += preds.eq(y.view_as(preds)).sum().item()

    train_loss /= len(dataloader)
    train_acc = 100. * correct / len(dataloader.dataset)

    return train_loss, train_acc

In [36]:
def evaluate(model: torch.nn.Module,
             dataloader: torch.utils.data.DataLoader,
             loss_fn: torch.nn.Module):
    model.eval()
    val_loss, val_acc, correct = 0, 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            # Forward pass
            preds_prob = model(X)
            # Calculate the loss
            loss = loss_fn(preds_prob, y)
            val_loss += loss.item()

            preds = torch.argmax(preds_prob, dim=1)
            correct += preds.eq(y.view_as(preds)).sum().item()

        val_loss /= len(dataloader)
        val_acc = 100. * correct / len(dataloader.dataset)

    return val_loss, val_acc

In [37]:
def train_baseline(model: torch.nn.Module,
                   train_dataloader: torch.utils.data.DataLoader,
                   val_dataloader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer,
                   loss_fn: torch.nn.Module,
                   num_epochs: int):

    best_acc = 0
    best_model_wts = copy.deepcopy(model.state_dict())
    for epoch in range(1, num_epochs + 1):
        start = time.time()
        train_loss, train_acc = train(model, train_dataloader, optimizer, loss_fn)
        test_loss, test_acc = evaluate(model, val_dataloader, loss_fn)

        if test_acc > best_acc:
            best_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        end = time.time()
        time_elapsed = end - start
        print(f"------------ epoch {epoch} ------------")
        print(f"Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f}%")
        print(f"Val loss: {test_loss:.4f} | Val acc: {test_acc:2f}%")
        print(f"Time taken: {time_elapsed / 60:.0f}min {time_elapsed % 60:.0f}s")

    model.load_state_dict(best_model_wts)
    return model

In [38]:
fine_tuned = train_baseline(model=ResNet152.to("mps"),
                            train_dataloader=train_dataloader,
                            val_dataloader=val_dataloader,
                            optimizer=optimizer,
                            loss_fn=loss_fn,
                            num_epochs=3)

------------ epoch 1 ------------
Train loss: 2.2608 | Train acc: 14.86%
Val loss: 2.1816 | Val acc: 23.909717%
Time taken: 7min 31s
------------ epoch 2 ------------
Train loss: 2.1127 | Train acc: 29.22%
Val loss: 2.0454 | Val acc: 36.801836%
Time taken: 7min 30s
------------ epoch 3 ------------
Train loss: 1.9846 | Train acc: 41.06%
Val loss: 1.9223 | Val acc: 47.130834%
Time taken: 7min 31s
