------------------------------------------------------------------------
------------------------------------------------------------------------

# Cell 1: Import Libraries and Setup

This cell sets up our deep learning environment by importing necessary libraries and checking GPU availability.

## Key Components:
1. **Core Libraries**:
   - `torch`: PyTorch deep learning framework
   - `transformers`: Hugging Face library for pre-trained models
   - `PIL`: Python Imaging Library for image processing
   - `numpy`: Numerical computing library
   - `pandas`: Data manipulation library

2. **CUDA Check**:
   - Verifies if GPU acceleration is available
   - Prints GPU device information if available

3. **Constants**:
   - `NUM_CATEGORIES = 30`: Number of clothing categories
   - `NUM_ATTRIBUTES = 341`: Number of possible attributes

## Why These Numbers?
- 30 categories cover main clothing types (shirts, pants, dresses, etc.)
- 341 attributes include various features like colors, patterns, materials

In [None]:
# Cell 1: Import Libraries and Test CUDA
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DetrImageProcessor, DetrForSegmentation
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt

# Test CUDA availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Constants
NUM_CATEGORIES = 30
NUM_ATTRIBUTES = 341

PyTorch version: 2.4.0
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4090


------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 2: Retrieve Data and Define Paths

This cell sets up file paths for our dataset components.

## Path Definitions:
1. `LABEL_FILE`: JSON file containing label descriptions
2. `CSV_FILE`: Training data annotations
3. `IMAGE_DIR`: Directory containing training images

## Important Note:
There are two methods for retrieval...

USE THE CORRECT METHOD:
* (1) Kaggle Retrieval
* (2) GPU Cluster Retrieval

### Use this for kaggle retrieval

In [None]:
# !mkdir ~/.kaggle #create the .kaggle folder in your root directory
# !echo '{"username":"YOUR_USERNAME","key":"YOUR_KEY}' > ~/.kaggle/kaggle.json #write kaggle API credentials to kaggle.json
  # !chmod 600 ~/.kaggle/kaggle.json  # set permissions
# !pip install kaggle #install the kaggle library

In [None]:
# !kaggle datasets list -s Fashionpediea

In [None]:
# !kaggle datasets download -d chiangkhenghe/fashionpediea-fyp-full-train -p /content/kaggle/

In [None]:
# ! unzip /content/kaggle/fashionpediea-fyp-full-train.zip -d /content/kaggle/

In [None]:
# Define Paths
LABEL_FILE = '/content/kaggle/label_description_v3.json'
CSV_FILE = '/content/kaggle/train_swapaholic_v0.csv/train_swapaholic_v0.csv'
IMAGE_DIR = '/content/kaggle/train'

### Use this when using GPU Cluster

In [None]:
# # Define Paths
# LABEL_FILE = r"./imaterialist-2020/label_description_v3.json"
# CSV_FILE = r"./imaterialist-2020/train_swapaholic_v0.csv"
# IMAGE_DIR = r"./imaterialist-2020/train"

------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 3: Category and Attribute Mappings

This cell defines the classification structure for our fashion items.

## Components:

1. **CATEGORY_MAPPING**:
   - Dictionary mapping numeric IDs to clothing categories
   - Example: `0: "shirt, blouse"`, `1: "top, t-shirt, sweatshirt"`
   - Total of 30 distinct categories

2. **SUPERCATEGORIES**:
   - Groups related categories together
   - Helps in hierarchical classification
   - Example: "Tops" includes shirts, t-shirts, and vests

## Usage:
- Used for converting between numeric IDs and human-readable labels
- Helps in organizing and analyzing model predictions
- Useful for generating reports and evaluating model performance

In [None]:
# Category and attribute mappings
CATEGORY_MAPPING = {
    0: "shirt, blouse",
    1: "top, t-shirt, sweatshirt",
    2: "sweater",
    3: "cardigan",
    4: "jacket",
    5: "vest",
    6: "pants",
    7: "shorts",
    8: "skirt",
    9: "coat",
    10: "dress",
    11: "jumpsuit",
    12: "cape",
    13: "glasses",
    14: "hat",
    15: "headband, head covering, hair accessory",
    16: "tie",
    17: "glove",
    18: "watch",
    19: "belt",
    20: "leg warmer",
    21: "tights, stockings",
    22: "sock",
    23: "shoe",
    24: "bag, wallet",
    25: "scarf",
    26: "umbrella",
    27: "hood",
    28: "epaulette",
    29: "bow"
}

# Group categories by supercategory
SUPERCATEGORIES = {
    "Tops": [0, 1, 5],
    "Outerwear": [2, 3, 9, 12],
    "Blazers & Jackets": [4],
    "Pants, Trousers, Leggings": [6],
    "Shorts": [7],
    "Skirts": [8],
    "Dress": [10],
    "Jumpsuits & Playsuits": [11],
    "Accessories": [28, 29]
}

------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 4: Data Loading and Verification

This cell loads and validates our dataset.

## Operations:
1. **Load CSV Data**:
   - Reads training data into pandas DataFrame
   - Displays basic dataset information

2. **Data Validation**:
   - Checks dataset shape (rows and columns)
   - Lists available columns
   - Shows sample data
   - Identifies missing values

## Why Important?
- Ensures data quality before training
- Helps identify potential issues early
- Provides dataset statistics for reference

In [None]:
# Cell 4: Load and Check Data
# Load your CSV file
df = pd.read_csv(CSV_FILE)  # Replace with your path

print("Dataset shape:", df.shape)
print("\nColumns:", df.columns.tolist())
print("\nSample data:")
print(df.head())

# Check for missing values
print("\nMissing values:")
print(df.isnull().sum())

Dataset shape: (90731, 6)

Columns: ['ImageId', 'EncodedPixels', 'Height', 'Width', 'CategoryId', 'AttributesIds']

Sample data:
                            ImageId  \
0  00000663ed1ff0c4e0132b9b9ac53f6e   
1  00000663ed1ff0c4e0132b9b9ac53f6e   
2  00000663ed1ff0c4e0132b9b9ac53f6e   
3  00000663ed1ff0c4e0132b9b9ac53f6e   
4  00000663ed1ff0c4e0132b9b9ac53f6e   

                                       EncodedPixels  Height  Width  \
0  6068157 7 6073371 20 6078584 34 6083797 48 608...    5214   3676   
1  6323163 11 6328356 32 6333549 53 6338742 75 63...    5214   3676   
2  8521389 10 8526585 30 8531789 42 8537002 46 85...    5214   3676   
3  6421446 292 6426657 298 6431867 305 6437078 31...    5214   3676   
4  4566382 8 4571592 25 4576803 41 4582013 58 458...    5214   3676   

   CategoryId                       AttributesIds  
0           6     115,136,143,154,230,295,316,317  
1           0     115,136,142,146,225,295,316,317  
2          28                                 163  
3

In [None]:
# # Cell 3: Test Image Loading and RLE Decoding
# def test_rle_decode(mask_rle, shape):
#     """Test RLE decoding for a single mask"""
#     if pd.isna(mask_rle):
#         return np.zeros(shape, dtype=np.bool_)

#     s = mask_rle.split()
#     starts = np.array(s[0::2], dtype=np.int64) - 1
#     lengths = np.array(s[1::2], dtype=np.int64)
#     mask = np.zeros(shape[0] * shape[1], dtype=np.bool_)

#     for start, length in zip(starts, lengths):
#         if start + length <= len(mask):
#             mask[start:start + length] = True

#     return mask.reshape(shape, order='F')

# # Test with first image
# first_image_id = df['ImageId'].iloc[0]  # Adjust based on your column name
# first_mask_rle = df['EncodedPixels'].iloc[0]
# first_category = df['CategoryId'].iloc[0]
# first_attributes = df['AttributesIds'].iloc[0]

# # Load and display image
# first_image_path = f"{IMAGE_DIR}/{first_image_id}.jpg"
# image = Image.open(first_image_path).convert('RGB')
# width, height = image.size
# print(f"Image dimensions: {width}x{height}")

# # Test mask decoding
# mask = test_rle_decode(first_mask_rle, (height, width))
# print(f"Mask shape: {mask.shape}")
# print(f"Mask coverage: {mask.sum() / (height * width) * 100:.2f}%")

# # Visualize
# plt.figure(figsize=(15, 5))
# plt.subplot(131)
# plt.imshow(image)
# plt.title("Original Image")
# plt.axis('off')

# plt.subplot(132)
# plt.imshow(mask, cmap='gray')
# plt.title("Decoded Mask")
# plt.axis('off')

# plt.subplot(133)
# plt.imshow(image)
# plt.imshow(mask, alpha=0.5, cmap='Reds')
# plt.title("Overlay")
# plt.axis('off')

# plt.tight_layout()
# plt.show()

------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 5: Dataset Class Implementation

This cell defines our custom dataset class for handling fashion images and their labels.

## Class Overview: `FashionMultiTaskDataset`

## Key Features:

1. **Data Augmentation**:
   - Random rotations and flips
   - Brightness/contrast adjustments
   - Random cropping
   - Normalization

2. **Label Processing**:
   - Converts categories to one-hot encoding
   - Processes attributes as multi-hot encoding
   - Decodes RLE-encoded segmentation masks

3. **Image Processing**:
   - Resizes images to 800x800 pixels
   - Applies transformations consistently
   - Handles both training and validation modes

## Methods:
- `__init__`: Initializes dataset with paths and options
- `rle_decode`: Converts RLE masks to binary format
- `__getitem__`: Retrieves and processes single items
- `__len__`: Returns dataset size

## Usage:
Used for both training and validation data loading with different augmentation settings

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import DetrImageProcessor
from PIL import Image
import torch.nn as nn
import pandas as pd

class FashionMultiTaskDataset(Dataset):
    def __init__(self, image_paths, masks, categories, attributes, image_processor, augment=False):
        """
        Initialize the dataset.

        Args:
            image_paths (list): List of paths to images
            masks (list): List of RLE encoded masks
            categories (list): List of category IDs
            attributes (list): List of attribute IDs (comma-separated strings)
            image_processor: DETR image processor
            augment (bool): Whether to apply data augmentation
        """
        self.image_paths = image_paths
        self.masks = masks
        self.categories = categories
        self.attributes = attributes
        self.image_processor = image_processor
        self.augment = augment
        self.target_size = (800, 800)  # DETR default size

        if augment:
            self.transform = A.Compose([
                A.RandomRotate90(p=0.5),
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.2),
                A.RandomResizedCrop(
                    height=self.target_size[0],
                    width=self.target_size[1],
                    scale=(0.8, 1.0)
                ),
                A.Normalize(
                    mean=image_processor.image_mean,
                    std=image_processor.image_std
                ),
                ToTensorV2()
            ])
        else:
            self.transform = A.Compose([
                A.Resize(
                    height=self.target_size[0],
                    width=self.target_size[1]
                ),
                A.Normalize(
                    mean=image_processor.image_mean,
                    std=image_processor.image_std
                ),
                ToTensorV2()
            ])

    def __len__(self):
        return len(self.image_paths)

    def rle_decode(self, mask_rle, shape):
        """
        Decode RLE encoded mask.

        Args:
            mask_rle (str): Run-length encoded mask
            shape (tuple): Image shape (height, width)

        Returns:
            np.ndarray: Binary mask
        """
        if pd.isna(mask_rle):
            return np.zeros(shape, dtype=np.bool_)

        s = mask_rle.split()
        starts = np.array(s[0::2], dtype=np.int64) - 1
        lengths = np.array(s[1::2], dtype=np.int64)
        mask = np.zeros(shape[0] * shape[1], dtype=np.bool_)

        for start, length in zip(starts, lengths):
            if start + length <= len(mask):
                mask[start:start + length] = True

        return mask.reshape(shape, order='F')

    def __getitem__(self, idx):
        """
        Get a single item from the dataset.

        Args:
            idx (int): Index

        Returns:
            dict: Dictionary containing:
                - pixel_values: Image tensor
                - pixel_mask: Attention mask
                - category_labels: Category label
                - attribute_labels: Attribute labels
                - mask_labels: Segmentation mask
        """
        # Load image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = image.resize(self.target_size, Image.BILINEAR)

        # Get original dimensions for mask decoding
        original_width, original_height = image.size

        # Decode mask
        mask = self.rle_decode(self.masks[idx], (original_height, original_width))
        mask = torch.from_numpy(mask).float()

        # Resize mask to target size
        mask = mask.unsqueeze(0)  # Add channel dimension
        mask = nn.functional.interpolate(
            mask.unsqueeze(0),  # Add batch dimension
            size=self.target_size,
            mode='nearest'
        ).squeeze(0).squeeze(0)  # Remove batch and channel dimensions

        # Process image with image processor
        encoding = self.image_processor(
            images=image,
            return_tensors="pt"
        )

        # Remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()

        # Convert category to one-hot
        category = torch.zeros(NUM_CATEGORIES)
        category[self.categories[idx]] = 1

        # Convert attributes to multi-hot
        attributes = torch.zeros(NUM_ATTRIBUTES)
        if isinstance(self.attributes[idx], str):
            attr_indices = [int(x) for x in self.attributes[idx].split(',')]
        else:
            attr_indices = self.attributes[idx]
        attributes[attr_indices] = 1

        # Add labels to encoding
        encoding['category_labels'] = category
        encoding['attribute_labels'] = attributes
        encoding['mask_labels'] = mask

        return encoding


def custom_collate_fn(batch):
    """
    Custom collate function for DataLoader.

    Args:
        batch (list): List of samples from dataset

    Returns:
        dict: Batched samples
    """
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    pixel_mask = torch.stack([item['pixel_mask'] for item in batch])
    category_labels = torch.stack([item['category_labels'] for item in batch])
    attribute_labels = torch.stack([item['attribute_labels'] for item in batch])
    mask_labels = torch.stack([item['mask_labels'] for item in batch])

    return {
        'pixel_values': pixel_values,
        'pixel_mask': pixel_mask,
        'category_labels': category_labels,
        'attribute_labels': attribute_labels,
        'mask_labels': mask_labels
    }

# Test Dataset
image_processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50-panoptic')

# Create small test dataset
image_paths = [f"{IMAGE_DIR}/{image_id}.jpg" for image_id in df['ImageId'].iloc[:5]]
test_dataset = FashionMultiTaskDataset(
    image_paths=image_paths,
    masks=df['EncodedPixels'].iloc[:5],
    categories=df['CategoryId'].iloc[:5],
    attributes=df['AttributesIds'].iloc[:5],
    image_processor=image_processor
)

# Test loading one item
test_item = test_dataset[0]
print("Dataset item keys:", test_item.keys())
print("\nShapes:")
for k, v in test_item.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")

Dataset item keys: dict_keys(['pixel_values', 'pixel_mask', 'category_labels', 'attribute_labels', 'mask_labels'])

Shapes:
pixel_values: torch.Size([3, 800, 800])
pixel_mask: torch.Size([800, 800])
category_labels: torch.Size([30])
attribute_labels: torch.Size([341])
mask_labels: torch.Size([800, 800])


------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 6: Model Architecture

This cell defines our core model architecture based on DETR (Detection Transformer).

## Class Overview: `FashionMultiTaskDETR`

## Architecture Components:

1. **Base Model**:
   - Uses pretrained DETR (facebook/detr-resnet-50-panoptic)
   - Leverages transformer architecture for image understanding
   - Hidden dimension: 256 (from DETR config)

2. **Category Head**:
   ```python
   self.category_head = nn.Sequential(
       nn.Linear(hidden_dim, hidden_dim),
       nn.BatchNorm1d(hidden_dim),
       nn.ReLU(),
       nn.Dropout(0.3),
       # ... additional layers
   )
   ```
   - Deep network for category classification
   - Uses batch normalization for stable training
   - Dropout (0.3) for regularization
   - Progressive dimension reduction

3. **Attribute Head**:
   - Simpler architecture for multi-label classification
   - Lower dropout (0.1) to maintain feature richness
   - Direct mapping to attribute space

4. **Forward Pass Logic**:
   - Processes images through DETR backbone
   - Extracts global features for classification
   - Generates predictions for all three tasks

## Why This Architecture?
- Balances complexity across tasks
- Prevents overfitting through regularization
- Maintains pretrained knowledge while adding fashion-specific capabilities

In [None]:
# Cell 6: Model Class and Test

class FashionMultiTaskDETR(nn.Module):
    def __init__(self, num_categories=NUM_CATEGORIES, num_attributes=NUM_ATTRIBUTES):
        super().__init__()

        self.detr = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
        hidden_dim = self.detr.config.d_model

        # Modified category head with batch norm and dropout
        self.category_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_categories)
        )

        self.attribute_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_attributes)
        )

    def forward(self, pixel_values, pixel_mask):
        # Get DETR outputs
        outputs = self.detr(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask
        )

        # Get last hidden state for classification tasks
        last_hidden_state = outputs.last_hidden_state  # Shape: [batch_size, 100, 256]
        global_feature = last_hidden_state[:, 0]  # Use first token as global feature

        # Task-specific predictions
        category_logits = self.category_head(global_feature)
        attribute_logits = self.attribute_head(global_feature)

        # Get segmentation predictions
        # pred_masks shape: [batch_size, num_queries, height, width]
        masks = outputs.pred_masks

        # Aggregate masks if needed (you might want to adjust this based on your needs)
        # Here we're taking the max across all queries
        segmentation_mask = torch.max(masks, dim=1)[0]  # Shape: [batch_size, height, width]

        return {
            'category_logits': category_logits,
            'attribute_logits': attribute_logits,
            'segmentation_masks': segmentation_mask
        }

# Test Model
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model and move to device
model = FashionMultiTaskDETR().to(device)

# Create sample batch
sample_batch = {
    'pixel_values': torch.randn(2, 3, 800, 800).to(device),
    'pixel_mask': torch.ones(2, 800, 800).to(device)
}

# Test forward pass
with torch.no_grad():
    outputs = model(**sample_batch)

print("\nModel output shapes:")
for k, v in outputs.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")

------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 7: Loss Function Implementation

This cell defines our multi-task loss function that balances learning across different objectives.

## Class Overview: `MultitaskLoss`

## Loss Components:

1. **Learnable Task Weights**:
   ```python
   self.log_vars = nn.Parameter(torch.FloatTensor([-0.5, 0.0, 0.5]))
   ```
   - Automatically learns importance of each task
   - Converts to weights through exponential function
   - Initialized with different values for each task

2. **Category Loss**:
   - Binary Cross Entropy with Logits
   - Enhanced positive weighting (5.0)
   - Handles class imbalance

3. **Attribute Loss**:
   - Similar to category loss but with lower positive weight (2.0)
   - Designed for multi-label scenario
   - Balanced for attribute frequency

4. **Segmentation Loss**:
   - Standard BCE loss for pixel-wise predictions
   - Equal weighting for all pixels
   - Includes regularization term

## Loss Calculation:
- Combines all three losses with learned weights
- Includes temperature scaling for better calibration
- Adds L2 regularization for stability

## Why This Approach?
- Automatic task balancing
- Handles different scales of losses
- Adapts to changing task difficulties during training

In [None]:
# Cell 6: Loss Function
class MultitaskLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize log variances with different values
        self.log_vars = nn.Parameter(torch.FloatTensor([-0.5, 0.0, 0.5]))

        # Modify BCE loss for categories with stronger positive weighting
        self.category_loss_fn = nn.BCEWithLogitsLoss(
            pos_weight=torch.ones(NUM_CATEGORIES) * 5.0,  # Increase positive weight
            reduction='none'
        )

        # Rest remains the same
        self.attribute_loss_fn = nn.BCEWithLogitsLoss(
            pos_weight=torch.ones(NUM_ATTRIBUTES) * 2.0,
            reduction='none'
        )
        self.segmentation_loss_fn = nn.BCEWithLogitsLoss(reduction='none')

    def to(self, device):
        super().to(device)
        self.attribute_loss_fn.pos_weight = self.attribute_loss_fn.pos_weight.to(device)
        return self

    def forward(self, outputs, targets):
        device = outputs['category_logits'].device

        # Add L2 regularization for category predictions
        l2_reg = 0.01 * torch.norm(outputs['category_logits'])

        # Calculate losses with temperature scaling
        temperature = 2.0
        category_loss = self.category_loss_fn(
            outputs['category_logits'] / temperature,
            targets['category_labels'].to(device).float() * 0.9
        ).mean() + l2_reg

        # Calculate raw losses first
        category_loss = self.category_loss_fn(
            outputs['category_logits'],
            targets['category_labels'].to(device).float() * 0.9
        ).mean()

        attribute_loss = self.attribute_loss_fn(
            outputs['attribute_logits'],
            targets['attribute_labels'].to(device)
        ).mean()

        pred_masks = outputs['segmentation_masks']
        target_masks = targets['mask_labels'].to(device).float()
        target_masks = nn.functional.interpolate(
            target_masks.unsqueeze(1),
            size=pred_masks.shape[-2:],
            mode='nearest'
        ).squeeze(1)

        segmentation_loss = self.segmentation_loss_fn(
            pred_masks,
            target_masks
        ).mean()

        # Calculate precision terms
        precision_category = torch.exp(-self.log_vars[0])
        precision_attribute = torch.exp(-self.log_vars[1])
        precision_segmentation = torch.exp(-self.log_vars[2])

        # Calculate weighted losses
        weighted_category_loss = precision_category * category_loss + 0.5 * self.log_vars[0]
        weighted_attribute_loss = precision_attribute * attribute_loss + 0.5 * self.log_vars[1]
        weighted_segmentation_loss = precision_segmentation * segmentation_loss + 0.5 * self.log_vars[2]

        # Store all components
        losses = {
            'total_loss': weighted_category_loss + weighted_attribute_loss + weighted_segmentation_loss,
            'raw_category_loss': category_loss.item(),
            'raw_attribute_loss': attribute_loss.item(),
            'raw_segmentation_loss': segmentation_loss.item(),
            'category_weight': precision_category.item(),
            'attribute_weight': precision_attribute.item(),
            'segmentation_weight': precision_segmentation.item()
        }

        return losses

------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 8: MultitaskTrainer Implementation

This cell implements a comprehensive training framework for our multi-task fashion model.

## Class Overview: `MultitaskTrainer`

### Initialization Parameters
```python
def __init__(
    self,
    model,
    train_dataset,
    val_dataset,
    batch_size=4,
    learning_rate=2e-4,
    num_epochs=20,
    device=None,
    patience=5,
    max_grad_norm=1.0,
    warmup_ratio=0.1,
    checkpoint_dir=None,
    logger=None,
    gradient_accumulation_steps=4
):
```

### Key Components:

1. **Memory Management**:
   - GPU cache clearing
   - Memory-efficient DataLoaders
   - Gradient accumulation for large batches
   ```python
   if torch.cuda.is_available():
       torch.cuda.empty_cache()
   ```

2. **Optimizer Configuration**:
   - Three parameter groups with different learning rates:
     - DETR backbone: 0.05× base learning rate
     - Category head: 2.0× base learning rate
     - Other components: Base learning rate
   ```python
   param_dicts = [
       {"params": [...], "lr": learning_rate * 0.05},  # DETR backbone
       {"params": [...], "lr": learning_rate * 2.0},   # Category head
       {"params": [...], "lr": learning_rate}          # Others
   ]
   ```

3. **Training Features**:
   - Checkpoint saving/loading
   - Early stopping
   - Learning rate scheduling
   - Gradient clipping
   - Comprehensive metrics tracking

### Key Methods:

1. **train()**:
   - Main training loop
   - Epoch-level training
   - Validation
   - Metrics logging
   - Checkpoint management

2. **train_epoch()**:
   - Single epoch training
   - Gradient accumulation
   - Memory optimization
   - Progress tracking

3. **evaluate()**:
   - Validation loop
   - Metrics computation
   - No gradient calculation

## Usage Example:
```python
trainer = MultitaskTrainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=4,
    learning_rate=2e-4,
    num_epochs=20
)
history = trainer.train()
```

In [None]:
from transformers import get_linear_schedule_with_warmup
from collections import defaultdict

# Cell 8: Updated Training Setup with Metrics
class MultitaskTrainer:
    def __init__(
        self,
        model,
        train_dataset,
        val_dataset,
        batch_size=4,
        learning_rate=2e-4,
        num_epochs=20,
        device=None,
        patience=5,
        max_grad_norm=1.0,
        warmup_ratio=0.1,
        checkpoint_dir=None,
        logger=None,
        gradient_accumulation_steps=4
    ):
        """
        Initialize the trainer with memory optimization.
        """
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.patience = patience
        self.max_grad_norm = max_grad_norm
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.best_val_loss = float('inf')
        self.patience_counter = 0
        self.checkpoint_dir = checkpoint_dir
        self.logger = logger or logging.getLogger(__name__)

        # Clear GPU cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            self.logger.info(f"GPU Memory After Cache Clear: {torch.cuda.memory_allocated()/1024**2:.2f}MB")

        # DataLoaders with pin_memory=False to reduce memory usage
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=False,
            collate_fn=custom_collate_fn
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            num_workers=2,
            pin_memory=False,
            collate_fn=custom_collate_fn
        )

        # Loss and Optimizer
        self.criterion = MultitaskLoss().to(self.device)

        # Parameter groups with different learning rates
        param_dicts = [
            {
                "params": [p for n, p in model.named_parameters()
                          if "detr" in n and p.requires_grad],
                "lr": learning_rate * 0.05  # Reduce backbone learning rate
            },
            {
                "params": [p for n, p in model.named_parameters()
                          if "category_head" in n and p.requires_grad],
                "lr": learning_rate * 2.0  # Increase category head learning rate
            },
            {
                "params": [p for n, p in model.named_parameters()
                          if "detr" not in n and "category_head" not in n and p.requires_grad],
                "lr": learning_rate
            },
        ]

        # Update optimizer with weight decay
        self.optimizer = torch.optim.AdamW(
            param_dicts,
            weight_decay=0.05  # Increase weight decay
        )

        # Learning rate scheduler with warmup
        num_training_steps = self.num_epochs * len(self.train_loader)
        num_warmup_steps = int(num_training_steps * warmup_ratio)

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

        # Initialize metrics tracking
        self.train_metrics_history = []
        self.val_metrics_history = []

        # Log training setup
        self.logger.info(f"\nTraining Setup:")
        self.logger.info(f"Batch Size: {batch_size}")
        self.logger.info(f"Learning Rate: {learning_rate}")
        self.logger.info(f"Number of Epochs: {num_epochs}")
        self.logger.info(f"Training Steps per Epoch: {len(self.train_loader)}")
        self.logger.info(f"Total Training Steps: {num_training_steps}")
        self.logger.info(f"Warmup Steps: {num_warmup_steps}")

    def save_checkpoint(self, epoch, metrics, is_best=False):
        """
        Save training checkpoint

        Args:
            epoch: Current epoch number
            metrics: Dictionary of current metrics
            is_best: Whether this checkpoint has the best validation loss
        """
        if self.checkpoint_dir is None:
            return

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'metrics': metrics,
            'best_val_loss': self.best_val_loss
        }

        # Save latest checkpoint
        checkpoint_path = f'{self.checkpoint_dir}/checkpoint_latest.pth'
        torch.save(checkpoint, checkpoint_path)
        self.logger.info(f"Saved checkpoint: {checkpoint_path}")

        # Save best model separately
        if is_best:
            best_model_path = f'{self.checkpoint_dir}/model_best.pth'
            torch.save(checkpoint, best_model_path)
            self.logger.info(f"Saved best model: {best_model_path}")

        # Save epoch-specific checkpoint periodically
        if (epoch + 1) % 5 == 0:  # Save every 5 epochs
            epoch_checkpoint_path = f'{self.checkpoint_dir}/checkpoint_epoch_{epoch+1}.pth'
            torch.save(checkpoint, epoch_checkpoint_path)
            self.logger.info(f"Saved epoch checkpoint: {epoch_checkpoint_path}")

    def load_checkpoint(self, checkpoint_path):
        """
        Load a training checkpoint

        Args:
            checkpoint_path: Path to the checkpoint file
        """
        self.logger.info(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_val_loss = checkpoint['best_val_loss']

        return checkpoint['epoch']

    def train(self):
        """
        Main training loop with enhanced logging
        """
        train_losses = []
        val_losses = []

        for epoch in range(self.num_epochs):
            self.logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")

            # Training phase
            train_loss, train_metrics = self.train_epoch()
            train_losses.append(train_loss)
            self.train_metrics_history.append(train_metrics)

            # Validation phase
            val_loss, val_metrics = self.evaluate()
            val_losses.append(val_loss)
            self.val_metrics_history.append(val_metrics)

            # Enhanced logging
            self.logger.info("\nTraining Metrics:")
            self.logger.info(f"Total Loss: {train_loss:.4f}")
            self.logger.info("Task Losses:")
            for task in ['category', 'attribute', 'segmentation']:
                self.logger.info(f"  {task}: {train_metrics[f'{task}_loss']:.4f} "
                               f"(weight: {train_metrics[f'{task}_weight']:.4f})")
            self.logger.info("Performance Metrics:")
            for k, v in train_metrics.items():
                if not (k.endswith('_loss') or k.endswith('_weight')):
                    self.logger.info(f"  {k}: {v:.4f}")

            # Log metrics
            self.logger.info(f"Train Loss: {train_loss:.4f}")
            self.logger.info("Train Metrics: %s",
                           {k: f"{v:.4f}" for k, v in train_metrics.items()})
            self.logger.info(f"Val Loss: {val_loss:.4f}")
            self.logger.info("Val Metrics: %s",
                           {k: f"{v:.4f}" for k, v in val_metrics.items()})
            self.logger.info(f"Learning Rate: {self.scheduler.get_last_lr()[0]:.6f}")

            # Save checkpoint
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
                self.patience_counter = 0
            else:
                self.patience_counter += 1

            self.save_checkpoint(
                epoch=epoch,
                metrics={'train': train_metrics, 'val': val_metrics},
                is_best=is_best
            )

            # Early stopping
            if self.patience_counter >= self.patience:
                self.logger.info(f"\nEarly stopping triggered after {epoch+1} epochs")
                break

        return {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_metrics_history': self.train_metrics_history,
            'val_metrics_history': self.val_metrics_history
        }

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        total_metrics = defaultdict(float)
        total_raw_losses = defaultdict(float)
        total_weights = defaultdict(float)

        progress_bar = tqdm(self.train_loader, desc='Training')
        self.optimizer.zero_grad()

        for batch_idx, batch in enumerate(progress_bar):
            batch = {k: v.to(self.device) for k, v in batch.items()}
            outputs = self.model(
                pixel_values=batch['pixel_values'],
                pixel_mask=batch['pixel_mask']
            )

            losses = self.criterion(outputs, batch)
            loss = losses['total_loss'] / self.gradient_accumulation_steps

            # Track raw losses and weights
            for task in ['category', 'attribute', 'segmentation']:
                total_raw_losses[f'{task}_loss'] += losses[f'raw_{task}_loss']
                total_weights[f'{task}_weight'] += losses[f'{task}_weight']

            loss.backward()

            metrics = self.compute_metrics(outputs, batch)
            total_loss += losses['total_loss'].item()

            for k, v in metrics.items():
                total_metrics[k] += v

            if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            if batch_idx % 500 == 0:
                self.logger.info(f"\nStep {batch_idx} log_vars: {self.criterion.log_vars.data.tolist()}")
                self.logger.info(f"Category pred mean: {metrics['cat_pred_mean']:.3f}, "
                                f"target mean: {metrics['cat_target_mean']:.3f}")

            # Update progress bar
            progress_bar.set_postfix({
                'total': f"{losses['total_loss']:.3f}",
                'raw_losses': f"c:{losses['raw_category_loss']:.3f}/a:{losses['raw_attribute_loss']:.3f}/s:{losses['raw_segmentation_loss']:.3f}",
                'weights': f"c:{losses['category_weight']:.3f}/a:{losses['attribute_weight']:.3f}/s:{losses['segmentation_weight']:.3f}",
                'metrics': f"c:{metrics['category_dice']:.3f}({metrics['cat_pred_mean']:.2f})/a:{metrics['attribute_dice']:.3f}/s:{metrics['segmentation_miou']:.3f}"
            })

            del outputs, losses, loss
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        num_batches = len(self.train_loader)
        avg_metrics = {k: v / num_batches for k, v in total_metrics.items()}
        avg_raw_losses = {k: v / num_batches for k, v in total_raw_losses.items()}
        avg_weights = {k: v / num_batches for k, v in total_weights.items()}

        # Combine all metrics
        avg_metrics.update(avg_raw_losses)
        avg_metrics.update(avg_weights)

        return total_loss / num_batches, avg_metrics

    def evaluate(self):
        """
        Validation loop with enhanced logging
        """
        self.model.eval()
        total_loss = 0
        total_metrics = defaultdict(float)
        total_task_losses = defaultdict(float)
        total_weights = defaultdict(float)

        with torch.no_grad():
            progress_bar = tqdm(self.val_loader, desc='Validating')
            for batch in progress_bar:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(
                    pixel_values=batch['pixel_values'],
                    pixel_mask=batch['pixel_mask']
                )

                losses = self.criterion(outputs, batch)
                total_loss += losses['total_loss'].item()

                metrics = self.compute_metrics(outputs, batch)
                for k, v in metrics.items():
                    total_metrics[k] += v

                # Track individual losses and weights
                for k in ['category', 'attribute', 'segmentation']:
                    total_task_losses[f'{k}_loss'] += losses[f'raw_{k}_loss']
                    total_weights[f'{k}_weight'] += losses[f'{k}_weight']

                progress_bar.set_postfix({
                    'total': f"{losses['total_loss']:.3f}",
                    'cat': f"{losses['raw_category_loss']:.3f}({losses['category_weight']:.2f})",
                    'attr': f"{losses['raw_attribute_loss']:.3f}({losses['attribute_weight']:.2f})",
                    'seg': f"{losses['raw_segmentation_loss']:.3f}({losses['segmentation_weight']:.2f})",
                    'metrics': ', '.join([f"{k[:3]}: {v:.3f}" for k, v in metrics.items()])
                })

            num_batches = len(self.val_loader)
            avg_metrics = {k: v / num_batches for k, v in total_metrics.items()}
            avg_losses = {k: v / num_batches for k, v in total_task_losses.items()}
            avg_weights = {k: v / num_batches for k, v in total_weights.items()}

            # Add losses and weights to metrics for logging
            avg_metrics.update(avg_losses)
            avg_metrics.update(avg_weights)

            return total_loss / num_batches, avg_metrics

    def compute_metrics(self, outputs, targets):
        device = outputs['category_logits'].device
        metrics = {}

        # Category Dice Score
        cat_probs = torch.sigmoid(outputs['category_logits'])
        cat_preds = (cat_probs > 0.5).float()
        cat_targets = targets['category_labels'].to(device).float()

        # Calculate Dice score per sample and then average
        cat_intersection = (cat_preds * cat_targets).sum(dim=1)
        cat_union = cat_preds.sum(dim=1) + cat_targets.sum(dim=1)
        cat_dice = (2 * cat_intersection + 1e-8) / (cat_union + 1e-8)
        metrics['category_dice'] = cat_dice.mean().item()

        # Add prediction statistics for debugging
        metrics['cat_pred_mean'] = cat_preds.mean().item()
        metrics['cat_target_mean'] = cat_targets.mean().item()

        # Attribute Dice Score (no changes needed)
        attr_probs = torch.sigmoid(outputs['attribute_logits'])
        attr_preds = (attr_probs > 0.5).float()
        attr_targets = targets['attribute_labels'].to(device).float()

        attr_intersection = (attr_preds * attr_targets).sum(dim=1)
        attr_union = attr_preds.sum(dim=1) + attr_targets.sum(dim=1)
        attr_dice = (2 * attr_intersection + 1e-8) / (attr_union + 1e-8)
        metrics['attribute_dice'] = attr_dice.mean().item()

        # Segmentation mIOU (no changes needed)
        pred_masks = torch.sigmoid(outputs['segmentation_masks']) > 0.5
        target_masks = targets['mask_labels'].to(device)

        if pred_masks.shape != target_masks.shape:
            target_masks = nn.functional.interpolate(
                target_masks.unsqueeze(1).float(),
                size=pred_masks.shape[-2:],
                mode='nearest'
            ).squeeze(1)

        intersection = (pred_masks & target_masks.bool()).float().sum((1, 2))
        union = (pred_masks | target_masks.bool()).float().sum((1, 2))
        batch_ious = (intersection + 1e-8) / (union + 1e-8)
        metrics['segmentation_miou'] = batch_ious.mean().item()

        return metrics

------------------------------------------------------------------------
------------------------------------------------------------------------
# Cell 9: Training Utilities and Pipeline

This cell implements supporting functions for training visualization, logging, and pipeline execution.

## Key Components:

### 1. Training History Visualization
```python
def plot_training_history(history):
    """Creates comprehensive training progress plots"""
```
Features:
- Multiple subplots for each metric
- Loss curves
- Training vs validation comparison
- Timestamp-labeled plots
- Grid and legends for clarity

### 2. Logging Setup
```python
def setup_logging(experiment_name):
    """Configures logging system"""
```
Features:
- Rotating file handler (10MB limit)
- Console output
- Timestamp-based log files
- Formatted log messages

### 3. Training Pipeline
```python
def test_training_pipeline():
    """Main training execution function"""
```

#### Pipeline Steps:
1. **Setup**:
   - Create experiment directories
   - Initialize logging
   - Set up checkpointing

2. **Data Preparation**:
   - Train/validation split (80/20)
   - Dataset creation
   - Data validation checks

3. **Training Execution**:
   - Model initialization
   - Training loop
   - Progress monitoring
   - Error handling

4. **Results Processing**:
   - Metrics logging
   - CSV export
   - Plot generation
   - Memory usage tracking

### Error Handling:
```python
try:
    history, trainer = test_training_pipeline()
except Exception as e:
    # Comprehensive error logging
    # State saving for debugging
```

## Usage:
```python
if __name__ == "__main__":
    torch.manual_seed(42)  # Reproducibility
    history, trainer = test_training_pipeline()
```

## Output Files:
1. Training logs (`logs/experiment_name_timestamp.log`)
2. Checkpoints (`DETR_CHECKPOINTS/experiment_name/`)
3. Metrics CSV (`metrics.csv`)
4. Training plots (`training_history.png`)
5. Error states (if any) (`error_state.pth`)

In [None]:
import os
import logging
import numpy as np
import torch
import matplotlib.pyplot as plt
from datetime import datetime
import traceback
import pandas as pd
from logging.handlers import RotatingFileHandler
import time
from tqdm import tqdm
from sklearn.model_selection import train_test_split

def plot_training_history(history):
    """
    Enhanced plotting function for training metrics
    """
    plt.style.use('seaborn')

    # Create subplots for each metric
    metrics = list(history['train_metrics_history'][0].keys())
    num_plots = len(metrics) + 1  # +1 for loss plot
    fig, axes = plt.subplots(num_plots, 1, figsize=(12, 4*num_plots))

    # Add overall title with timestamp
    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    fig.suptitle(f'Training Progress\n{timestamp}', fontsize=16)

    # Plot losses
    axes[0].plot(history['train_losses'], 'b-o', label='Train Loss', markersize=4)
    axes[0].plot(history['val_losses'], 'r-o', label='Val Loss', markersize=4)
    axes[0].set_title('Training and Validation Losses')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend(loc='upper right')

    # Plot each metric
    for i, metric in enumerate(metrics, 1):
        train_metric = [metrics[metric] for metrics in history['train_metrics_history']]
        val_metric = [metrics[metric] for metrics in history['val_metrics_history']]

        axes[i].plot(train_metric, 'b-o', label=f'Train {metric}', markersize=4)
        axes[i].plot(val_metric, 'r-o', label=f'Val {metric}', markersize=4)
        axes[i].set_title(f'Training and Validation {metric}')
        axes[i].set_xlabel('Epoch')
        axes[i].set_ylabel(metric)
        axes[i].grid(True, alpha=0.3)
        axes[i].legend(loc='upper right')

    plt.tight_layout()
    return fig

def setup_logging(experiment_name):
    """Setup logging configuration"""
    # Create logs directory if it doesn't exist
    os.makedirs('logs', exist_ok=True)

    # Create timestamp for unique log file
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_filename = f'logs/{experiment_name}_{timestamp}.log'

    # Add rotating file handler
    file_handler = RotatingFileHandler(
        log_filename,
        maxBytes=10*1024*1024,  # 10MB
        backupCount=5
    )

    console_handler = logging.StreamHandler()

    # Enhanced formatting
    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    return logger

def test_training_pipeline():
    """
    Enhanced test pipeline with full dataset, checkpointing and logging
    """
    # Setup experiment name and logging
    experiment_name = 'fashion_multitask_full'
    logger = setup_logging(experiment_name)

    # Create checkpoint directory
    checkpoint_dir = f'DETR_CHECKPOINTS/{experiment_name}'
    os.makedirs(checkpoint_dir, exist_ok=True)

    logger.info("Starting training pipeline")
    logger.info("=" * 50)

    # Log GPU information
    if torch.cuda.is_available():
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"Initial GPU Memory: {torch.cuda.memory_allocated()/1024**2:.2f}MB")

    # Create train/val split
    train_df, val_df = train_test_split(
        df,
        test_size=0.2,  # 20% for validation
        random_state=42,
        shuffle=True
    )

    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)

    # Validate dataset sizes
    assert len(train_df) > 0, "Training dataset is empty"
    assert len(val_df) > 0, "Validation dataset is empty"

    # Print dataset information
    logger.info("Dataset Information:")
    logger.info("-" * 50)
    logger.info(f"Training data shape: {train_df.shape}")
    logger.info(f"Validation data shape: {val_df.shape}")
    logger.info("-" * 50)

    # Create datasets
    train_dataset = FashionMultiTaskDataset(
        image_paths=[f"{IMAGE_DIR}/{image_id}.jpg" for image_id in train_df['ImageId']],
        masks=train_df['EncodedPixels'],
        categories=train_df['CategoryId'],
        attributes=train_df['AttributesIds'],
        image_processor=image_processor,
        augment=True
    )

    val_dataset = FashionMultiTaskDataset(
        image_paths=[f"{IMAGE_DIR}/{image_id}.jpg" for image_id in val_df['ImageId']],
        masks=val_df['EncodedPixels'],
        categories=val_df['CategoryId'],
        attributes=val_df['AttributesIds'],
        image_processor=image_processor,
        augment=False
    )


    # Log model information
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info("\nModel Information:")
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")

    # Initialize trainer with smaller batch size and gradient accumulation
    trainer = MultitaskTrainer(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        batch_size=4,  # Reduced from 16
        num_epochs=5,
        learning_rate=2e-4,
        patience=5,
        max_grad_norm=1.0,
        warmup_ratio=0.1,
        checkpoint_dir=checkpoint_dir,
        logger=logger,
        gradient_accumulation_steps=4
    )

    try:
        logger.info("\nStarting training...")
        logger.info("=" * 50)

        # Record start time
        start_time = time.time()

        # Train the model
        history = trainer.train()

        # Calculate training time
        training_time = time.time() - start_time

        logger.info("\nTraining completed successfully!")
        logger.info("=" * 50)
        logger.info(f"Total training time: {training_time/3600:.2f} hours")

        # Log final metrics
        logger.info("\nFinal Training Metrics:")
        for metric, value in history['train_metrics_history'][-1].items():
            logger.info(f"{metric}: {value:.4f}")

        logger.info("\nFinal Validation Metrics:")
        for metric, value in history['val_metrics_history'][-1].items():
            logger.info(f"{metric}: {value:.4f}")

        # Save metrics to CSV
        metrics_df = pd.DataFrame({
            'epoch': range(len(history['train_losses'])),
            'train_loss': history['train_losses'],
            'val_loss': history['val_losses'],
            **{f'train_{k}': [h[k] for h in history['train_metrics_history']]
               for k in history['train_metrics_history'][0].keys()},
            **{f'val_{k}': [h[k] for h in history['val_metrics_history']]
               for k in history['val_metrics_history'][0].keys()}
        })
        metrics_df.to_csv(f'{checkpoint_dir}/metrics.csv', index=False)

        # Plot and save training history
        fig = plot_training_history(history)
        fig.savefig(f'{checkpoint_dir}/training_history.png')
        plt.close(fig)

        # Log final GPU memory usage
        if torch.cuda.is_available():
            logger.info(f"Final GPU Memory: {torch.cuda.memory_allocated()/1024**2:.2f}MB")

        return history, trainer

    except Exception as e:
        logger.error(f"\nError during training: {str(e)}")
        logger.error(traceback.format_exc())

        # Save error state
        try:
            torch.save({
                'error_state': {
                    'last_batch': batch if 'batch' in locals() else None,
                    'last_outputs': outputs if 'outputs' in locals() else None,
                    'model_state': model.state_dict()
                }
            }, f'{checkpoint_dir}/error_state.pth')
            logger.info("Error state saved successfully")
        except Exception as save_error:
            logger.error(f"Failed to save error state: {str(save_error)}")

        raise e

# Run the test pipeline
if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    try:
        history, trainer = test_training_pipeline()
    except Exception as e:
        print(f"Training failed: {str(e)}")
        raise e

2024-10-30 04:02:54 - INFO - Starting training pipeline
2024-10-30 04:02:54 - INFO - GPU: NVIDIA GeForce RTX 4090
2024-10-30 04:02:54 - INFO - Initial GPU Memory: 196.52MB
2024-10-30 04:02:54 - INFO - Dataset Information:
2024-10-30 04:02:54 - INFO - --------------------------------------------------
2024-10-30 04:02:54 - INFO - Training data shape: (72584, 6)
2024-10-30 04:02:54 - INFO - Validation data shape: (18147, 6)
2024-10-30 04:02:54 - INFO - --------------------------------------------------
2024-10-30 04:02:54 - INFO - 
Model Information:
2024-10-30 04:02:54 - INFO - Total parameters: 43,156,043
2024-10-30 04:02:54 - INFO - Trainable parameters: 42,933,643
2024-10-30 04:02:54 - INFO - GPU Memory After Cache Clear: 196.52MB
2024-10-30 04:02:54 - INFO - 
Training Setup:
2024-10-30 04:02:54 - INFO - Batch Size: 4
2024-10-30 04:02:54 - INFO - Learning Rate: 0.0002
2024-10-30 04:02:54 - INFO - Number of Epochs: 5
2024-10-30 04:02:54 - INFO - Training Steps per Epoch: 18146
2024-10

Training failed: 'seaborn' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)


OSError: 'seaborn' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)