In [None]:
import os
import random
from PIL import Image

import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing
from torch.utils.data import DataLoader, Subset
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from datasets import UserModeImgDataset, UserModeDataset, UserModeFeatDataset
from models import DVBPR
from trainers import ImgTrainer
from utils.data import extract_embedding

# Parameters
RNG_SEED = 0
USE_GPU = True
BASE_PATH = '/home/pcerdam/VisualRecSys-Tutorial-IUI2021/'

TRAINING_PATH = os.path.join(BASE_PATH, "data", "naive-user-train.csv")
EMBEDDING_PATH = os.path.join(BASE_PATH, "data", "embedding-resnet50.npy")
VALIDATION_PATH = os.path.join(BASE_PATH, "data", "naive-user-validation.csv")

IMAGES_PATH = os.path.join('/mnt/data2/wikimedia/mini-images-224-224-v2')
CHECKPOINTS_DIR = os.path.join(BASE_PATH, "checkpoints")
version = 'DVBPR_wikimedia'

# Parameters (training)
SETTINGS = {
    "dataloader:batch_size": 128,
    "dataloader:num_workers": os.cpu_count(),
    "prev_checkpoint": False,
    "model:dim_visual": 100,
    "optimizer:lr": 0.001,
    "optimizer:weight_decay": 0.0001,
    "scheduler:factor": 0.6,
    "scheduler:patience": 2,
    "train:max_epochs": 5,
    "train:max_lrs": 5,
    "train:non_blocking": True,
    "train:train_per_valid_times": 1
}

In [None]:
%%time
# Freezing RNG seed if needed
if RNG_SEED is not None:
    print(f"\nUsing random seed...")
    random.seed(RNG_SEED)
    torch.manual_seed(RNG_SEED)
    np.random.seed(RNG_SEED)

# Load embedding from file
print(f"\nLoading embedding from file... ({EMBEDDING_PATH})")
embedding = np.load(EMBEDDING_PATH, allow_pickle=True)

# Extract features and "id2index" mapping
print("\nExtracting data into variables...")
embedding, id2index, index2fn = extract_embedding(embedding, verbose=True)
print(f">> Features shape: {embedding.shape}")

In [None]:
# DataLoaders initialization
print("\nInitialize DataLoaders")

# Training DataLoader
train_dataset = UserModeImgDataset(
    csv_file=TRAINING_PATH,
    img_path=IMAGES_PATH,
    id2index=id2index,
    index2fn=index2fn
)
print(f">> Training dataset: {len(train_dataset)}")
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=SETTINGS["dataloader:batch_size"],
    num_workers=SETTINGS["dataloader:num_workers"],
    shuffle=True,
    pin_memory=True,
)
print(f">> Training dataloader: {len(train_dataloader)}")

# Validation DataLoader
valid_dataset = UserModeImgDataset(
    csv_file=VALIDATION_PATH,
    img_path=IMAGES_PATH,
    id2index=id2index,
    index2fn=index2fn
)
print(f">> Validation dataset: {len(valid_dataset)}")
valid_sampler = SequentialSampler(valid_dataset)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=SETTINGS["dataloader:batch_size"],
    num_workers=SETTINGS["dataloader:num_workers"],
    shuffle=True,
    pin_memory=True,
)
print(f">> Validation dataloader: {len(valid_dataloader)}")

In [None]:
# Model initialization
print("\nInitialize model")
device = torch.device("cuda:0" if torch.cuda.is_available() and USE_GPU else "cpu")
if torch.cuda.is_available() != USE_GPU:
    print((f"\nNotice: Not using GPU - "
           f"Cuda available ({torch.cuda.is_available()}) "
           f"does not match USE_GPU ({USE_GPU})"
           ))
N_USERS = len(set(train_dataset.ui))
N_ITEMS = len(embedding)
print(f">> N_USERS = {N_USERS} | N_ITEMS = {N_ITEMS}")
print(torch.Tensor(embedding).shape)
model = DVBPR(
    N_USERS,  # Number of users and items
    N_ITEMS,
    SETTINGS["model:dim_visual"],  # Size of visual spaces
).to(device)

print(model)

In [None]:
# Training setup
print("\nSetting up training")
optimizer = optim.Adam(
    model.parameters(),
    lr=SETTINGS["optimizer:lr"],
    weight_decay=SETTINGS["optimizer:weight_decay"],
)
criterion = nn.BCEWithLogitsLoss(reduction="sum")
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=SETTINGS["scheduler:factor"],
    patience=SETTINGS["scheduler:patience"], verbose=True,
)

In [None]:
%%time
# Training
trainer = ImgTrainer(
    model, device, criterion, optimizer, scheduler,
    checkpoint_dir=CHECKPOINTS_DIR,
    version=version,
)
best_model, best_acc, best_loss, best_epoch = trainer.run(
    SETTINGS["train:max_epochs"], SETTINGS["train:max_lrs"],
    {"train": train_dataloader, "validation": valid_dataloader},
    train_valid_loops=SETTINGS["train:train_per_valid_times"],
    use_checkpoint=SETTINGS["prev_checkpoint"]
)

In [None]:
# Final result
print(f"\nBest ACC {best_acc} reached at epoch {best_epoch}")
print(best_model)