### Initialize Data

In [2]:
import torch
import torch.optim as optim
from torchvision import transforms
from focal_loss import FocalLoss
from mock_dataset import MockOutfitDataset
from outfit_model import OutfitCompatibilityModel
from outfit_dataset import OutfitDataset
import torch.nn as nn
from utils import save_checkpoint
import logging

# DEBUG - INFO - WARNING - ERROR
logging.basicConfig(level=logging.DEBUG)
data_dir = "data"

# Should be disjoint/nondisjoint
polyvore_split = "nondisjoint"

# Should be traim/valid/test
split = "valid"

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        normalize,
    ]
)


# Organizes your dataset into batches.
# Batch size = number of samples processed in one iteration
# Number of batches = total samples divided by batch_size
# Each this case, a sample = an outfit
def custom_collate(batch):
    images = torch.stack([item["outfit_images"] for item in batch], dim=0)
    texts = [item["outfit_texts"] for item in batch]
    labels = torch.stack([item["outfit_labels"] for item in batch], dim=0)

    return images, texts, labels


real_dataset = OutfitDataset(data_dir, polyvore_split, split, transform)
dataloader = torch.utils.data.DataLoader(
    real_dataset, batch_size=15, shuffle=True, collate_fn=custom_collate
)

# Instantiate the mock dataset and dataloader
# Contains list of all outfits
mock_dataset = MockOutfitDataset()

dataloader = torch.utils.data.DataLoader(
    mock_dataset, batch_size=15, shuffle=True, collate_fn=custom_collate
)

KeyError: 224

### Visualize Data

In [None]:
# CURRENTLY NOT USABLE

# from matplotlib import pyplot as plt
# import torchvision.transforms.functional as F


# def show_images(images, labels):
#     for i in range(images.size(1)):
#         image = F.to_pil_image(images[:, i, ...])
#         plt.subplot(1, images.size(1), i + 1)
#         plt.imshow(image)
#         plt.title(f"Label: {labels[i]}")
#         plt.axis("off")
#     plt.show()


# for batch_idx, (images, texts, labels) in enumerate(dataloader):
#     print(
#         f"Batch {batch_idx + 1} - Shape of images: {images.shape}, Texts: {texts}, Labels: {labels}"
#     )

#     # Visualize the images
#     show_images(images, labels)
    
#     if batch_idx == 2:  # Print information for the first 3 batches
#         break

### Init Model

In [None]:
# Instantiate the model, dataset, and dataloader
model = OutfitCompatibilityModel()
focal_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

### Training

In [None]:
# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    for batch in dataloader:
        images, texts, labels = batch
        # images = batch["outfit_images"]
        # texts = batch["outfit_texts"]
        # labels = batch["outfit_labels"]

        print(f"batch - images.shape: {images.shape}")
        print(f"batch - texts: {texts}")
        print(f"batch - labels: {labels}")

        optimizer.zero_grad()
        outputs = model(images, texts)
        loss = focal_loss(
            outputs, labels.unsqueeze(1)
        )  # Ensure labels have the right dimension
        loss.backward()
        optimizer.step()

        # Print or log the loss if needed
        print(f"Epoch {epoch + 1}, Batch loss: {loss.item()}")

    save_checkpoint(model.state_dict(), "mock", f"model_epoch_{epoch + 1}.pth")

    # Adjust the learning rate as needed (reduce by half in steps of 10)
    if (epoch + 1) % 10 == 0:
        for param_group in optimizer.param_groups:
            param_group["lr"] = param_group["lr"] / 2