# Library Imports

In [1]:
# General imports
import numpy as np
import matplotlib.pyplot as plt
import os
import csv
import ast
import cv2
# Pytorch imports
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

# Prep Dataset

In [2]:
# square = 0, circle = 1, triangle = 2
# red = 0, green = 1, blue = 2
LABELS = ['square', 'circle', 'triangle', 'NONE']
COLORS = ['red', 'green', 'blue', 'NONE']

def customCollate(batch):
    images = torch.stack([item[0] for item in batch])

    shapes = torch.stack([item[1]["shape"] for item in batch])
    colors = torch.stack([item[1]["color"] for item in batch])

    label_dict = {"shape": shapes, "color": colors}
    return images, label_dict


class CustomDataset(Dataset):
    def __init__(self, input_dir, label_csv, transform=transforms.ToTensor(), max_shapes=10):
        self.max_shapes = max_shapes
        self.transform = transform
        self.label_to_num = {"square" : 0, "circle" : 1, "triangle" : 2}
        self.color_to_num = {"red" : 0, "green" : 1, "blue" : 2}
        self.img_filenames = []
        self.img_labels = []
        for img in os.listdir(input_dir):
            self.img_filenames.append(os.path.join(input_dir, img))
        with open(label_csv, 'r') as file:
            reader = csv.reader(file)
            first = True
            for row in reader:
                if first:
                    first = False
                    continue
                self.img_labels.append(ast.literal_eval(row[1]))

    def __len__(self):
        return len(self.img_filenames)
    
    def __getitem__(self, idx):
        img_filename = self.img_filenames[idx]
        get_labels = self.img_labels[idx]

        image = cv2.imread(img_filename)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image)

        shape_list = []
        color_list = []
        for i in range(self.max_shapes):
            if i < len(get_labels):
                shape, col = get_labels[i]
                shape_list.append(self.label_to_num[shape])
                color_list.append(self.color_to_num[col])
            else:
                shape_list.append(3)
                color_list.append(3)
        label_dict = {"shape": torch.from_numpy(np.array(shape_list)), "color": torch.from_numpy(np.array(color_list))}

        return image, label_dict



In [3]:
train_dataset = CustomDataset(r"dataset_v3\train_dataset", r"dataset_v3\train.csv")
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, collate_fn=customCollate, pin_memory=True)

# Custom Loss

In [4]:
def customLoss(pred, labels):
    pred_shapes = pred["shape"]
    pred_colors = pred["color"]
    label_shapes = labels["shape"]
    label_colors = labels["color"]

    criterion_shape = nn.CrossEntropyLoss()
    criterion_color = nn.CrossEntropyLoss()

    shape_loss = criterion_shape(pred_shapes.reshape(-1, 4), label_shapes.reshape(-1))
    color_loss = criterion_color(pred_colors.reshape(-1, 4), label_colors.reshape(-1))

    return (shape_loss + color_loss) / 2.0


# Model Architecture

In [5]:
class ShapeColorModel(nn.Module):
    def __init__(self, max_shapes=10):
        super().__init__()
        self.max_shapes = max_shapes

        # Get resnet50 but remove the last layer 
        self.backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        # To get output size (batch, 2048)
        self.backbone.fc = nn.Identity()
        self.backbone_dim = 2048

        # One head to predict each of shape and color
        self.color_head = nn.Linear(self.backbone_dim, self.max_shapes * 4)
        self.shape_head = nn.Linear(self.backbone_dim, self.max_shapes * 4)
    
    def forward(self, x):
        features = self.backbone(x)
    
        # Get shape and color preds for each slot
        out_shape = self.shape_head(features).view(-1, self.max_shapes, 4) # Get to shape (batch_size, max_shapes, 4)
        out_color = self.color_head(features).view(-1, self.max_shapes, 4) # Get to shape (batch_size, max_shapes, 4)

        return {"shape" : out_shape, "color" : out_color}


# Training Loop

In [None]:
NUM_EPOCHS = 10

model = ShapeColorModel()
model.cuda()

print(next(model.shape_head.parameters()).device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

device = torch.device('cuda')

for epoch in range(NUM_EPOCHS):
    epoch_loss = []
    print("Entering batches")
    for i, (batch, labels) in enumerate(train_dataloader):
        print("Starting batch to cuda")
        batch = batch.to(device, non_blocking=True)
        print("Done with batch loading")
        labels = {k: v.to(device, non_blocking=True) for k, v in labels.items()}

        print("Forward pass")
        pred = model(batch)
        optimizer.zero_grad()
        loss = customLoss(pred, labels)
        print("Post loss")
        loss.backward()
        print("Post backward pass")
        optimizer.step()
        
        epoch_loss.append(loss)

    print(f"Avg Epoch Loss: {sum(epoch_loss) / len(epoch_loss)}")


cuda:0
Entering batches
