In [None]:
from pathlib import Path

import numpy as np
import torch
import wandb
from PIL import Image
from scipy.special import softmax
from torch import nn
from torch.utils.data import DataLoader
from torcheval.metrics.functional import multiclass_f1_score
from torchvision.datasets import ImageFolder
from torchvision.models import ResNet152_Weights, resnet152
from tqdm import tqdm

from sneakers_ml.models.onnx_utils import get_session, predict, save_torch_model

In [None]:
weights = ResNet152_Weights.DEFAULT
preprocess = weights.transforms()
torch.set_float32_matmul_precision("medium")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
train = "data/training/brands-classification-splits/train"
val = "data/training/brands-classification-splits/val"
test = "data/training/brands-classification-splits/test"

train_dataset = ImageFolder(train, transform=preprocess)
val_dataset = ImageFolder(val, transform=preprocess)
test_dataset = ImageFolder(test, transform=preprocess)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=False, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

In [None]:
path = "data/models/brands-classification/resnet152-finetune-classes.npy"
save_path = Path(path)
save_path.parent.mkdir(parents=True, exist_ok=True)
class_to_idx = train_dataset.class_to_idx
with save_path.open("wb") as save_file:
    np.save(save_file, np.array(list(class_to_idx.items())), allow_pickle=False)

In [None]:
class ResNet152Classifier(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.num_classes = num_classes
        weights = ResNet152_Weights.DEFAULT
        backbone = resnet152(weights=weights)
        num_filters = backbone.fc.in_features
        backbone.fc = nn.Linear(num_filters, self.num_classes)
        extractor_layers = list(backbone.children())[:-3]
        trainable_bottleneck_layers = list(backbone.children())[-3:-1]
        classifier_layer = list(backbone.children())[-1]
        self.feature_extractor = nn.Sequential(*extractor_layers)
        self.feature_extractor.eval()

        self.trainable_bottleneck = nn.Sequential(*trainable_bottleneck_layers)
        self.classifier = nn.Sequential(classifier_layer)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.feature_extractor(x)
        x = self.trainable_bottleneck(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

In [None]:
num_classes = len(train_dataset.classes)
model = ResNet152Classifier(num_classes)
model.to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    [{"params": model.trainable_bottleneck.parameters()}, {"params": model.classifier.parameters()}], lr=0.001
)

In [None]:
def calculate_metrics(y_pred: torch.Tensor, y_true: torch.Tensor):
    f1_macro = multiclass_f1_score(y_pred, y_true, num_classes=num_classes, average="macro")
    f1_micro = multiclass_f1_score(y_pred, y_true, num_classes=num_classes, average="micro")
    f1_weighted = multiclass_f1_score(y_pred, y_true, num_classes=num_classes, average="weighted")
    return f1_macro.item(), f1_micro.item(), f1_weighted.item()


def train_epoch(model, train_dataloader, criterion, optimizer):
    running_loss = 0.0

    model.trainable_bottleneck.train()
    model.classifier.train()

    for data in tqdm(train_dataloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_dataloader)


def eval_epoch(model, val_dataloader, criterion):
    running_loss = 0.0
    y_true = []
    y_pred = []

    model.trainable_bottleneck.eval()
    model.classifier.eval()

    with torch.inference_mode():
        for data in tqdm(val_dataloader):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            y_true.append(labels.cpu())
            y_pred.append(predicted.cpu())

        y_true = torch.cat(y_true)
        y_pred = torch.cat(y_pred)
        f1_macro, f1_micro, f1_weighted = calculate_metrics(y_pred, y_true)

        return running_loss / len(train_dataloader), f1_macro, f1_micro, f1_weighted

In [None]:
wandb.init(project="sneakers_ml")

In [None]:
num_epochs = 5


def train(model, train_dataloader, criterion, optimizer, val_dataloader):
    for _ in range(num_epochs):
        train_loss = train_epoch(model, train_dataloader, criterion, optimizer)
        val_loss, f1_macro, f1_micro, f1_weighted = eval_epoch(model, val_dataloader, criterion)
        wandb.log(
            {
                "val_f1_macro": f1_macro,
                "val_f1_micro": f1_micro,
                "val_f1_weighted": f1_weighted,
                "val_loss": val_loss,
                "train_loss": train_loss,
            }
        )

In [None]:
train(model, train_dataloader, criterion, optimizer, val_dataloader)

In [None]:
wandb.finish()

In [None]:
loss, f1_macro, f1_micro, f1_weighted = eval_epoch(model, test_dataloader, criterion)
print(
    {
        "test_f1_macro": f1_macro,
        "test_f1_micro": f1_micro,
        "test_f1_weighted": f1_weighted,
        "test_loss": loss,
    }
)

In [None]:
model.eval()
model.to("cpu")
torch_input = torch.randn(1, 3, 224, 224)
path = "data/models/brands-classification/resnet152-finetune.onnx"
save_torch_model(model, torch_input, path)

In [None]:
def predict_resnet(images: Image.Image) -> np.ndarray:
    with Path("data/models/brands-classification/resnet152-finetune-classes.npy").open("rb") as file:
        class_to_idx_numpy = np.load(file, allow_pickle=False)
        class_to_idx = dict(zip(class_to_idx_numpy[:, 1].astype(int), class_to_idx_numpy[:, 0]))

    weights = ResNet152_Weights.DEFAULT
    preprocess = weights.transforms()

    def apply_transforms(image: Image.Image) -> torch.Tensor:
        return preprocess(image)  # type: ignore[no-any-return]

    preprocessed_images = torch.stack([apply_transforms(image) for image in images])

    onnx_session = get_session("data/models/brands-classification/resnet152-finetune.onnx", "cpu")

    pred = predict(onnx_session, preprocessed_images)
    softmax_pred = softmax(pred, axis=1)
    predictions = np.argmax(softmax_pred, axis=1)
    string_predictions = np.vectorize(class_to_idx.get)(predictions)
    return predictions, string_predictions

In [None]:
image = Image.open("data/training/brands-classification-splits/train/adidas/1.jpeg")

In [None]:
predict_resnet([image, image])