### All Shapes and Colors - Kaggle Challenge ###

In [27]:
# Imports
import os, ast
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import matplotlib.pyplot as plt

# Device (prototyped on my MacBook Air M1)
device = torch.device(
    "mps" if torch.backends.mps.is_available() else
    ("cuda" if torch.cuda.is_available() else "cpu") # for colab compatibility
)
device

device(type='mps')

After looking at the problem, I can see we have the following constraints:
- Each image contains k objects where 1 <= k <= 9, since there are no duplicate objects with same (shape, color) in any given image.
- we can have multiple of the same shape and multiple of the same color in the image. 

I'll want to use a multi-label target over all 9 possible pairs. For each image, the target is a 9-dimension one-hot vector with a one for each present pair in the image.

In [8]:
SHAPES = ['circle', 'square', 'triangle']
COLORS = ['red', 'green', 'blue']

# assign each (shape, color) pair an index
PAIR_TO_IDX = {
    f"{shape}_{color}": i
    for i, (shape, color) in enumerate(
        (s, c) for s in SHAPES for c in COLORS
    )
}

IDX_TO_PAIR = {v: k for k, v in PAIR_TO_IDX.items()} # need this to decode predictions

print(PAIR_TO_IDX)
print(IDX_TO_PAIR)

{'circle_red': 0, 'circle_green': 1, 'circle_blue': 2, 'square_red': 3, 'square_green': 4, 'square_blue': 5, 'triangle_red': 6, 'triangle_green': 7, 'triangle_blue': 8}
{0: 'circle_red', 1: 'circle_green', 2: 'circle_blue', 3: 'square_red', 4: 'square_green', 5: 'square_blue', 6: 'triangle_red', 7: 'triangle_green', 8: 'triangle_blue'}


Step one is to setup label representation

In [None]:
# Encoding the pairs to multi-hot vector
def encode_pairs(pairs):
    """
    Encode a list of (shape, color) pairs to a multi-hot vector.
    pairs: list of tuples (shape, color)
    returns: torch tensor of shape (9,)
    """
    target = np.zeros(len(PAIR_TO_IDX), dtype=np.float32)
    for shape, color in pairs:
        shape, color = shape.lower(), color.lower()
        key = f"{shape}_{color}"
        y = PAIR_TO_IDX[key]
        target[y] = 1.0
    target = torch.from_numpy(target)
    return target

# pairs = [("circle","red"), ("triangle","blue")]
# y = encode_pairs(pairs)
# print(y)        
# print(y.sum())  

def decode_vec(y):
    """
    Decode a multi-hot vector to a list of (shape, color) pairs.
    y: torch tensor of shape (9,)
    returns: list of tuples (shape, color)
    """
    probs = torch.sigmoid(y) # convert logits to probabilities using sigmoid (will give values between 0 and 1)
    idxs = (probs >= 0.5).nonzero(as_tuple=True)[0].tolist() # threshold at 0.5
    pairs = []
    for i in idxs:
        pair = IDX_TO_PAIR[i]
        shape, color = pair.split("_")
        pairs.append((shape, color))
    return pairs


# logits = torch.tensor([3.0, -1.0, 0.2, 0.0, 0.0, 2.5, -2.0, 0.0, 0.0])
# decoded = decode_vec(logits)
# print(decoded)


There's a function to make the data usable for training. I represent the labels as the 9-dim tensor mentioned earlier. 

I also have a function to decode the tensor outputted at inference time. It applies a sigmoid to get probabilities, thresholds, then turns the tensor back into (shape, color) tuples. 

Next thing to do is load the data.

In [None]:
def parse_label_string(string): # need this since the CSV has labels as strings
    """
    Parse a label string into a list of (shape, color) pairs.
    string: str, e.g. "[(circle_red), (triangle_blue)]" or "[]"
    returns: list of tuples (shape, color)
    """
    if string == "" or string == "[]" or string is None: #handle null or empty arg
        return []
    s = string.lower().strip()
    data = ast.literal_eval(s) # turn the string into a list of tuples
    # normalize the tuples
    data = [(shape.strip().lower(), color.strip().lower()) for shape, color in data]
    return data

# Datasets
class ShapesColorsDatasetTrain(Dataset):
    """
    Dataset for training and validation: (image_tensor, target_vector)
    CSV columns: image_path, label
    """
    def __init__(self, csv_file, img_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = os.path.join(self.img_dir, self.df.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        
        label_str = self.df.iloc[idx, 1]
        pairs = parse_label_string(label_str)
        target = encode_pairs(pairs)

        if self.transform:
            image = self.transform(image)

        return image, target