In [43]:
import os
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Dataset

In [45]:
base_dir = r"/Users/h383kim/pytorch/AlexNet/splitted"
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

BATCH_SIZE = 128

img_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

train_dataset = ImageFolder(root=train_dir,
                            transform=img_transform)
val_dataset = ImageFolder(root=val_dir,
                          transform=img_transform)
test_dataset = ImageFolder(root=test_dir, 
                           transform=img_transform)

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=os.cpu_count())
test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=os.cpu_count())

# Import the pre-trained VGG19_BN

In [4]:
from torchvision import models

vgg19 = models.vgg19_bn(weights="IMAGENET1K_V1", progress=True)

Downloading: "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" to /Users/h383kim/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth
100%|████████████████████████████████████████| 548M/548M [00:12<00:00, 47.4MB/s]


# Check model architecture

In [5]:
print(vgg19)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [9]:
# Classifier Part
print(vgg19.classifier)

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)


# Modify the last layer
# Pre-trained model classifies into 1000 labels while we want to classify 10 labels

In [20]:
NUM_CLASSES = 10
IN_FEATURES = vgg19.classifier[-1].in_features

vgg19.classifier[-1] = nn.Linear(in_features=IN_FEATURES, out_features=NUM_CLASSES)
print(vgg19.classifier)

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=10, bias=True)
)


# Freeze all layers except for the last two nn.Linear

In [33]:
for name, child in vgg19.named_children():
    if name in ["features", "avgpool"]:
        for param in child.parameters():
            param.requires_grad = False
    else:
        for param in child[0].parameters():
            param.requires_grad = False

# Optimizer, Loss

In [34]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, vgg19.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Fine-tuning

In [35]:
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
DEVICE

'mps'

In [36]:
def train(model: torch.nn.Module,
          dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module=nn.CrossEntropyLoss()):
    # Put the model into train mode
    model.train()
    train_loss, train_acc, correct = 0, 0, 0

    for batch, (X, y) in enumerate(dataloader):
        # Send the data into cpu or gpu
        X, y = X.to(DEVICE), y.to(DEVICE)

        # Forward pass
        preds_prob = model(X)

        # Calculate the loss
        loss = loss_fn(preds_prob, y)
        train_loss += loss.item()

        # Optimizer zero_grad
        optimizer.zero_grad()

        # Backpropagtion
        loss.backward()
        # Optimizer step
        optimizer.step()

        pred = torch.argmax(preds_prob, dim=1)
        correct += pred.eq(y.view_as(pred)).sum().item()

    train_loss /= len(dataloader)
    train_acc = 100. * correct / len(dataloader.dataset)
    return train_loss, train_acc

In [37]:
def evaluate(model: torch.nn.Module,
             dataloader: torch.utils.data.DataLoader,
             optimizer: torch.optim.Optimizer,
             loss_fn: torch.nn.Module=nn.CrossEntropyLoss()):
    # Put the model into eval mode
    model.eval()
    test_loss, test_acc, correct = 0, 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(DEVICE), y.to(DEVICE)

            # Forward pass
            preds_prob = model(X)

            # Calculate the loss
            loss = loss_fn(preds_prob, y).item()
            test_loss += loss

            pred = torch.argmax(preds_prob, dim=1)
            correct += pred.eq(y.view_as(pred)).sum().item()

        test_loss /= len(dataloader)
        test_acc = 100. * correct / len(dataloader.dataset)

    return test_loss, test_acc

In [38]:
import time
import copy

def train_baseline(model: torch.nn.Module, 
                   train_dataloader: torch.utils.data.DataLoader, 
                   val_dataloader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer,
                   loss_fn: torch.nn.Module=nn.CrossEntropyLoss(),
                   num_epochs: int=30):
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        # Train the model and print save the results
        train_loss, train_acc = train(model=model,
                                      dataloader=train_dataloader, 
                                      optimizer=optimizer,
                                      loss_fn=loss_fn)
        
        val_loss, val_acc = evaluate(model=model,
                                     dataloader=val_dataloader,
                                     optimizer=optimizer,
                                     loss_fn=loss_fn)
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            
        end_time = time.time()
        time_elapsed = end_time - start_time
        print(f"------------ epoch {epoch} ------------")
        print(f"Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f}%")
        print(f"Val loss: {val_loss:.4f} | Val acc: {val_acc:2f}%")
        print(f"Time taken: {time_elapsed / 60:.0f}min {time_elapsed % 60:.0f}s")
        
    model.load_state_dict(best_model_wts)
    return model  

In [47]:
fine_tuned = train_baseline(model=vgg19.to(DEVICE),
                            train_dataloader=train_dataloader,
                            val_dataloader=val_dataloader,
                            optimizer=optimizer,
                            loss_fn=loss_fn,
                            num_epochs=10)

------------ epoch 1 ------------
Train loss: 0.2376 | Train acc: 92.87%
Val loss: 0.1199 | Val acc: 96.327467%
Time taken: 6min 13s
------------ epoch 2 ------------
Train loss: 0.1366 | Train acc: 95.79%
Val loss: 0.1271 | Val acc: 96.327467%
Time taken: 6min 16s
------------ epoch 3 ------------
Train loss: 0.1184 | Train acc: 96.35%
Val loss: 0.1132 | Val acc: 96.824790%
Time taken: 6min 25s
------------ epoch 4 ------------
Train loss: 0.1056 | Train acc: 96.73%
Val loss: 0.1132 | Val acc: 97.092578%
Time taken: 7min 44s
------------ epoch 5 ------------
Train loss: 0.0952 | Train acc: 97.11%
Val loss: 0.1191 | Val acc: 96.901301%
Time taken: 7min 55s
------------ epoch 6 ------------
Train loss: 0.0819 | Train acc: 97.44%
Val loss: 0.1260 | Val acc: 96.977812%
Time taken: 7min 50s


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x12410bf70>
Traceback (most recent call last):
  File "/Users/h383kim/miniforge3/envs/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Users/h383kim/miniforge3/envs/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/h383kim/miniforge3/envs/env/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/h383kim/miniforge3/envs/env/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/h383kim/miniforge3/envs/env/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/h383kim/miniforge3/envs/env/lib/python3.8/selectors.py", line 415, in select
    fd_eve

KeyboardInterrupt: 