In [1]:
import torch
import wandb
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import ResNet18_Weights, resnet18
from tqdm import tqdm

In [2]:
weights = ResNet18_Weights.DEFAULT
preprocess = weights.transforms()
torch.set_float32_matmul_precision("medium")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
train = "data/training/brands-classification-splits/train"
val = "data/training/brands-classification-splits/val"
test = "data/training/brands-classification-splits/test"

train_dataset = ImageFolder(train, transform=preprocess)
val_dataset = ImageFolder(val, transform=preprocess)
test_dataset = ImageFolder(test, transform=preprocess)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=False, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

In [3]:
num_classes = len(train_dataset.classes)
model = resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

In [5]:
def train_epoch(model, train_dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0

    for data in tqdm(train_dataloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_dataloader)


def eval_epoch(model, val_dataloader, criterion):
    model.eval()
    running_loss = 0.0
    total = 0
    correct = 0

    with torch.inference_mode():
        for data in tqdm(val_dataloader):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total

        return running_loss / len(train_dataloader), accuracy

In [6]:
wandb.init(project="sneakers_ml")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mseara[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
num_epochs = 30


def train(model, train_dataloader, criterion, optimizer, val_dataloader):
    for _ in range(num_epochs):
        loss = train_epoch(model, train_dataloader, criterion, optimizer)
        loss, acc = eval_epoch(model, val_dataloader, criterion)
        wandb.log({"val_accuracy": acc, "val_loss": loss, "train_loss": loss})

In [8]:
train(model, train_dataloader, criterion, optimizer, val_dataloader)

100%|██████████| 25/25 [00:25<00:00,  1.00s/it]
100%|██████████| 9/9 [00:08<00:00,  1.06it/s]
100%|██████████| 25/25 [00:25<00:00,  1.01s/it]
100%|██████████| 9/9 [00:10<00:00,  1.12s/it]
100%|██████████| 25/25 [00:26<00:00,  1.05s/it]
100%|██████████| 9/9 [00:08<00:00,  1.12it/s]
100%|██████████| 25/25 [00:23<00:00,  1.04it/s]
100%|██████████| 9/9 [00:07<00:00,  1.17it/s]
100%|██████████| 25/25 [00:25<00:00,  1.04s/it]
100%|██████████| 9/9 [00:08<00:00,  1.09it/s]
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]
100%|██████████| 9/9 [00:08<00:00,  1.06it/s]
100%|██████████| 25/25 [00:23<00:00,  1.04it/s]
100%|██████████| 9/9 [00:08<00:00,  1.10it/s]
100%|██████████| 25/25 [00:25<00:00,  1.02s/it]
100%|██████████| 9/9 [00:07<00:00,  1.13it/s]
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]
100%|██████████| 9/9 [00:08<00:00,  1.10it/s]
100%|██████████| 25/25 [00:24<00:00,  1.01it/s]
100%|██████████| 9/9 [00:08<00:00,  1.09it/s]
100%|██████████| 25/25 [00:25<00:00,  1.02s/it]
100%|███████

In [9]:
wandb.finish()



VBox(children=(Label(value='0.269 MB of 0.291 MB uploaded\r'), FloatProgress(value=0.9236918175918825, max=1.0…

0,1
train_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▅▆▆▇▇▇▇▇▇▇█████████████████
val_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,0.41759
val_accuracy,0.62879
val_loss,0.41759
