### Initialize Data

In [1]:
import torch
import torch.optim as optim
from focal_loss import FocalLoss
from mock_dataset import MockOutfitDataset
from outfit_model import OutfitCompatibilityModel
import torch.nn as nn
from utils import save_checkpoint

# Instantiate the mock dataset and dataloader
# Contains list of all outfits
mock_dataset = MockOutfitDataset()

# Organizes your dataset into batches.
# Batch size = number of samples processed in one iteration
# Number of batches = total samples divided by batch_size
# Each this case, a sample = an outfit
dataloader = torch.utils.data.DataLoader(mock_dataset, batch_size=50, shuffle=True)

### Visualize Data

In [None]:
from matplotlib import pyplot as plt
import torchvision.transforms.functional as F


def show_images(images, labels):
    for i in range(images.size(0)):
        image = F.to_pil_image(images[i])
        plt.subplot(1, images.size(0), i + 1)
        plt.imshow(image)
        plt.title(f"Label: {labels[i]}")
        plt.axis("off")
    plt.show()


for batch_idx, (images, texts, labels) in enumerate(dataloader):
    print(
        f"Batch {batch_idx + 1} - Shape of images: {images.shape}, Texts: {texts}, Labels: {labels}"
    )

    # Visualize the images
    show_images(images, labels)
    
    if batch_idx == 2:  # Print information for the first 3 batches
        break

### Init Model

In [2]:
# Instantiate the model, dataset, and dataloader
model = OutfitCompatibilityModel()
focal_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)



### Training

In [3]:
# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    for batch in dataloader:
        images, texts, labels = batch  # Adjust this based on your dataset structure
        optimizer.zero_grad()
        outputs = model(images, texts)
        loss = focal_loss(
            outputs, labels.unsqueeze(1)
        )  # Ensure labels have the right dimension
        loss.backward()
        optimizer.step()

        # Print or log the loss if needed
        print(f"Epoch {epoch + 1}, Batch loss: {loss.item()}")

    save_checkpoint(model.state_dict(),"mock",f"model_epoch_{epoch + 1}.pth")

    # Adjust the learning rate as needed (reduce by half in steps of 10)
    if (epoch + 1) % 10 == 0:
        for param_group in optimizer.param_groups:
            param_group["lr"] = param_group["lr"] / 2

item_image: tensor([[[6.1862e-01, 3.5895e-01, 1.3058e-01,  ..., 1.5398e-01,
          1.7951e-01, 7.4827e-01],
         [2.9103e-01, 5.4234e-01, 1.2398e-01,  ..., 3.4051e-01,
          2.8241e-01, 8.4956e-01],
         [5.9702e-01, 6.4334e-01, 2.3772e-01,  ..., 8.0319e-01,
          4.2572e-01, 6.9479e-01],
         ...,
         [1.4980e-01, 6.9386e-01, 7.9831e-01,  ..., 2.4417e-01,
          6.3358e-01, 7.4379e-01],
         [7.7392e-01, 9.1118e-01, 4.0588e-01,  ..., 4.3054e-01,
          4.0429e-01, 4.9255e-01],
         [7.0293e-01, 5.3127e-01, 9.9506e-01,  ..., 1.3217e-02,
          2.3647e-01, 7.1585e-05]],

        [[8.3040e-01, 3.1672e-01, 3.8315e-01,  ..., 8.1234e-03,
          5.4081e-01, 9.8516e-02],
         [2.3183e-01, 5.5794e-01, 5.3545e-01,  ..., 6.1714e-01,
          6.8902e-01, 5.4077e-01],
         [4.2142e-01, 8.7561e-02, 5.2564e-01,  ..., 6.4588e-01,
          8.3411e-01, 9.5979e-02],
         ...,
         [7.8803e-01, 4.4939e-01, 8.3936e-02,  ..., 7.3855e-01,
   

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 20, 224, 224] to have 3 channels, but got 20 channels instead