# Financial Statements Page Classification - Data Preprocessing

## Overview
This notebook preprocesses the financial statement pages for training three different computer vision models:

1. **ResNet50** - Classic CNN with transfer learning (224x224)
2. **EfficientNet-B2** - Modern efficient architecture (260x260)
3. **Vision Transformer (ViT-Base)** - Transformer-based approach (224x224)

### Dataset Statistics:
- Total pages: 1,179
- Average dimensions: 1277 x 1692 pixels
- 5 Target Classes (with imbalance):
  - Notes (Tabular): 557 (47.2%)
  - Notes (Text): 321 (27.2%)
  - Financial Sheets: 124 (10.5%)
  - Independent Auditor's Report: 111 (9.4%)
  - Other Pages: 66 (5.6%)

### Remaining Preprocessing Steps:
1. ~~Extract pages from PDFs~~ ‚úÖ Already done
2. ~~Create labels~~ ‚úÖ Already done
3. **Train/Val/Test split (stratified)**
4. **Model-specific transformations**
5. **Data augmentation**
6. **Create PyTorch DataLoaders**

---
## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q torch torchvision
!pip install -q timm  # For EfficientNet and ViT
!pip install -q albumentations  # Advanced augmentations
!pip install -q Pillow opencv-python-headless
!pip install -q pandas numpy scikit-learn matplotlib seaborn tqdm

In [None]:
# Core imports
import os
import json
import pickle
import warnings
from pathlib import Path
from collections import Counter
from typing import List, Dict, Tuple, Optional, Callable

# Data handling
import numpy as np
import pandas as pd

# Image processing
import cv2
from PIL import Image

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T

# timm for modern architectures
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Utilities
from tqdm.notebook import tqdm

# Settings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')

# Check device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ============================================
# CONFIGURATION - UPDATE THESE PATHS
# ============================================

# Path to your extracted images folder
IMAGES_DIR = "/content/drive/MyDrive/YOUR_IMAGES_FOLDER"  # <-- UPDATE THIS

# Path to your labels CSV file
LABELS_CSV_PATH = "/content/drive/MyDrive/YOUR_LABELS.csv"  # <-- UPDATE THIS

# Output directory for processed data
OUTPUT_DIR = "/content/drive/MyDrive/FS_Classification_Processed"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Configuration
CONFIG = {
    'seed': 42,
    'test_size': 0.15,
    'val_size': 0.15,
    'batch_size': 16,
    'num_workers': 2,
    'num_classes': 5,
}

# Set seeds for reproducibility
def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])
print(f"Configuration: {CONFIG}")

---
## 2. Load Your Existing Data

In [None]:
# Load the labels CSV
# Expected columns: image_path (or filename), label
# Adjust column names below if different

labels_df = pd.read_csv(LABELS_CSV_PATH)

print(f"Loaded {len(labels_df)} labeled samples")
print(f"\nColumns: {labels_df.columns.tolist()}")
print(f"\nFirst few rows:")
display(labels_df.head())

In [None]:
# ============================================
# ADJUST COLUMN NAMES HERE IF NEEDED
# ============================================

# Column containing image filename or path
IMAGE_COL = 'image_path'  # <-- UPDATE if different (e.g., 'filename', 'image_name')

# Column containing label
LABEL_COL = 'label'  # <-- UPDATE if different (e.g., 'class', 'category')

# If your CSV only has filename (not full path), construct full path
if not labels_df[IMAGE_COL].str.startswith('/').any():
    labels_df['image_path_full'] = labels_df[IMAGE_COL].apply(lambda x: os.path.join(IMAGES_DIR, x))
else:
    labels_df['image_path_full'] = labels_df[IMAGE_COL]

# Verify paths exist
labels_df['exists'] = labels_df['image_path_full'].apply(os.path.exists)
missing = labels_df[~labels_df['exists']]

if len(missing) > 0:
    print(f"‚ö†Ô∏è WARNING: {len(missing)} images not found!")
    print(f"Example missing: {missing['image_path_full'].iloc[0]}")
else:
    print(f"‚úÖ All {len(labels_df)} images found!")

In [None]:
# Keep only existing files
df = labels_df[labels_df['exists']].copy()
df = df.rename(columns={LABEL_COL: 'label', 'image_path_full': 'image_path'})

print(f"\nWorking with {len(df)} samples")
print(f"\nLabel distribution:")
print(df['label'].value_counts())

In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 6))
label_counts = df['label'].value_counts()
colors = sns.color_palette('husl', len(label_counts))

bars = plt.bar(range(len(label_counts)), label_counts.values, color=colors)
plt.xticks(range(len(label_counts)), label_counts.index, rotation=45, ha='right')
plt.ylabel('Count')
plt.title('Class Distribution (Imbalanced)', fontweight='bold')

# Add count labels on bars
for bar, count in zip(bars, label_counts.values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, 
             f'{count}\n({count/len(df)*100:.1f}%)', ha='center', fontsize=9)

plt.tight_layout()
plt.show()

# Calculate imbalance ratio
max_class = label_counts.max()
min_class = label_counts.min()
print(f"\nImbalance ratio (max/min): {max_class/min_class:.1f}x")

---
## 3. Encode Labels

In [None]:
# Define class names in consistent order
CLASS_NAMES = [
    'Financial Sheets',
    'Independent Auditor\'s Report',
    'Notes (Tabular)',
    'Notes (Text)',
    'Other Pages'
]

# Create label encoder
label_encoder = LabelEncoder()
label_encoder.fit(CLASS_NAMES)

# Encode labels
df['label_encoded'] = label_encoder.transform(df['label'])

print("Label encoding:")
for i, name in enumerate(label_encoder.classes_):
    count = (df['label_encoded'] == i).sum()
    print(f"  {i}: {name} ({count} samples)")

# Save label encoder
with open(os.path.join(OUTPUT_DIR, 'label_encoder.pkl'), 'wb') as f:
    pickle.dump(label_encoder, f)
print(f"\nLabel encoder saved to {OUTPUT_DIR}/label_encoder.pkl")

---
## 4. Train/Validation/Test Split (Stratified)

In [None]:
def create_stratified_split(df: pd.DataFrame, 
                            test_size: float = 0.15, 
                            val_size: float = 0.15, 
                            seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Create stratified train/val/test splits preserving class distribution.
    
    Split ratios: 70% train, 15% val, 15% test
    """
    # First split: train+val vs test
    train_val_df, test_df = train_test_split(
        df, 
        test_size=test_size, 
        stratify=df['label_encoded'],
        random_state=seed
    )
    
    # Second split: train vs val
    actual_val_size = val_size / (1 - test_size)
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=actual_val_size,
        stratify=train_val_df['label_encoded'],
        random_state=seed
    )
    
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)

In [None]:
# Create splits
train_df, val_df, test_df = create_stratified_split(
    df, 
    test_size=CONFIG['test_size'],
    val_size=CONFIG['val_size'],
    seed=CONFIG['seed']
)

print("="*50)
print("DATASET SPLITS")
print("="*50)
print(f"\nTrain: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"Val:   {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test:  {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
print(f"Total: {len(train_df) + len(val_df) + len(test_df)} samples")

In [None]:
# Verify stratification
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for ax, (name, split_df) in zip(axes, [('Train', train_df), ('Val', val_df), ('Test', test_df)]):
    counts = split_df['label'].value_counts()
    percentages = counts / len(split_df) * 100
    
    bars = ax.bar(range(len(counts)), counts.values, color=sns.color_palette('husl', len(counts)))
    ax.set_xticks(range(len(counts)))
    ax.set_xticklabels([c[:15] + '...' if len(c) > 15 else c for c in counts.index], rotation=45, ha='right')
    ax.set_title(f'{name} Set (n={len(split_df)})', fontweight='bold')
    ax.set_ylabel('Count')
    
    # Add percentage labels
    for bar, pct in zip(bars, percentages.values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, 
                f'{pct:.1f}%', ha='center', fontsize=8)

plt.tight_layout()
plt.show()

print("\n‚úÖ Stratification verified - class proportions maintained across splits")

In [None]:
# Save splits
train_df.to_csv(os.path.join(OUTPUT_DIR, 'train.csv'), index=False)
val_df.to_csv(os.path.join(OUTPUT_DIR, 'val.csv'), index=False)
test_df.to_csv(os.path.join(OUTPUT_DIR, 'test.csv'), index=False)

print(f"Splits saved to {OUTPUT_DIR}/")
print(f"  - train.csv ({len(train_df)} samples)")
print(f"  - val.csv ({len(val_df)} samples)")
print(f"  - test.csv ({len(test_df)} samples)")

---
## 5. Model-Specific Configurations

In [None]:
# ImageNet normalization values
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Model configurations
MODEL_CONFIGS = {
    'resnet50': {
        'input_size': 224,
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD,
        'timm_name': 'resnet50',
    },
    'efficientnet_b2': {
        'input_size': 260,
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD,
        'timm_name': 'efficientnet_b2',
    },
    'vit_base': {
        'input_size': 224,
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD,
        'timm_name': 'vit_base_patch16_224',
    }
}

print("Model configurations:")
for name, config in MODEL_CONFIGS.items():
    print(f"  {name}: {config['input_size']}x{config['input_size']} input")

---
## 6. Data Augmentation

In [None]:
def get_train_transforms(input_size: int, mean: List[float], std: List[float]) -> A.Compose:
    """
    Training augmentation pipeline using Albumentations.
    Designed for document images (conservative augmentations).
    """
    return A.Compose([
        # Resize maintaining aspect ratio, then pad and crop
        A.LongestMaxSize(max_size=int(input_size * 1.15)),
        A.PadIfNeeded(
            min_height=int(input_size * 1.15), 
            min_width=int(input_size * 1.15),
            border_mode=cv2.BORDER_CONSTANT,
            value=(255, 255, 255)  # White padding for documents
        ),
        A.RandomCrop(height=input_size, width=input_size),
        
        # Geometric augmentations (subtle for documents)
        A.ShiftScaleRotate(
            shift_limit=0.05,
            scale_limit=0.1,
            rotate_limit=5,
            border_mode=cv2.BORDER_CONSTANT,
            value=(255, 255, 255),
            p=0.5
        ),
        A.Perspective(scale=(0.02, 0.05), p=0.3),
        
        # Quality augmentations
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 5), p=1.0),
            A.MotionBlur(blur_limit=3, p=1.0),
        ], p=0.2),
        
        A.ImageCompression(quality_lower=75, quality_upper=100, p=0.3),
        A.GaussNoise(var_limit=(5.0, 20.0), p=0.2),
        
        # Slight contrast/brightness variation
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        
        # Normalize and convert to tensor
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ])


def get_val_transforms(input_size: int, mean: List[float], std: List[float]) -> A.Compose:
    """
    Validation/test transform pipeline - no augmentation.
    """
    return A.Compose([
        A.LongestMaxSize(max_size=input_size),
        A.PadIfNeeded(
            min_height=input_size, 
            min_width=input_size,
            border_mode=cv2.BORDER_CONSTANT,
            value=(255, 255, 255)
        ),
        A.CenterCrop(height=input_size, width=input_size),
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ])

In [None]:
# Visualize augmentations
def visualize_augmentations(image_path: str, transform: A.Compose, n_samples: int = 6):
    """Show multiple augmented versions of an image."""
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    fig, axes = plt.subplots(2, (n_samples + 1) // 2, figsize=(15, 8))
    axes = axes.flatten()
    
    # Original (resized for display)
    axes[0].imshow(img)
    axes[0].set_title('Original', fontweight='bold')
    axes[0].axis('off')
    
    # Augmented versions
    for i in range(1, n_samples):
        augmented = transform(image=img)['image']
        aug_img = augmented.permute(1, 2, 0).numpy()
        aug_img = aug_img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
        aug_img = np.clip(aug_img, 0, 1)
        
        axes[i].imshow(aug_img)
        axes[i].set_title(f'Augmented {i}')
        axes[i].axis('off')
    
    plt.suptitle('Data Augmentation Examples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show augmentation examples
sample_image = train_df.iloc[0]['image_path']
sample_transform = get_train_transforms(224, IMAGENET_MEAN, IMAGENET_STD)
visualize_augmentations(sample_image, sample_transform, n_samples=6)

---
## 7. PyTorch Dataset

In [None]:
class FinancialStatementDataset(Dataset):
    """
    PyTorch Dataset for Financial Statement page classification.
    """
    
    def __init__(self, dataframe: pd.DataFrame, transform: Optional[Callable] = None):
        """
        Args:
            dataframe: DataFrame with 'image_path' and 'label_encoded' columns
            transform: Albumentations transform pipeline
        """
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        row = self.df.iloc[idx]
        
        # Load image
        image = cv2.imread(row['image_path'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Get label
        label = row['label_encoded']
        
        # Apply transform
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        
        return image, label
    
    def get_labels(self) -> np.ndarray:
        """Return all labels (for weighted sampler)."""
        return self.df['label_encoded'].values

---
## 8. Handle Class Imbalance

In [None]:
def get_weighted_sampler(dataset: FinancialStatementDataset) -> WeightedRandomSampler:
    """
    Create weighted sampler for handling class imbalance.
    Over-samples minority classes during training.
    """
    labels = dataset.get_labels()
    class_counts = Counter(labels)
    
    # Calculate weights (inverse frequency)
    total_samples = len(labels)
    class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
    
    # Assign weight to each sample
    sample_weights = [class_weights[label] for label in labels]
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    print("Class weights for sampling:")
    for cls, weight in sorted(class_weights.items()):
        print(f"  Class {cls} ({label_encoder.classes_[cls][:20]}): {weight:.3f}")
    
    return sampler


def compute_class_weights(train_df: pd.DataFrame, num_classes: int) -> torch.Tensor:
    """
    Compute class weights for weighted loss function.
    """
    class_counts = train_df['label_encoded'].value_counts().sort_index()
    total_samples = len(train_df)
    
    # Inverse frequency weighting
    weights = total_samples / (num_classes * class_counts.values)
    weights = weights / weights.sum() * num_classes  # Normalize
    
    return torch.FloatTensor(weights)

In [None]:
# Compute and save class weights
class_weights = compute_class_weights(train_df, CONFIG['num_classes'])

print("\nClass weights for loss function:")
for i, (name, weight) in enumerate(zip(label_encoder.classes_, class_weights)):
    print(f"  {i}: {name}: {weight:.4f}")

# Save class weights
torch.save(class_weights, os.path.join(OUTPUT_DIR, 'class_weights.pt'))
print(f"\nClass weights saved to {OUTPUT_DIR}/class_weights.pt")

---
## 9. Create DataLoaders

In [None]:
def create_dataloaders(train_df: pd.DataFrame, 
                       val_df: pd.DataFrame, 
                       test_df: pd.DataFrame,
                       model_name: str,
                       batch_size: int = 16,
                       num_workers: int = 2,
                       use_weighted_sampling: bool = True) -> Dict[str, DataLoader]:
    """
    Create DataLoaders for a specific model configuration.
    """
    config = MODEL_CONFIGS[model_name]
    input_size = config['input_size']
    mean = config['mean']
    std = config['std']
    
    # Create transforms
    train_transform = get_train_transforms(input_size, mean, std)
    val_transform = get_val_transforms(input_size, mean, std)
    
    # Create datasets
    train_dataset = FinancialStatementDataset(train_df, train_transform)
    val_dataset = FinancialStatementDataset(val_df, val_transform)
    test_dataset = FinancialStatementDataset(test_df, val_transform)
    
    # Setup sampler for imbalanced data
    train_sampler = None
    shuffle_train = True
    
    if use_weighted_sampling:
        print(f"\nCreating weighted sampler for {model_name}...")
        train_sampler = get_weighted_sampler(train_dataset)
        shuffle_train = False
    
    # Create dataloaders
    dataloaders = {
        'train': DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=shuffle_train,
            sampler=train_sampler,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True
        ),
        'val': DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        ),
        'test': DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
    }
    
    print(f"\n{model_name.upper()} DataLoaders:")
    print(f"  Input size: {input_size}x{input_size}")
    print(f"  Train batches: {len(dataloaders['train'])}")
    print(f"  Val batches: {len(dataloaders['val'])}")
    print(f"  Test batches: {len(dataloaders['test'])}")
    
    return dataloaders

In [None]:
# Create dataloaders for all three models
print("="*60)
print("CREATING DATALOADERS FOR ALL MODELS")
print("="*60)

all_dataloaders = {}

for model_name in MODEL_CONFIGS.keys():
    print(f"\n{'-'*40}")
    all_dataloaders[model_name] = create_dataloaders(
        train_df, val_df, test_df,
        model_name=model_name,
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers'],
        use_weighted_sampling=True
    )

---
## 10. Verify DataLoaders

In [None]:
def visualize_batch(dataloader: DataLoader, model_name: str, n_samples: int = 8):
    """Visualize a batch from the dataloader."""
    config = MODEL_CONFIGS[model_name]
    mean = np.array(config['mean'])
    std = np.array(config['std'])
    
    images, labels = next(iter(dataloader))
    n_samples = min(n_samples, len(images))
    
    n_cols = 4
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 4*n_rows))
    axes = axes.flatten()
    
    for i in range(n_samples):
        img = images[i].permute(1, 2, 0).numpy()
        img = img * std + mean
        img = np.clip(img, 0, 1)
        
        label_name = label_encoder.classes_[labels[i]]
        
        axes[i].imshow(img)
        axes[i].set_title(f"{label_name[:20]}.." if len(label_name) > 20 else label_name, fontsize=9)
        axes[i].axis('off')
    
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(f'{model_name.upper()} - Training Batch', fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize batch for each model
for model_name in MODEL_CONFIGS.keys():
    visualize_batch(all_dataloaders[model_name]['train'], model_name, n_samples=8)

---
## 11. Save Configuration

In [None]:
# Save all configuration
preprocessing_config = {
    'config': CONFIG,
    'model_configs': MODEL_CONFIGS,
    'class_names': list(label_encoder.classes_),
    'imagenet_mean': IMAGENET_MEAN,
    'imagenet_std': IMAGENET_STD,
    'dataset_stats': {
        'total_samples': len(df),
        'train_size': len(train_df),
        'val_size': len(val_df),
        'test_size': len(test_df),
        'class_distribution': df['label'].value_counts().to_dict()
    }
}

config_path = os.path.join(OUTPUT_DIR, 'preprocessing_config.json')
with open(config_path, 'w') as f:
    json.dump(preprocessing_config, f, indent=2)

print(f"Configuration saved to {config_path}")

In [None]:
# Final Summary
print("="*70)
print("                    PREPROCESSING COMPLETE")
print("="*70)

print(f"\nüìÅ Output Directory: {OUTPUT_DIR}")

print(f"\nüìä Dataset Splits:")
print(f"   Train: {len(train_df)} samples")
print(f"   Val:   {len(val_df)} samples")
print(f"   Test:  {len(test_df)} samples")

print(f"\nüè∑Ô∏è Classes ({CONFIG['num_classes']}):")
for i, name in enumerate(label_encoder.classes_):
    train_count = (train_df['label_encoded'] == i).sum()
    print(f"   {i}: {name} ({train_count} train samples)")

print(f"\nü§ñ Models Ready:")
for model_name, config in MODEL_CONFIGS.items():
    print(f"   - {model_name}: {config['input_size']}x{config['input_size']}")

print(f"\n‚öñÔ∏è Class Imbalance Handling:")
print(f"   - Weighted sampling: ENABLED")
print(f"   - Class weights saved for loss function")

print(f"\nüì¶ Files Saved:")
files = ['train.csv', 'val.csv', 'test.csv', 'label_encoder.pkl', 
         'class_weights.pt', 'preprocessing_config.json']
for f in files:
    path = os.path.join(OUTPUT_DIR, f)
    exists = '‚úì' if os.path.exists(path) else '‚úó'
    print(f"   {exists} {f}")

print(f"\n‚úÖ Ready for model training!")

---
## Code for Training Notebook

Use this code to load the preprocessed data:

```python
import json
import pickle
import torch
import pandas as pd

OUTPUT_DIR = "/content/drive/MyDrive/FS_Classification_Processed"

# Load configuration
with open(f'{OUTPUT_DIR}/preprocessing_config.json', 'r') as f:
    config = json.load(f)

# Load splits
train_df = pd.read_csv(f'{OUTPUT_DIR}/train.csv')
val_df = pd.read_csv(f'{OUTPUT_DIR}/val.csv')
test_df = pd.read_csv(f'{OUTPUT_DIR}/test.csv')

# Load label encoder
with open(f'{OUTPUT_DIR}/label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

# Load class weights for loss function
class_weights = torch.load(f'{OUTPUT_DIR}/class_weights.pt')

print(f"Loaded {len(train_df)} train, {len(val_df)} val, {len(test_df)} test samples")
```