In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm, trange

# Configuration (Feel free to modify these)
CSV_FILE = '../../datasets/styles_fixed.csv'  # Path to your CSV file.
IMAGE_DIR = 'static/high_res_images'     # Directory containing images.
BATCH_SIZE = 1
LEARNING_RATE = 0.001
NUM_EPOCHS = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 1. Data Loading and Preprocessing

class ApparelDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None, target_attributes=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame containing image metadata.
            image_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            target_attributes (list):  List of attribute names to predict (e.g., ['gender', 'articleType']).
        """
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.transform = transform
        self.target_attributes = target_attributes
        self.label_encoders = {}

        # Initialize LabelEncoders for each target attribute
        for attr in self.target_attributes:
          if attr in self.dataframe.columns:  # Make sure attribute exists
            le = LabelEncoder()
            # Handle missing values by replacing with a placeholder string
            self.dataframe[attr] = self.dataframe[attr].astype(str).fillna('Unknown')
            # Fit and transform the labels
            self.dataframe[attr] = le.fit_transform(self.dataframe[attr])
            self.label_encoders[attr] = le
          else:
            print(f"Warning: Attribute '{attr}' not found in DataFrame.")


    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_id = self.dataframe.iloc[idx]['id']
        img_name = os.path.join(self.image_dir, f"{img_id}.jpg")
        try:
            image = Image.open(img_name).convert('RGB')  # Load image and convert to RGB
        except FileNotFoundError:
            print(f"Error: Image file not found: {img_name}")
            # Return a black image as a placeholder
            image = Image.new('RGB', (224, 224), (0, 0, 0))

        if self.transform:
            image = self.transform(image)

        # Create a dictionary of labels for each target attribute
        labels = {}
        for attr in self.target_attributes:
            if attr in self.dataframe.columns:  # Check if the attribute exists
                labels[attr] = torch.tensor(self.dataframe.iloc[idx][attr], dtype=torch.long)
            else:
                labels[attr] = torch.tensor(-1, dtype=torch.long) # Assign a dummy value

        return image, labels



# Load the CSV (handling errors and limiting rows for faster prototyping)
try:
    data = pd.read_csv(CSV_FILE) #, error_bad_lines=False, warn_bad_lines=True, nrows=1000)  # Read fewer rows at first.
except FileNotFoundError:
    print(f"Error: CSV file not found at {CSV_FILE}")
    exit()
except pd.errors.ParserError:
    print(f"Error: Parsing error in CSV file {CSV_FILE}")
    exit()

# Drop rows with missing IDs
data.dropna(subset=['id'], inplace=True)

# Check for the existence of the images directory
if not os.path.isdir(IMAGE_DIR):
    print(f"Error: Image directory '{IMAGE_DIR}' not found.")
    exit()


# Select target attributes (you can add/remove attributes here)
target_attributes = ['gender', 'articleType', 'baseColour', 'season', 'usage']

# Split the data into training, validation, and test sets
train_df, test_df = train_test_split(data, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.25, random_state=42)  # 0.25 x 0.8 = 0.2


# Image transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
        'test': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Create datasets
train_dataset = ApparelDataset(train_df, IMAGE_DIR, data_transforms['train'], target_attributes)
val_dataset = ApparelDataset(val_df, IMAGE_DIR, data_transforms['val'], target_attributes)
test_dataset = ApparelDataset(test_df, IMAGE_DIR, data_transforms['test'], target_attributes)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)



# 2. Model Definition (Multi-Task Classifier)
class MultiTaskClassifier(nn.Module):
    def __init__(self, num_classes_dict):
        super(MultiTaskClassifier, self).__init__()
        # Use a pre-trained ResNet model
        self.base_model = models.resnet18(pretrained=True)
        in_features = self.base_model.fc.in_features

        # Remove the original fully connected layer
        self.base_model.fc = nn.Identity()  # Replace with Identity to act as feature extractor

        # Create separate classification heads for each attribute
        self.heads = nn.ModuleDict({
            attr: nn.Linear(in_features, num_classes) for attr, num_classes in num_classes_dict.items()
        })

    def forward(self, x):
        x = self.base_model(x)  # Get features from ResNet
        outputs = {attr: head(x) for attr, head in self.heads.items()}
        return outputs

# Calculate the number of classes for each target attribute
num_classes_dict = {attr: len(train_dataset.label_encoders[attr].classes_) for attr in target_attributes if attr in train_dataset.label_encoders}
print(f"Number of classes for each attribute: {num_classes_dict}")

# Instantiate the model
model = MultiTaskClassifier(num_classes_dict).to(DEVICE)


# 3. Loss Function and Optimizer
# Loss Function:  CrossEntropyLoss for each classification head
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning Rate Scheduler (Optional - but generally good practice)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 4. Training Loop
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = {attr: 0.0 for attr in target_attributes} #track best accuracy for each attribute

    for epoch in trange(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = train_loader
            else:
                model.eval()   # Set model to evaluate mode
                dataloader = val_loader


            running_loss = 0.0
            running_corrects = {attr: 0 for attr in target_attributes} #track corrects for each attribute
            total_samples = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloader):
                inputs = inputs.to(DEVICE)

                # Convert labels to a dictionary of tensors and move to the device
                labels_dict = {attr: labels[attr].to(DEVICE) for attr in target_attributes if attr in labels}

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = 0
                    for attr in target_attributes:
                         if attr in labels_dict and attr in outputs: #Check if attribute is present
                            loss += criterion(outputs[attr], labels_dict[attr])
                            _, preds = torch.max(outputs[attr], 1)
                            running_corrects[attr] += torch.sum(preds == labels_dict[attr].data)


                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                total_samples += inputs.size(0)

            if phase == 'train':
                scheduler.step() # Update learning rate

            epoch_loss = running_loss / total_samples
            epoch_acc = {attr: running_corrects[attr].double().item() / total_samples for attr in target_attributes if attr in running_corrects}


            print('{} Loss: {:.4f} Acc: {}'.format(
                phase, epoch_loss, ", ".join([f"{attr}: {epoch_acc[attr]:.4f}" for attr in epoch_acc])))


            # deep copy the model
            if phase == 'val':
              for attr in target_attributes:
                if attr in epoch_acc and epoch_acc[attr] > best_acc[attr]:  # Check if attr exists in both dicts
                    best_acc[attr] = epoch_acc[attr]
                    best_model_wts = model.state_dict()

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {}'.format(", ".join([f"{attr}: {best_acc[attr]:.4f}" for attr in best_acc])))


    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


model = train_model(model, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS)



# 5. Evaluation

def evaluate_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode

    running_corrects = {attr: 0 for attr in target_attributes}
    total_samples = 0

    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels in dataloader:
            inputs = inputs.to(DEVICE)
            labels_dict = {attr: labels[attr].to(DEVICE) for attr in target_attributes if attr in labels}

            outputs = model(inputs)

            for attr in target_attributes:
                if attr in labels_dict and attr in outputs:
                  _, preds = torch.max(outputs[attr], 1)
                  running_corrects[attr] += torch.sum(preds == labels_dict[attr].data)

            total_samples += inputs.size(0)

    accuracy = {attr: running_corrects[attr].double().item() / total_samples for attr in target_attributes if attr in running_corrects}
    print('Accuracy on the test set:')
    for attr in target_attributes:
       if attr in accuracy:
        print(f'{attr}: {accuracy[attr]:.4f}')


evaluate_model(model, test_loader)


# 6.  Prediction on a Single Image (Inference)
def predict_single_image(model, image_path, transform, label_encoders):
    """Predicts attributes of a single image.

    Args:
        model: Trained PyTorch model.
        image_path: Path to the image file.
        transform: Image transformation pipeline.
        label_encoders: Dictionary of LabelEncoders used during training.
    """

    model.eval()  # Set model to evaluation mode

    try:
        image = Image.open(image_path).convert('RGB')
    except FileNotFoundError:
        print(f"Error: Image not found at {image_path}")
        return None

    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension [1, C, H, W]
    image = image.to(DEVICE)


    with torch.no_grad():
        outputs = model(image)

    predictions = {}
    for attr in label_encoders:
        if attr in outputs:  # Check if the attribute head exists.
           _, predicted_class = torch.max(outputs[attr], 1)
           predicted_label = label_encoders[attr].inverse_transform(predicted_class.cpu().numpy())[0]
           predictions[attr] = predicted_label
        else:
            predictions[attr] = "N/A"  # Attribute not predicted by the model

    return predictions


# Example usage of single image prediction:
image_to_predict = 'images/10017.jpg' # Replace with your image path
if os.path.exists(image_to_predict):  #check if the file exists
    predictions = predict_single_image(model, image_to_predict, data_transforms['test'], train_dataset.label_encoders)

    print(f"Predictions for {image_to_predict}:")
    for attr, value in predictions.items():
        print(f"  {attr}: {value}")
else:
	print(f"Image file {image_to_predict} not found.  Skipping single image prediction.")


# 7. Save and load the model (Optional)
torch.save(model.state_dict(), 'apparel_classifier.pth')

# To load:
# loaded_model = MultiTaskClassifier(num_classes_dict)  # Same architecture!
# loaded_model.load_state_dict(torch.load('apparel_classifier.pth'))
# loaded_model.to(DEVICE)
# loaded_model.eval()  # Important: set to eval mode for inference!