# CS4701 Project: Personal Color Analysis

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import os
import torch
from torch.utils.data import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold

import random
import os
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from google.colab import drive

random.seed(0)
torch.manual_seed(0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(F"Device set to {device}")

# Step 1: Create Dataset

In [None]:
# First, define the path to the data

PATH_TO_DATA = 'TODO: path to data'

### Define ColorSeasonDataset class

In [None]:
class ColorSeasonDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, indices=None):
        """
        Args:
            root_dir (string): Root directory containing the RGB-M folder
            split (string): 'train' or 'test'
            transform (callable, optional): Transform to apply to images
            indices (list, optional): If provided, only use these indices
        """
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        
        # Define season and subtype mapping
        self.seasons = ['spring', 'summer', 'autumn', 'winter']
        self.subtypes = {
            'spring': ['warm', 'light', 'bright'],
            'summer': ['cool', 'light', 'soft'],
            'autumn': ['warm', 'deep', 'soft'],
            'winter': ['cool', 'deep', 'bright']
        }
        
        # Collect all image paths and corresponding labels
        self.image_paths = []
        self.labels = []
        
        base_path = os.path.join(root_dir, 'RGB-M', split)
        
        for season_idx, season in enumerate(self.seasons):
            season_path = os.path.join(base_path, season)
            if not os.path.isdir(season_path):
                continue
                
            for subtype_idx, subtype in enumerate(self.subtypes[season]):
                subtype_path = os.path.join(season_path, subtype)
                if not os.path.isdir(subtype_path):
                    continue
                    
                for img_file in os.listdir(subtype_path):
                    if img_file.endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(season_path, subtype, img_file))
                        
                        # Calculate combined label (0-11 for 4 seasons × 3 subtypes)
                        combined_label = season_idx * 3 + subtype_idx
                        self.labels.append(combined_label)
        
        # If indices are provided, use only those
        if indices is not None:
            self.image_paths = [self.image_paths[i] for i in indices]
            self.labels = [self.labels[i] for i in indices]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

### Define transforms to data and create train/test datasets

**Note here**: 
- The authors split the data into 4000 images for training and 900 images for testing (80-20 split)
- It's good practice to have validation set but I'm skipping it for now because I'll use stratified k-fold cross validation later
- This is because the dataset is small so I don't want to lose more data by splitting test set it into validation set

In [None]:
# Define random transforms for training time
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    # transforms.RandAugment(2, 7), # not sure about this because I don't want colors changing for color classification
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define fixed transforms for test time
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


train_dataset = ColorSeasonDataset(
    root_dir=PATH_TO_DATA,
    split='train',
    transform=train_transform
)

test_dataset = ColorSeasonDataset(
    root_dir=PATH_TO_DATA,
    split='test',
    transform=test_transform
)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


# Step 2: Train Model

### Define StratifiedKFold (for validation)