### Initialize Data

In [1]:
import torch
import torch.optim as optim
from torchvision import transforms
from focal_loss import FocalLoss
from mock_dataset import MockOutfitDataset
from outfit_model import OutfitCompatibilityModel
from outfit_dataset import OutfitDataset
import torch.nn as nn
from utils import save_checkpoint
import logging

# DEBUG - INFO - WARNING - ERROR
logging.basicConfig(level=logging.DEBUG)
data_dir = "data"

# Should be disjoint/nondisjoint
polyvore_split = "nondisjoint"

# Should be traim/valid/test
split = "valid"

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        normalize,
    ]
)


# Organizes your dataset into batches. If we don't do this, our texts matrix will be transposed
# Batch size = number of samples processed in one iteration
# Number of batches = total samples divided by batch_size
# Each this case, a sample = an outfit
def custom_collate(batch):
    outfits_images = []
    outfits_texts = []
    outfits_labels = []

    # Find the maximum number of items in any outfit in this batch
    max_items = max(len(outfit["outfit_images"]) for outfit in batch)

    for outfit in batch:
        # Pad or truncate the number of items to match max_items
        padded_images = torch.zeros((max_items,) + outfit["outfit_images"].shape[1:])
        padded_images[: outfit["outfit_images"].shape[0]] = outfit["outfit_images"]

        # Similarly, pad or truncate the number of texts
        padded_texts = outfit["outfit_texts"] + [""] * (
            max_items - len(outfit["outfit_texts"])
        )

        outfits_images.append(padded_images)
        outfits_texts.append(padded_texts)
        outfits_labels.append(outfit["outfit_label"])

    return {
        "outfit_images": torch.stack(outfits_images),
        "outfit_texts": outfits_texts,
        "outfit_labels": torch.tensor(outfits_labels, dtype=torch.float),
    }


real_dataset = OutfitDataset(data_dir, polyvore_split, split, transform)
real_dataloader = torch.utils.data.DataLoader(
    real_dataset, batch_size=5, shuffle=True, collate_fn=custom_collate
)

# Instantiate the mock dataset and dataloader
# Contains list of all outfits
mock_dataset = MockOutfitDataset()

dataloader = torch.utils.data.DataLoader(
    mock_dataset, batch_size=15, shuffle=True, collate_fn=custom_collate
)

DEBUG:root:OutfitDataset - itemIdentifier2ItemId's 1st 10 items: {'224930161_1': '213343990', '224930161_2': '206270853', '224930161_3': '202059322', '224930161_4': '53908391', '208756998_1': '182121462', '208756998_2': '184300769', '208756998_3': '185818820', '208756998_4': '185817365', '208756998_5': '185818053', '218698690_1': '206948767'}
DEBUG:root:imageNames 1st 10 items: ['202321215', '200886508', '209234827', '148482437', '208219955', '52192148', '163215721', '187916901', '159655041', '33446168']
DEBUG:root:OutfitDataset - itemIdToIndex's 1st 10 items: {'202321215': 0, '200886508': 1, '209234827': 2, '148482437': 3, '208219955': 4, '52192148': 5, '163215721': 6, '187916901': 7, '159655041': 8, '33446168': 9}
DEBUG:root:OutfitDataset - itemIdToDescription's 1st 10 items: {'202321215': 'saint laurent pink medium monogram', '200886508': 'h&m leather trousers', '209234827': 'steve madden baddison handbag backpack', '148482437': 'bloomingdales cashmere ribbed gloves', '208219955': '

MockOutfitDataset; images: torch.Size([20, 5, 3, 224, 224]), texts: [['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock description 1', 'mock description 2', 'mock description 3', 'mock description 4'], ['mock description 0', 'mock descri

### Visualize Data

In [2]:
# CURRENTLY NOT USABLE

# from matplotlib import pyplot as plt
# import torchvision.transforms.functional as F


# def show_images(images, labels):
#     for i in range(images.size(1)):
#         image = F.to_pil_image(images[:, i, ...])
#         plt.subplot(1, images.size(1), i + 1)
#         plt.imshow(image)
#         plt.title(f"Label: {labels[i]}")
#         plt.axis("off")
#     plt.show()


# for batch_idx, (images, texts, labels) in enumerate(dataloader):
#     print(
#         f"Batch {batch_idx + 1} - Shape of images: {images.shape}, Texts: {texts}, Labels: {labels}"
#     )

#     # Visualize the images
#     show_images(images, labels)
    
#     if batch_idx == 2:  # Print information for the first 3 batches
#         break

### Init Model

In [3]:
# Instantiate the model and components for training (loss function, optimizer)
model = OutfitCompatibilityModel()
focal_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /sentence-transformers/bert-base-nli-mean-tokens/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json HTTP/1.1" 200 0


### Training

In [4]:
# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    for batch in real_dataloader:
        images = batch["outfit_images"]
        texts = batch["outfit_texts"]
        labels = batch["outfit_labels"]

        print(f"batch - images.shape: {images.shape}")
        print(f"batch - texts: {texts}")
        print(f"batch - labels: {labels}")

        optimizer.zero_grad()
        outputs = model(images, texts)
        loss = focal_loss(
            outputs, labels
        )  # Ensure labels have the right dimension
        loss.backward()
        optimizer.step()

        # Print or log the loss if needed
        print(f"Epoch {epoch + 1}, Batch loss: {loss.item()}")

    save_checkpoint(model.state_dict(), "mock", f"model_epoch_{epoch + 1}.pth")

    # Adjust the learning rate as needed (reduce by half in steps of 10)
    if (epoch + 1) % 10 == 0:
        for param_group in optimizer.param_groups:
            param_group["lr"] = param_group["lr"] / 2

DEBUG:root:!!!!!!!!!!
DEBUG:root:OutfitCompatibilityModel - START
DEBUG:root:OutfitCompatibilityModel - intial images.shape: torch.Size([5, 7, 3, 224, 224])
DEBUG:root:@@@@@@@@@@
DEBUG:root:[START LOOP] OUTFIT - 0
DEBUG:root:OutfitCompatibilityModel - outfit_images.shape: torch.Size([7, 3, 224, 224])
DEBUG:root:##########
DEBUG:root:[START LOOP] ITEM - 0
DEBUG:root:OutfitCompatibilityModel - item_index: 0 - item_image.shape: torch.Size([1, 3, 224, 224]) item_text.shape: push lock cross body sequin
DEBUG:root:ImageEncoder - after fc_layer x's shape: torch.Size([1, 64])
DEBUG:root:----------
DEBUG:root:----------


batch - images.shape: torch.Size([5, 7, 3, 224, 224])
batch - texts: [['push lock cross body sequin', 'glitter stiletto heel peep toe', 'glitter mini heel peep toe', 'gurhan small 24k gold amulet', 'signature gold heart bracelet in', 'feather print fit flare mini', 'philosophy di lorenzo serafini mink'], ['burberry leather bowling bag', 'zara studded cowboy ankle boot', 'kain classic modal and silk-blend tank', 'h&m shirt', 'patek philippe gold automatic watch', 'illesteva leonard round-frame matte-acetate sunglasses', 'ted baker whistel - dark rinse skinny denim'], ['mini borsa metropolis in tessuto', 'alberta ferretti denim beaded mules', 'womens helene berman denim ruffle', "marques'almeida denim tank top", 'fendi womens ff0177 sunglasses', 'bliss and mischief shadow flower-embroidered denim shorts', ''], ['rebecca minkoff julian backpack with', 'adidas originals stan smith crackled leather sneakers - white/navy', 'stone black gold white diamonds', 'yohji yamamoto regular shirt', 'm

DEBUG:root:TextEncoder - x's shape after fc_layer: torch.Size([1, 64])
DEBUG:root:----------
DEBUG:root:OutfitCompatibilityModel - item_features.shape: torch.Size([1, 128])
DEBUG:root:[END LOOP] ITEM - 0
DEBUG:root:##########
DEBUG:root:##########
DEBUG:root:[START LOOP] ITEM - 1
DEBUG:root:OutfitCompatibilityModel - item_index: 1 - item_image.shape: torch.Size([1, 3, 224, 224]) item_text.shape: glitter stiletto heel peep toe
DEBUG:root:ImageEncoder - after fc_layer x's shape: torch.Size([1, 64])
DEBUG:root:----------
DEBUG:root:----------
DEBUG:root:TextEncoder - x's shape after fc_layer: torch.Size([1, 64])
DEBUG:root:----------
DEBUG:root:OutfitCompatibilityModel - item_features.shape: torch.Size([1, 128])
DEBUG:root:[END LOOP] ITEM - 1
DEBUG:root:##########
DEBUG:root:##########
DEBUG:root:[START LOOP] ITEM - 2
DEBUG:root:OutfitCompatibilityModel - item_index: 2 - item_image.shape: torch.Size([1, 3, 224, 224]) item_text.shape: glitter mini heel peep toe
DEBUG:root:ImageEncoder - af

RuntimeError: result type Float can't be cast to the desired output type Long