#Load data

In [36]:
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

data_path = 'data/casia-webface'

batch_size = 20

train_data = datasets.ImageFolder(data_path, transform=transform)

trainloader = torch.utils.data.DataLoader(
    train_data, shuffle=False, batch_size=batch_size)

Make sure dataset is loaded correctly

In [37]:
iterator = iter(trainloader)
image, label = next(iterator)

assert label[0] == label[1]
print("image", image[0])
print("label", label[0])
print("label", label[1])
print("label", label[2])
print("label", label[3])

image tensor([[[0.4667, 0.4706, 0.4784,  ..., 0.5216, 0.5294, 0.5333],
         [0.4667, 0.4706, 0.4784,  ..., 0.5176, 0.5255, 0.5294],
         [0.4667, 0.4706, 0.4784,  ..., 0.5137, 0.5216, 0.5255],
         ...,
         [0.7137, 0.7176, 0.7255,  ..., 0.2392, 0.2392, 0.2353],
         [0.7294, 0.7333, 0.7373,  ..., 0.2314, 0.2353, 0.2353],
         [0.7373, 0.7412, 0.7451,  ..., 0.2275, 0.2314, 0.2353]],

        [[0.4039, 0.4078, 0.4157,  ..., 0.2667, 0.2706, 0.2706],
         [0.4039, 0.4078, 0.4157,  ..., 0.2627, 0.2667, 0.2667],
         [0.4039, 0.4078, 0.4157,  ..., 0.2588, 0.2627, 0.2627],
         ...,
         [0.1961, 0.2000, 0.2078,  ..., 0.0667, 0.0667, 0.0627],
         [0.1961, 0.2000, 0.2039,  ..., 0.0588, 0.0627, 0.0627],
         [0.1961, 0.2000, 0.2039,  ..., 0.0549, 0.0588, 0.0627]],

        [[0.4157, 0.4196, 0.4275,  ..., 0.3843, 0.3882, 0.3922],
         [0.4157, 0.4196, 0.4275,  ..., 0.3804, 0.3843, 0.3882],
         [0.4157, 0.4196, 0.4235,  ..., 0.3686, 0.38

#Init the model

In [40]:
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn, optim
import os
# model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model = resnet18()
num_classes = len(os.listdir(data_path))
print(num_classes)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

model.fc = nn.Linear(512, num_classes)
model.to(device)


10572
mps


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

#Train the model

In [None]:
import time

learning_rate = 1e-3
num_epochs = 10
weight_decay = 0.0
momentum = 0.9

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),
                    lr=learning_rate,
                    momentum=momentum,
                    weight_decay=weight_decay)

def get_train_accuracy(model: nn.Module):
    correct = 0
    total = 0
    n = 0
    with torch.no_grad():
        for imgs, labels in iter(trainloader):
            imgs, labels = imgs.to(device), labels.to(device) # Move input data to the same device as the model
            model.eval()
            output = model(imgs) # We don't need to run torch.softmax
            pred = output.max(1)[1] # get the index of the max log-probability
            correct += pred.eq(labels).sum().item()
            total += imgs.shape[0]
            n += 1
    return correct / total 

for epoch in range(num_epochs):
    n = 0
    for imgs, labels in iter(trainloader):
        tic = time.perf_counter()
        # imshow(imgs[0])
        imgs = imgs.to(device)
        labels = labels.to(device)

        model.train() # annotate model for training

        out = model(imgs)
        loss = loss_fn(out, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        toc = time.perf_counter()
        n += 1
        if n % 10 == 0:
            print('epoch: {}, iter: {}, loss: {}, time: {}'.format(epoch, n, loss, toc - tic))
    print(get_train_accuracy(model))