In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torch.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter

from torchvision.datasets import ImageFolder
from torchvision import transforms

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import logging
from utils.train import train
from utils.checkpoint import (
    load_checkpoint,
    create_checkpoint,
    save_weights,
)

from utils import get_model

In [None]:
torch._logging.set_logs(all=logging.ERROR)
torch.multiprocessing.set_start_method("spawn", force=True)

In [None]:
MODEL = "resnet50"
BATCH_SIZE = 128
EPOCHS = 100
SEED = 11
LR = 1e-2
WEIGHT_DECAY = 1e-4
LOAD = True
CHECKPOINT_NAME = "checkpoint.pt"

In [None]:
torch.manual_seed(SEED)

transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.RandomCrop((224, 224)),
        transforms.ToTensor(),
    ]
)

dataset = ImageFolder(
    "data/images/",
    transform=transform,
)

class_weight = compute_class_weight(
    "balanced",
    classes=np.arange(10),
    y=dataset.targets,
)

print(dataset.class_to_idx)
print(class_weight)
class_weight = torch.tensor(class_weight, dtype=torch.float32)

train_indices, test_indices = train_test_split(
    range(len(dataset.targets)),
    test_size=0.2,
    stratify=dataset.targets,
    random_state=SEED,
)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    persistent_workers=True,
    pin_memory=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    persistent_workers=True,
    pin_memory=True,
)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_model(MODEL).to(device)
model.compile(dynamic=False, mode="max-autotune")
criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
optimizer = optim.AdamW(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[30, 60],
    gamma=0.1,
)
scaler = GradScaler(device)
writer = SummaryWriter(log_dir=f"logs/{MODEL}")
epoch = 0

In [None]:
if LOAD:
    try:
        epoch += load_checkpoint(
            dir=f"data/checkpoints/{MODEL}",
            model=model,
            optimizer=optimizer,
            scaler=scaler,
            scheduler=scheduler,
            name=CHECKPOINT_NAME,
        )
    except:
        pass

In [None]:
epoch = train(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    criterion=criterion,
    optimizer=optimizer,
    scaler=scaler,
    scheduler=scheduler,
    writer=writer,
    epochs=EPOCHS,
    start_epoch=epoch,
    checkpoint_dir=f"data/checkpoints/{MODEL}",
)

In [None]:
create_checkpoint(
    dir=f"data/checkpoints/{MODEL}",
    model=model,
    epoch=epoch,
    optimizer=optimizer,
    scaler=scaler,
    scheduler=scheduler,
)

In [None]:
save_weights(model, MODEL)
writer.close()