In [None]:
import os
import json
import pandas as pd
from typing import List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchmetrics.functional import dice
from transformers import Mask2FormerForUniversalSegmentation
import torchvision.transforms as T
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from functools import partial
import gc
from tqdm import tqdm

# Device and seed setup
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

### Data Pre-processing

In [None]:
class IDMapper:
    def __init__(self, name=''):
        self.original_to_consecutive = {}
        self.consecutive_to_original = {}
        self.next_id = 0
        self.name = name

    def fit(self, id_lists):
        # Fit the mapper to a collection of IDs.
        unique_ids = set()
        if id_lists and isinstance(id_lists[0], (list, set)):
            # Handle nested lists (for attributes)
            for id_list in id_lists:
                unique_ids.update(id_list)
        else:
            # Handle flat lists (for categories)
            unique_ids.update(id_lists)
        
        for original_id in sorted(unique_ids):
            self.original_to_consecutive[original_id] = self.next_id
            self.consecutive_to_original[self.next_id] = original_id
            self.next_id += 1
    
    def transform(self, ids):
        # Transform original IDs to consecutive IDs.
        if isinstance(ids, (list, set)):
            return [self.original_to_consecutive[id_] for id_ in ids]
        return self.original_to_consecutive[ids]
    
    def inverse_transform(self, consecutive_ids):
        # Transform consecutive IDs back to original IDs.
        if isinstance(consecutive_ids, (list, set)):
            return [self.consecutive_to_original[id_] for id_ in consecutive_ids]
        return self.consecutive_to_original[consecutive_ids]
    
    def save(self, filepath):
        mapping_data = {
            'original_to_consecutive': {str(k): v for k, v in self.original_to_consecutive.items()},
            'consecutive_to_original': {str(k): v for k, v in self.consecutive_to_original.items()},
            'next_id': self.next_id,
            'name': self.name
        }
        with open(filepath, 'w') as f:
            json.dump(mapping_data, f, indent=2)
    
    def load(self, filepath):
        with open(filepath, 'r') as f:
            mapping_data = json.load(f)
        
        self.original_to_consecutive = {int(k): v for k, v in mapping_data['original_to_consecutive'].items()}
        self.consecutive_to_original = {int(k): v for k, v in mapping_data['consecutive_to_original'].items()}
        self.next_id = mapping_data['next_id']
        self.name = mapping_data.get('name', '')
    
    @property
    def num_ids(self):
        # Return total number of mapped IDs
        return self.next_id

class FashionpediaDataset(Dataset):
    def __init__(self, csv_file: str, img_dir: str, category_mapping_file: str = None, attribute_mapping_file: str = None):
        self.csv_data = self.read_csv(csv_file)
        self.img_dir = img_dir
        
        # Initialize mappers
        self.category_mapper = IDMapper(name='category')
        self.attribute_mapper = IDMapper(name='attribute')
        
        # Handle category mapping
        if category_mapping_file and os.path.exists(category_mapping_file):
            self.category_mapper.load(category_mapping_file)
        else:
            all_category_ids = []
            for item in self.csv_data:
                all_category_ids.extend(item['CategoryId'])
            self.category_mapper.fit(all_category_ids)
            if category_mapping_file:
                self.category_mapper.save(category_mapping_file)
        
        # Handle attribute mapping
        if attribute_mapping_file and os.path.exists(attribute_mapping_file):
            self.attribute_mapper.load(attribute_mapping_file)
        else:
            all_attribute_lists = []
            for item in self.csv_data:
                all_attribute_lists.extend(item['AttributesIds'])
            self.attribute_mapper.fit(all_attribute_lists)
            if attribute_mapping_file:
                self.attribute_mapper.save(attribute_mapping_file)
        
        # Transform IDs to consecutive ones
        for item in self.csv_data:
            item['CategoryId'] = self.category_mapper.transform(item['CategoryId'])
            item['AttributesIds'] = [self.attribute_mapper.transform(attr_list) for attr_list in item['AttributesIds']]

    def read_csv(self, csv_file: str) -> List[Dict[str, Any]]:
        df = pd.read_csv(csv_file)
        # Group by ImageId to handle multiple segmentations per image
        grouped = df.groupby('ImageId')
        data = []
        for _, group in grouped:
            item = {
                'ImageId': group.iloc[0]['ImageId'],
                'Height': group.iloc[0]['Height'],
                'Width': group.iloc[0]['Width'],
                'EncodedPixels': [],
                'CategoryId': [],
                'AttributesIds': []
            }
            
            # Collect data for each segmentation
            for _, row in group.iterrows():
                item['EncodedPixels'].append(row['EncodedPixels'])
                item['CategoryId'].append(row['CategoryId'])
                item['AttributesIds'].append([int(id) for id in row['AttributesIds'].split(',')])

            data.append(item)

        return data

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        inputs = {}

        # Prepare pixel values
        img_data = self.csv_data[idx]
        img_path = os.path.join(self.img_dir, f"{img_data['ImageId']}.jpg")
        image = Image.open(img_path).convert("RGB")

        transform = T.Compose([
        T.Resize((384, 384)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], 
                   std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)
        
        inputs['pixel_values'] = torch.tensor(image).unsqueeze(0)

        # Create a multi-class segmentation mask
        masks = []
        for i, encoded_pixels in enumerate(img_data['EncodedPixels'], start=1):
            mask = np.zeros((int(img_data['Height']), int(img_data['Width'])), dtype=np.int32)
            if isinstance(encoded_pixels, str):
                encoded_pixels = [int(x) for x in encoded_pixels.split()]
                for j in range(0, len(encoded_pixels), 2):
                    start, length = encoded_pixels[j] - 1, encoded_pixels[j + 1]
                    row = start // int(img_data['Width'])
                    col = start % int(img_data['Width'])
                    mask[row:row+length, col] = i
            mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).float()
            mask = F.interpolate(mask, size=(384, 384), mode='bilinear', align_corners=False)
            mask = mask.squeeze(0).squeeze(0)
            masks.append(mask)
        
        inputs['pixel_masks'] = torch.stack(masks).unsqueeze(0)
        
        # Prepare category labels
        inputs['category_labels'] = torch.tensor(img_data['CategoryId']).unsqueeze(0)

        # Prepare one-hot encoded attribute labels
        num_attributes = self.attribute_mapper.num_ids
        num_segments = len(img_data['AttributesIds'])
        
        one_hot_attributes = torch.zeros(num_segments, num_attributes)
        
        for segment_idx, attr_ids in enumerate(img_data['AttributesIds']):
            if attr_ids:  # Check if there are any attributes for this segment
                one_hot_attributes[segment_idx, attr_ids] = 1
        
        inputs['attribute_labels'] = one_hot_attributes.unsqueeze(0)

        return inputs # Inputs for one image

def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch]).squeeze(0)
    pixel_masks = torch.stack([item['pixel_masks'] for item in batch]).squeeze(0).squeeze(0)
    category_labels = torch.stack([item['category_labels'] for item in batch]).squeeze(0).squeeze(0)
    attribute_labels = torch.stack([item['attribute_labels'] for item in batch]).squeeze(0).squeeze(0)

    return {
        'pixel_values': pixel_values,
        'pixel_masks': pixel_masks,
        'category_labels': category_labels,
        'attribute_labels': attribute_labels
    }

### Loss Functions

In [None]:
class UncertaintyWeights(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize log variances for each task
        self.log_var_category = nn.Parameter(torch.zeros(1))
        self.log_var_attribute = nn.Parameter(torch.zeros(1))
        self.log_var_mask = nn.Parameter(torch.zeros(1))

    def forward(self):
        # Return precision weights and regularization terms
        precision_category = torch.exp(-self.log_var_category)
        precision_attribute = torch.exp(-self.log_var_attribute)
        precision_mask = torch.exp(-self.log_var_mask)
        
        return {
            'category': precision_category,
            'attribute': precision_attribute,
            'mask': precision_mask,
            'regularization': (self.log_var_category + self.log_var_attribute + self.log_var_mask)
        }

def compute_miou(mask_logits, pixel_masks):
    # Convert logits to binary predictions
    mask_preds = (torch.sigmoid(mask_logits) > 0.5).float()
    
    # Ensure pixel_masks are the same size as predictions
    if pixel_masks.shape[-2:] != mask_preds.shape[-2:]:
        pixel_masks = F.interpolate(
            pixel_masks.unsqueeze(1) if pixel_masks.dim() == 3 else pixel_masks,
            size=mask_preds.shape[-2:],
            mode='nearest'  # Using nearest neighbor to preserve binary values
        )
        pixel_masks = pixel_masks.squeeze(1) if pixel_masks.dim() == 4 else pixel_masks

    # Hungarian matching
    mask_preds = mask_preds.reshape(mask_preds.shape[0], -1)
    pixel_masks = pixel_masks.reshape(pixel_masks.shape[0], -1)

    intersection = torch.matmul(mask_preds, pixel_masks.T)  # Matrix multiplication to get pairwise intersections

    union = torch.sum(mask_preds.unsqueeze(1) + pixel_masks.unsqueeze(0), dim=-1) - intersection  # Pairwise unions

    iou = intersection / (union + 1e-6)

    topk_ious, topk_indices = torch.topk(torch.flatten(iou), k=pixel_masks.shape[0], dim=0)
    
    miou = torch.mean(topk_ious)
    best_indices = topk_indices // pixel_masks.shape[0]

    return miou, best_indices

def compute_category_loss(category_logits, category_labels):
    num_predictions, num_classes = category_logits.shape
    num_labels = category_labels.shape[0]

    # Expand logits and labels to calculate pairwise losses
    expanded_logits = category_logits.unsqueeze(1).expand(-1, num_labels, -1)
    expanded_labels = category_labels.unsqueeze(0).expand(num_predictions, -1)

    # Calculate pairwise cross-entropy losses
    pairwise_losses = F.cross_entropy(
        expanded_logits.reshape(-1, num_classes),
        expanded_labels.reshape(-1),
        reduction='none'
    ).reshape(num_predictions, num_labels)

    # Take the minimum loss for each prediction across the target labels
    best_losses, _ = pairwise_losses.min(dim=0)
    
    # Compute the final mean of the best losses
    category_loss = best_losses.mean()

    return category_loss

def compute_attribute_loss(attribute_logits, attribute_labels):
    num_predictions, num_attributes = attribute_logits.shape
    num_labels = attribute_labels.shape[0]

    # Calculate pairwise binary cross-entropy losses
    pairwise_losses = F.binary_cross_entropy_with_logits(
        attribute_logits.unsqueeze(1).expand(-1, num_labels, -1),
        attribute_labels.unsqueeze(0).expand(num_predictions, -1, -1),
        reduction='none'
    )


    # Calculate the total loss by summing across the attributes for each prediction-label pair
    total_losses = pairwise_losses.mean(dim=-1)

    # Select the lowest loss for each prediction across the 4 labels
    best_losses, _ = total_losses.min(dim=1)
    
    # Compute the mean of the best losses
    attribute_loss = best_losses.mean()

    return attribute_loss

def compute_weighted_loss(predictions, targets, uncertainty_weights):
    category_logits, attribute_logits, mask_logits = predictions
    category_labels, attribute_labels, pixel_masks = targets

    # Mask Loss
    mask_logits = mask_logits.squeeze(0)
    miou, _ = compute_miou(mask_logits, pixel_masks)
    mask_loss = 1 - miou

    # Category Loss
    category_logits = category_logits.squeeze(0)
    category_loss = compute_category_loss(category_logits, category_labels)

    # Attribute Loss
    attribute_logits = attribute_logits.squeeze(0)
    attribute_loss = compute_attribute_loss(attribute_logits, attribute_labels)

    # Apply uncertainty weighting
    weights = uncertainty_weights()
    weighted_category_loss = weights['category'] * category_loss
    weighted_attribute_loss = weights['attribute'] * attribute_loss
    weighted_mask_loss = weights['mask'] * mask_loss
    
    # Combine losses with regularization term
    total_loss = (
        weighted_category_loss + 
        weighted_attribute_loss + 
        weighted_mask_loss + 
        0.5 * weights['regularization']
    )
    
    return total_loss, category_loss, attribute_loss, mask_loss

### Custom Mask2Former

In [None]:
class CustomMask2FormerForFashionpedia(nn.Module):
    def __init__(self, num_categories: int, num_attributes: int):
        super().__init__()
        self.mask2former = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-instance")

        for param in self.mask2former.parameters():
            param.requires_grad = False
        
        self.category_classifier = nn.Linear(self.mask2former.config.hidden_size, num_categories)
        self.attribute_classifier = nn.Linear(self.mask2former.config.hidden_size, num_attributes)
        
        self.uncertainty_weights = UncertaintyWeights()

    def forward(
        self, 
        pixel_values,
        pixel_masks,
        category_labels,
        attribute_labels
    ):
        outputs = self.mask2former(pixel_values=pixel_values)
        mask_logits = outputs.masks_queries_logits

        last_hidden_state = outputs.transformer_decoder_last_hidden_state
        category_logits = self.category_classifier(last_hidden_state)
        attribute_logits = self.attribute_classifier(last_hidden_state)

        predictions = (category_logits, attribute_logits, mask_logits)
        targets = (category_labels, attribute_labels, pixel_masks)

        # Calculate uncertainty weighted loss
        total_loss, category_loss, attribute_loss, mask_loss = compute_weighted_loss(
            predictions, 
            targets,
            self.uncertainty_weights
        )

        return total_loss, category_loss, attribute_loss, mask_loss, category_logits, attribute_logits, mask_logits
    
    def predict(self, pixel_values):
        outputs = self.mask2former(pixel_values=pixel_values)
        mask_logits = outputs.masks_queries_logits

        last_hidden_state = outputs.transformer_decoder_last_hidden_state
        category_logits = self.category_classifier(last_hidden_state)
        attribute_logits = self.attribute_classifier(last_hidden_state)

        return category_logits, attribute_logits, mask_logits

### Training

In [None]:
def load_weights(model, checkpoint_id):
    checkpoint_path = f'mask2former_checkpoints/mask2former_checkpoint_{checkpoint_id}.pth'
    if os.path.isfile(checkpoint_path):
        print(f" - Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        print(f" - No checkpoint found at {checkpoint_path}")
    return model

def train_model(model, train_loader, val_loader, optimizer, num_epochs, checkpoint_id, device, gradient_accumulation_steps=4):

    model = load_weights(model, checkpoint_id - 1)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        optimizer.zero_grad()  # Zero gradients at start of epoch
        
        for batch_idx, batch in enumerate(train_loader):
            # Transfer batch to device
            pixel_values = batch['pixel_values'].to(device)
            pixel_masks = batch['pixel_masks'].to(device)
            category_labels = batch['category_labels'].to(device)
            attribute_labels = batch['attribute_labels'].to(device)

            # Forward pass
            total_loss, _, _, _, _, _, _ = model(
                pixel_values,
                pixel_masks,
                category_labels,
                attribute_labels
            )
            
            # Scale loss by gradient accumulation steps
            scaled_loss = total_loss / gradient_accumulation_steps
            scaled_loss.backward()
            
            train_loss += total_loss.item()

            # Step optimizer after accumulating gradients
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                
                # Explicit garbage collection
                del total_loss, scaled_loss
                gc.collect()
                torch.cuda.empty_cache()

            # Clear memory for the batch
            del pixel_values, pixel_masks, category_labels, attribute_labels
            gc.collect()
            torch.cuda.empty_cache()

        # Handle any remaining gradients at end of epoch
        if (batch_idx + 1) % gradient_accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        train_loss /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch['pixel_values'].to(device)
                pixel_masks = batch['pixel_masks'].to(device)
                category_labels = batch['category_labels'].to(device)
                attribute_labels = batch['attribute_labels'].to(device)

                total_loss, _, _, _, _, _, _ = model(
                    pixel_values,
                    pixel_masks,
                    category_labels,
                    attribute_labels
                )
                
                val_loss += total_loss.item()

                # Clear memory after each validation batch
                del pixel_values, pixel_masks, category_labels, attribute_labels
                del total_loss
                gc.collect()
                torch.cuda.empty_cache()

        val_loss /= len(val_loader)

        print(f"-----CHECKPOINT {checkpoint_id}-----")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Memory allocated: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
        print(f"Memory cached: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")
        print(f"----------------------")

        # Save model with garbage collection
        gc.collect()
        torch.cuda.empty_cache()
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss
        }
        torch.save(checkpoint, f'mask2former_checkpoints/mask2former_checkpoint_{checkpoint_id}.pth')
        
        # Force garbage collection at end of epoch
        gc.collect()
        torch.cuda.empty_cache()

def plot_loss_curves(checkpoint_dir):
    if not os.listdir(checkpoint_dir):
        print(f" - No checkpoint files found in '{checkpoint_dir}'.")
    else:
        train_losses = []
        val_losses = []
        epochs = []
        for filename in os.listdir(checkpoint_dir):
            if filename.endswith('.pth'):
                checkpoint_path = os.path.join(checkpoint_dir, filename)
                checkpoint = torch.load(checkpoint_path)
                train_losses.append(checkpoint['train_loss'])
                val_losses.append(checkpoint['val_loss'])
                epochs.append(checkpoint['epoch'])

        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training vs Validation Loss')
        plt.legend()
        plt.savefig('mask2former_loss_plot.png')
        plt.show()

### Evaluation

In [None]:
def load_best_model(model, checkpoint_dir):
    val_losses = []
    for filename in os.listdir(checkpoint_dir):
        if filename.endswith(".pth"):
            checkpoint_path = os.path.join(checkpoint_dir, filename)
            checkpoint = torch.load(checkpoint_path)
            val_loss = checkpoint['val_loss']
            val_losses.append(val_loss)

    if val_losses:
        best_index = val_losses.index(min(val_losses))
        best_filename = os.listdir(checkpoint_dir)[best_index]
        best_model_path = os.path.join(checkpoint_dir, best_filename)
        best_checkpoint = torch.load(best_model_path)
        print(f" - Loading best model from {best_model_path}")
        model.load_state_dict(best_checkpoint['model_state_dict'])
    else:
        print(" - No checkpoints found.")

    return model

def compute_category_dice(category_predictions, category_labels):
    num_predictions, _ = category_predictions.shape
    num_labels = category_labels.shape[0]

    # Calculate pairwise Dice scores
    pairwise_dice_scores = torch.zeros((num_predictions, num_labels))

    for i in range(num_predictions):
        for j in range(num_labels):
            # Calculate the Dice score for each pair of prediction and label
            pairwise_dice_scores[i, j] = dice(
                category_predictions[i].unsqueeze(0),
                category_labels[j].unsqueeze(0).int(),
                mdmc_average='global'
            )

    # Select the best Dice score for each prediction
    best_dice_scores, _ = pairwise_dice_scores.max(dim=1)
    best_dice_score = best_dice_scores.mean()

    return best_dice_score

def compute_attribute_dice(attribute_predictions, attribute_labels):
    num_predictions, _ = attribute_predictions.shape
    num_labels = attribute_labels.shape[0]

    # Reshape and calculate pairwise Dice scores for each attribute
    pairwise_dice = torch.zeros((num_predictions, num_labels), device=attribute_predictions.device)

    for i in range(num_labels):
        # Calculate Dice score between each prediction and label
        pairwise_dice[:, i] = dice(
            attribute_predictions,
            attribute_labels[i].unsqueeze(0).expand(num_predictions, -1).int(),
            mdmc_average='samplewise'
        )

    # Select the best Dice score for each prediction
    best_dice_scores, _ = pairwise_dice.max(dim=1)
    attribute_dice = best_dice_scores.mean()

    return attribute_dice

def evaluate_model(model, test_loader, device, batch_size=32):
    model.eval()
    batch_results = []
    current_batch = {
        'mask_predictions': [], 
        'pixel_masks': [], 
        'category_predictions': [], 
        'category_labels': [], 
        'attribute_predictions': [], 
        'attribute_labels': [], 
    }
    
    def process_batch_metrics(batch_data):
        # Calculate metrics for current batch
        mask_predictions = torch.cat(batch_data['mask_predictions'], dim=0)
        pixel_masks = torch.cat(batch_data['pixel_masks'], dim=0)
        mask_miou, _ = compute_miou(mask_predictions, pixel_masks)

        category_predictions = torch.cat(batch_data['category_predictions'],dim=0)
        category_labels = torch.cat(batch_data['category_labels'],dim=0).int()
        category_dice = compute_category_dice(category_predictions, category_labels)

        attribute_predictions = torch.cat(batch_data['attribute_predictions'],dim=0)
        attribute_labels = torch.cat(batch_data['attribute_labels'],dim=0).int()
        attribute_dice = compute_attribute_dice(attribute_predictions, attribute_labels)

        return {
            'miou': mask_miou,
            'category_dice': category_dice,
            'attribute_dice': attribute_dice,
            'batch_size': len(batch_data['mask_predictions'])
        }
    
    def clear_batch_data(batch_data):
        for key in batch_data:
            batch_data[key] = []
        gc.collect()
        torch.cuda.empty_cache()

    progress_bar = tqdm(total=len(test_loader), desc="Evaluating")
    
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            # Process batch
            pixel_values = batch['pixel_values'].to(device)
            pixel_masks = batch['pixel_masks'].to(device)
            category_labels = batch['category_labels'].to(device)
            attribute_labels = batch['attribute_labels'].to(device)
            
            category_logits, attribute_logits, mask_logits = model.predict(pixel_values)
            
            mask_logits = mask_logits.squeeze(0)
            current_batch['mask_predictions'].append(mask_logits)
            current_batch['pixel_masks'].append(pixel_masks)
            
            category_logits = category_logits.squeeze(0)
            current_batch['category_predictions'].append(category_logits)
            current_batch['category_labels'].append(category_labels)
            
            attribute_logits = attribute_logits.squeeze(0)
            current_batch['attribute_predictions'].append(attribute_logits)
            current_batch['attribute_labels'].append(attribute_labels)
            
            # Clear immediate memory
            del pixel_values, pixel_masks, category_labels, attribute_labels
            del category_logits, attribute_logits, mask_logits
            gc.collect()
            torch.cuda.empty_cache()

            progress_bar.update(1)
            
            # Process batch if reached batch_size
            if (i + 1) % batch_size == 0:
                batch_metrics = process_batch_metrics(current_batch)
                batch_results.append(batch_metrics)
                clear_batch_data(current_batch)

            # Number of samples to evaluate
            if i == 1024:
                break

    # Calculate final metrics by summing values and dividing by the number of batches
    num_batches = len(batch_results)

    final_miou = sum(result['miou'] for result in batch_results) / num_batches
    final_category_dice = sum(result['category_dice'] for result in batch_results) / num_batches
    final_attribute_dice = sum(result['attribute_dice'] for result in batch_results) / num_batches

    # Save results
    with open("mask2former_results.txt", "w") as f:
        f.write("-----RESULTS-----\n")
        f.write(f"Mean IoU: {final_miou:.4f}\n")
        f.write(f"Category Dice Coefficient: {final_category_dice:.4f}\n")
        f.write(f"Attributes Dice Coefficient: {final_attribute_dice:.4f}\n")
        f.write("-----------------")

    with open("mask2former_results.txt", "r") as f:
        print(f.read())
    
    # Clear final memory
    del batch_results, current_batch
    gc.collect()
    torch.cuda.empty_cache()
    
    return final_category_dice, final_attribute_dice, final_miou

### Sample Image Segmentation

In [None]:
def load_category_attribute_names(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    categories = {
        str(category['id']): category['name']
        for category in data['categories']
    }
    
    attributes = {
        str(attr['id']): attr['name']
        for attr in data['attributes']
    }
    return categories, attributes

def compute_iou(mask1, mask2):
    intersection = torch.logical_and(mask1, mask2).sum()
    union = torch.logical_or(mask1, mask2).sum()
    return (intersection / union).item() if union > 0 else 0

def non_max_suppression(masks, confidence_scores, iou_threshold=0.5):
    if len(masks) == 0:
        return []
    
    # Convert to numpy for easier indexing
    scores = confidence_scores.cpu().numpy()
    
    # Get indices sorted by confidence
    indices = scores.argsort()[::-1]
    kept_indices = []
    
    while len(indices) > 0:
        # Keep the mask with highest confidence
        current_idx = indices[0]
        kept_indices.append(current_idx)
        
        if len(indices) == 1:
            break
            
        # Compute IoU between the current mask and all remaining masks
        ious = []
        current_mask = masks[current_idx]
        
        for idx in indices[1:]:
            iou = compute_iou(current_mask, masks[idx])
            ious.append(iou)
        
        # Filter out masks with IoU above threshold
        indices = indices[1:][np.array(ious) < iou_threshold]
    
    return kept_indices

def visualize_sample(model, dataset, idx, device, labels_file, save_path='mask2former_sample_segmentation.png'):
    # Get sample and prepare for model
    sample = dataset[idx]
    pixel_values = sample['pixel_values'].to(device)
    
    # Get image for visualization
    img_data = dataset.csv_data[idx]
    img_path = os.path.join(dataset.img_dir, f"{img_data['ImageId']}.jpg")
    image = Image.open(img_path).convert("RGB")
    img_width, img_height = image.size
    
    # Get model predictions
    model.eval()
    with torch.no_grad():
        category_logits, attribute_logits, mask_logits = model.predict(pixel_values)
    
    initial_top_k = 100  # Initial top-k since NMS will filter some masks

    # Process masks
    mask_probs = torch.sigmoid(mask_logits.squeeze(0))
    mask_confidence = mask_probs.mean(dim=[-1, -2])
    top_mask_indices = mask_confidence.argsort(descending=True)[:initial_top_k]
    mask_logits = mask_logits.squeeze(0)[top_mask_indices]

    mask_logits = F.interpolate(
        mask_logits.unsqueeze(1) if mask_logits.dim() == 3 else mask_logits,
        size=(img_height, img_width),
        mode='bilinear'
    )
    mask_logits = mask_logits.squeeze(1) if mask_logits.dim() == 4 else mask_logits
    predicted_masks = torch.sigmoid(mask_logits) > 0.5
    
    # Apply NMS to filter overlapping masks
    kept_indices = non_max_suppression(
        predicted_masks,
        mask_confidence[top_mask_indices],
        iou_threshold=0.5
    )
    
    # Update predictions based on NMS results
    predicted_masks = predicted_masks[kept_indices]
    top_mask_indices = top_mask_indices[kept_indices]
    
    # Limit to top_k after NMS if necessary
    final_top_k = 15
    if len(predicted_masks) > final_top_k:
        predicted_masks = predicted_masks[:final_top_k]
        top_mask_indices = top_mask_indices[:final_top_k]
    
    # Process categories
    category_logits = category_logits.squeeze(0)[top_mask_indices]
    category_probs = F.softmax(category_logits, dim=-1)
    predicted_categories = torch.argmax(category_probs, dim=-1)
    
    # Process attributes
    attribute_logits = attribute_logits.squeeze(0)[top_mask_indices]
    attribute_probs = torch.sigmoid(attribute_logits)
    ohe_predicted_attributes = (attribute_probs > 0.5).float()
    
    # Retrieve original category and attribute names
    category_map, attribute_map = load_category_attribute_names(labels_file)

    # Reverse attributes one hot encoding
    predicted_attribute_ids = []
    ohe_predicted_attributes = ohe_predicted_attributes.tolist()
    for ohe_attr in ohe_predicted_attributes:
        attr_list = []
        for i, value in enumerate(ohe_attr):
            if value == 1:
                attr_list.append(i)
        predicted_attribute_ids.append(attr_list)

    predicted_category_ids = dataset.category_mapper.inverse_transform(predicted_categories.tolist())
    predicted_category_names = [category_map[str(pred_cat_id)] for pred_cat_id in predicted_category_ids]

    predicted_attribute_ids = [dataset.attribute_mapper.inverse_transform(attr_list) for attr_list in predicted_attribute_ids]
    predicted_attribute_names = [[attribute_map[str(pred_attr_id)] for pred_attr_id in attr_list] for attr_list in predicted_attribute_ids]

    # Create figure with custom layout
    fig = plt.figure(figsize=(13, 7))
    gs = plt.GridSpec(1, 2)
    
    # Main image plot
    ax_main = fig.add_subplot(gs[0])
    ax_main.imshow(image)
    
    # Generate distinct colors for each mask using HSV color space
    num_masks = len(predicted_masks)
    colors = plt.cm.rainbow(np.linspace(0, 1, num_masks))
    
    # Plot masks on main image
    for mask, color in zip(predicted_masks, colors):
        mask_array = mask.cpu().numpy()
        mask_overlay = np.zeros((img_height, img_width, 4))
        mask_overlay[mask_array] = (*color[:3], 0.3)  # RGB + alpha
        ax_main.imshow(mask_overlay)
    
    ax_main.axis('off')
    ax_main.set_title(f'Image {idx}', pad=10)
    
    # Side panel for legend
    ax_legend = fig.add_subplot(gs[1])
    ax_legend.axis('off')
    
    # Create legend content
    legend_text = []
    for i, (category, attributes, color) in enumerate(zip(predicted_category_names, 
                                                        predicted_attribute_names, 
                                                        colors)):
        # Limit attributes to 5
        limited_attributes = attributes[:5]
        
        # Create a line-like object for the legend
        line = Line2D([0], [0], color=color, lw=4, alpha=0.3)
        ax_legend.add_line(line)
        legend_text.append(f"Category: {category} | Attributes: {', '.join(limited_attributes)}")

    # Create the legend
    ax_legend.legend(legend_text, loc='center right', frameon=False)
    
    # Save and show the visualization
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)

### Main

In [None]:
csv_file = 'fashionpedia_data/annotations.csv'
img_dir = 'fashionpedia_data/images'
category_mapping_file = 'fashionpedia_data/consecutive_category_mapping.json'
attribute_mapping_file = 'fashionpedia_data/consecutive_attribute_mapping.json'
labels_file = 'fashionpedia_data/labels.json'
checkpoint_dir = 'mask2former_checkpoints'

In [None]:
# Create dataset
dataset = FashionpediaDataset(csv_file, img_dir, category_mapping_file, attribute_mapping_file)
print(' - Dataset created.')

# Split dataset into train, validation, and test sets
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
print(' - Dataset split into train, val and test.')

# Create data loaders
collate_fn = partial(collate_fn)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

print(' - Train, val and test loaders initialized.')

In [None]:
# Initialize model
model = CustomMask2FormerForFashionpedia(num_categories=dataset.category_mapper.num_ids, num_attributes=dataset.attribute_mapper.num_ids).to(device)
print(' - CustomMask2FormerforFashionpedia model loaded.')

In [None]:
TRAIN = 1 # Toggle training

if TRAIN:
    checkpoint_id = 29
    num_epochs = 1
    gradient_accumulation_steps = 32
    
    # Set up optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    print(' - Optimizer set up.')

    print(' - Model training initiated.')
    train_model(model, train_loader, val_loader, optimizer, num_epochs, checkpoint_id, device, gradient_accumulation_steps=gradient_accumulation_steps)
    print(' - Model training completed.')
    print(f' - Checkpoint {checkpoint_id} saved to mask2former_checkpoints.')

In [None]:
# Plot training vs validation loss
plot_loss_curves(checkpoint_dir)
print(' - Loss curves plotted.')
print(' - Loss plot saved to mask2former_loss_plot.png')

In [None]:
EVALUATE = 1 # Toggle evaluation

if EVALUATE:
    # Load the best model
    model = load_best_model(model, checkpoint_dir)
    print(' - Best model loaded.')
    
    # Evaluate the model
    print(' - Model evaluation initiated.')
    category_f1, attribute_f1, miou = evaluate_model(model, test_loader, device)
    print(' - Model evaluation complete.')
    print(' - Results saved to mask2former_results.txt')

In [None]:
SAMPLE = 1 # Toggle visualising sample

if SAMPLE:
    # Load the best model
    model = load_best_model(model, checkpoint_dir)

    # Visualise a image segmentation sample
    print(' - Generating sample segmentation.')
    visualize_sample(model, dataset, idx=5999, device=device, labels_file=labels_file)