In [36]:
import os
import json
import torch
import PIL
from model.AlexNet import AlexNet
from model.ResNet34 import ResNet34
from model.MLP import MLP
from utils.dataset import build_dataloader
from utils.common import get_all_embeddings, get_accuracy, log_to_file
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from torchvision import datasets, transforms

In [37]:
from torchvision.datasets import MNIST

train_dataset = MNIST(
    root="data", train=True, download=True, transform=transforms.ToTensor()
)

val_dataset = MNIST(
    root="data", train=False, download=True, transform=transforms.ToTensor()
)

In [38]:
# Constants
EPOCHS = 300
BATCH_SIZE = 64

IMAGE_SIZE = 28
EMBEDDING_SIZE = 64

# MEAN = [0.485, 0.456, 0.406]
# STD = [0.229, 0.224, 0.225]

DEVICE = torch.device("cuda:0")
SAVE_PATH = "./weights"

# Image transformations
transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        # transforms.Normalize(MEAN, STD), # TODO: disccover
    ]
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False
)

In [42]:
model = AlexNet(input_size=IMAGE_SIZE, embedding_size=EMBEDDING_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define triplet loss utility functions
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.2, distance=distance, type_of_triplets="semihard"
)



In [43]:
history = {"train": [], "val": [], "best_accuracy": 0.0}
os.makedirs(SAVE_PATH, exist_ok=True)
if os.path.exists("training.log"):
    os.remove("training.log")

In [44]:
for epoch in range(EPOCHS):
    epoch += 1

    # Model training
    model.to(DEVICE)
    model.train()

    for batch_idx, (data, labels) in enumerate(train_dataloader):
        data, labels = data.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()

        if batch_idx % 20 == 0:
            history["train"].append(
                {
                    "epoch": epoch,
                    "loss": loss.item(),
                    "triplets": mining_func.num_triplets,
                }
            )
            msg = f"Epoch [{epoch}/{EPOCHS}] Iter [{batch_idx}/{len(train_loader)}], Loss: {loss.item()}, Triplets: {mining_func.num_triplets}"
            log_to_file(msg)
            print(msg)

    # evaluate after n epochs
    if epoch % 2 == 0:
        # model validation
        model.eval()

        with torch.no_grad():
            # as all embeddings need to be stored in memory
            # you can set DEVICE = torch.device('cpu') in case gpu memory overflow occurs
            accuracy = get_accuracy(val_dataset, train_dataset, model, DEVICE)

            history["val"].append({"epoch": epoch, "accuracy": accuracy})
            msg = f"Val accuracy: {accuracy}"
            log_to_file(msg)
            print(msg)

            # save model
            torch.save(model.state_dict(), f"{SAVE_PATH}/model_latest.pth")

            if accuracy >= history["best_accuracy"]:
                history["best_accuracy"] = accuracy
                torch.save(model.state_dict(), f"{SAVE_PATH}/model_best.pth")

            with open("history.json", "w") as f:
                f.write(json.dumps(history))

RuntimeError: Given groups=1, weight of size [64, 3, 11, 11], expected input[64, 1, 28, 28] to have 3 channels, but got 1 channels instead

In [2]:
import random

In [4]:
import numpy as np

classes = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).tolist()

random.sample(classes, 2)

[3, 1]