In [44]:
import copy
import time
import os
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Dataset

In [63]:
base_dir = r"/Users/h383kim/pytorch/AlexNet/splitted"
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

img_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5108, 0.4829, 0.3989], std=[0.2632, 0.2587, 0.2706])
])

train_dataset = ImageFolder(root=train_dir,
                            transform=img_transform)
val_dataset = ImageFolder(root=val_dir,
                          transform=img_transform)
test_dataset = ImageFolder(root=test_dir,
                           transform=img_transform)

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=os.cpu_count())
test_dataloader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=os.cpu_count())

# Model

In [46]:
ResNet_type = {
    "ResNet18"  : ("Basic_Conv", [2, 2, 2, 2]),
    "ResNet34"  : ("Basic_Conv", [3, 4, 6, 3]),
    "ResNet50"  : ("BottleNeck", [3, 4, 6, 3]),
    "ResNet101" : ("BottleNeck", [3, 4, 23, 3]),
    "ResNet152" : ("BottleNeck", [3, 8, 36, 3])
}

In [47]:
class Basic_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.shortcut = nn.Sequential()

        if stride > 1:
        #if stride > 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                # Downsamples to match spatial dimensions and 1x1 projecion shortcut to match channels
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.relu = nn.ReLU()
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        
        identity = self.shortcut(x)
        
        out = out + identity
        out = self.relu(out)

        return out
        

In [48]:
class BottleNeck(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.out_mult = 4

        self.conv = nn.Sequential(
            # 1 x 1 conv, dim reduction
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels*self.out_mult, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(out_channels*self.out_mult)
        )
        
        self.shortcut = nn.Sequential()

        if stride > 1 or in_channels != out_channels * self.out_mult:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels*self.out_mult, stride=stride, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(out_channels*self.out_mult)
            )

        self.relu = nn.ReLU()
        
    def forward(self, x):
        out = self.conv(x)
        #print(out.shape)
        identity = self.shortcut(x)
        #print(identity.shape)
        out = out + identity
        out = self.relu(out)

        return out

In [49]:
class ResNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, model="ResNet50"):
        super().__init__()
        self.CHANNELS = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.conv2 = self._make_layers(ResNet_type[model][0], ResNet_type[model][1][0], 64, 1)
        self.conv3 = self._make_layers(ResNet_type[model][0], ResNet_type[model][1][1], 128, 2)
        self.conv4 = self._make_layers(ResNet_type[model][0], ResNet_type[model][1][2], 256, 2)
        self.conv5 = self._make_layers(ResNet_type[model][0], ResNet_type[model][1][3], 512, 2)

        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.OUTCHANNELS = 512 if ResNet_type[model][0] == "Basic_Conv" else 512*4
        
        self.fc = nn.Linear(self.OUTCHANNELS, num_classes) 

    def _make_layers(self, block, num_blocks, out_channels, stride):
        # Any spatial reduction (i.e stride > 1) is applied in the very first layer of the block only
        strides = [stride] + [1]*(num_blocks - 1) # len(strides) == num_blocks
        conv_blocks = []
        
        if block == "Basic_Conv":
            for stride in strides:
                conv_blocks.append(Basic_Conv(self.CHANNELS, out_channels, stride))
                self.CHANNELS = out_channels
        else:
            for stride in strides:
                conv_blocks.append(BottleNeck(self.CHANNELS, out_channels, stride))
                self.CHANNELS = out_channels * 4

        return nn.Sequential(*conv_blocks)
        
    def forward(self, x):
        #print("--------conv1----------")
        x = self.conv1(x)
        #print("--------conv2----------")
        x = self.conv2(x)
        #print("--------conv3----------")
        x = self.conv3(x)
        #print("--------conv4----------")
        x = self.conv4(x)
        #print("--------conv5----------")
        x = self.conv5(x)
        x = self.avg(x)
        # Flatten
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

In [None]:
from torchsummary import summary
Res = ResNet(model="ResNet34")
summary(Res.to("cpu"), input_size=(3, 224, 224), device="cpu")

In [33]:
from torchsummary import summary
Res = ResNet(model="ResNet50")
summary(Res.to("cpu"), input_size=(3, 224, 224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [27]:
from torchvision import models

ResNet50 = models.resnet50(weights="IMAGENET1K_V1", progress=True)

In [28]:
summary(ResNet50.to("cpu"), input_size=(3, 224, 224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

# Train / Evaluate

In [57]:
def train(model: torch.nn.Module,
          dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module):
    
    model.train()
    train_loss, train_acc, correct = 0, 0, 0
        
    for X, y in dataloader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        # Forward pass
        preds_prob = model(X) # Shape of preds_prob = (batch_size, num_classes)
        # Calculate the loss
        loss = loss_fn(preds_prob, y) # Shape of loss = [float] (i.e. scalar tensor containing the average loss over the batch)
        train_loss += loss.item()
        # Optimizer zero_grad
        optimizer.zero_grad()
        # Loss backward
        loss.backward()
        # Update
        optimizer.step()

        preds = torch.argmax(preds_prob, dim=1) # max values over the num_classes of (batch_size, num_classes)
        correct += preds.eq(y.view_as(preds)).sum().item()

    train_loss /= len(dataloader)
    train_acc = 100. * correct / len(dataloader.dataset)

    return train_loss, train_acc

In [65]:
def evaluate(model: torch.nn.Module,
             dataloader: torch.utils.data.DataLoader,
             loss_fn: torch.nn.Module):
    model.eval()
    val_loss, val_acc, correct = 0, 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            # Forward pass
            preds_prob = model(X)
            # Calculate the loss
            loss = loss_fn(preds_prob, y)
            val_loss += loss.item()

            preds = torch.argmax(preds_prob, dim=1)
            correct += preds.eq(y.view_as(preds)).sum().item()

        val_loss /= len(dataloader)
        val_acc = 100. * correct / len(dataloader.dataset)

    return val_loss, val_acc

In [None]:
def train_baseline(model: torch.nn.Module,
                   train_dataloader: torch.utils.data.DataLoader,
                   val_dataloader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer,
                   loss_fn: torch.nn.Module,
                   num_epochs: int):

    best_acc = 0
    best_model_wts = copy.deepcopy(model.state_dict())
    for epoch in range(1, num_epochs + 1):
        start = time.time()
        train_loss, train_acc = train(model, train_dataloader, optimizer, loss_fn)
        test_loss, test_acc = evaluate(model, val_dataloader, loss_fn)

        if test_acc > best_acc:
            best_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        end = time.time()
        time_elapsed = end - start
        print(f"------------ epoch {epoch} ------------")
        print(f"Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f}%")
        print(f"Val loss: {test_loss:.4f} | Val acc: {test_acc:2f}%")
        print(f"Time taken: {time_elapsed / 60:.0f}min {time_elapsed % 60:.0f}s")

    model.load_state_dict(best_model_wts)
    return model

# Train baseline model

In [60]:
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device : {DEVICE}")

Device : mps


In [66]:
ResNet50 = ResNet(in_channels=3, model="ResNet50", num_classes=10).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ResNet50.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.3)

In [None]:
base = train_baseline(ResNet50,
                      train_dataloader,
                      val_dataloader,
                      optimizer,
                      loss_fn,
                      num_epochs=5)