# Financial Statements Page Classification - Data Preprocessing

## Overview
This notebook preprocesses the financial statement pages for training three different computer vision models:

1. **ResNet50** - Classic CNN with transfer learning
2. **EfficientNet-B2** - Modern efficient architecture
3. **Vision Transformer (ViT)** - Transformer-based approach

### Dataset Statistics (from EDA):
- Total PDF documents: 30
- Total pages: 1,179
- Average dimensions: 1277 x 1692 pixels
- 5 Target Classes: Independent Auditor's Report, Financial Sheets, Notes (Tabular), Notes (Text), Other Pages

### Preprocessing Pipeline:
1. Extract pages from PDFs as images
2. Create/load labels
3. Train/Val/Test split (stratified)
4. Model-specific transformations
5. Data augmentation
6. Create PyTorch DataLoaders

---
## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q torch torchvision
!pip install -q timm  # For EfficientNet and ViT
!pip install -q albumentations  # Advanced augmentations
!pip install -q pymupdf Pillow opencv-python-headless
!pip install -q pandas numpy scikit-learn matplotlib seaborn tqdm

In [None]:
# Core imports
import os
import sys
import json
import shutil
import warnings
import pickle
from pathlib import Path
from collections import Counter
from typing import List, Dict, Tuple, Optional, Callable

# Data handling
import numpy as np
import pandas as pd

# Image processing
import cv2
from PIL import Image
import fitz  # PyMuPDF

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
from torchvision import models

# timm for modern architectures
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Utilities
from tqdm.notebook import tqdm

# Settings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')

# Check device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ============================================
# CONFIGURATION - UPDATE THESE PATHS
# ============================================

# Source data path (PDFs)
PDF_DATA_PATH = "/content/drive/MyDrive/YOUR_DATASET_FOLDER"  # <-- UPDATE THIS

# Output paths
OUTPUT_DIR = "/content/drive/MyDrive/FS_Classification"  # <-- Where to save processed data
IMAGES_DIR = os.path.join(OUTPUT_DIR, "images")
LABELS_PATH = os.path.join(OUTPUT_DIR, "labels.csv")

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(IMAGES_DIR, exist_ok=True)

# Model configuration
CONFIG = {
    'seed': 42,
    'test_size': 0.15,
    'val_size': 0.15,  # of remaining after test split
    'batch_size': 16,
    'num_workers': 2,
    'num_classes': 5,
    'extraction_dpi': 150,  # DPI for PDF to image conversion
}

# Class names
CLASS_NAMES = [
    'Independent Auditors Report',
    'Financial Sheets',
    'Notes Tabular',
    'Notes Text',
    'Other Pages'
]

# Set seeds for reproducibility
def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])
print(f"Configuration: {CONFIG}")

---
## 2. PDF to Image Extraction

In [None]:
def extract_pdf_pages(pdf_path: str, output_dir: str, dpi: int = 150) -> List[Dict]:
    """
    Extract all pages from a PDF and save as images.
    
    Args:
        pdf_path: Path to PDF file
        output_dir: Directory to save images
        dpi: Resolution for rendering
    
    Returns:
        List of dictionaries with page info
    """
    pages_info = []
    pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
    
    try:
        doc = fitz.open(pdf_path)
        zoom = dpi / 72
        mat = fitz.Matrix(zoom, zoom)
        
        for page_num in range(len(doc)):
            page = doc[page_num]
            pix = page.get_pixmap(matrix=mat)
            
            # Create filename
            img_filename = f"{pdf_name}_page_{page_num + 1:04d}.png"
            img_path = os.path.join(output_dir, img_filename)
            
            # Save image
            pix.save(img_path)
            
            pages_info.append({
                'pdf_name': pdf_name,
                'page_num': page_num + 1,
                'total_pages': len(doc),
                'image_path': img_path,
                'image_filename': img_filename,
                'width': pix.width,
                'height': pix.height
            })
        
        doc.close()
        
    except Exception as e:
        print(f"Error processing {pdf_path}: {e}")
    
    return pages_info

In [None]:
def extract_all_pdfs(pdf_dir: str, output_dir: str, dpi: int = 150) -> pd.DataFrame:
    """
    Extract all pages from all PDFs in a directory.
    
    Args:
        pdf_dir: Directory containing PDFs
        output_dir: Directory to save images
        dpi: Resolution for rendering
    
    Returns:
        DataFrame with all page information
    """
    # Find all PDFs
    pdf_files = []
    for root, dirs, files in os.walk(pdf_dir):
        for file in files:
            if file.lower().endswith('.pdf'):
                pdf_files.append(os.path.join(root, file))
    
    print(f"Found {len(pdf_files)} PDF files")
    
    # Extract all pages
    all_pages = []
    for pdf_path in tqdm(pdf_files, desc="Extracting PDFs"):
        pages = extract_pdf_pages(pdf_path, output_dir, dpi)
        all_pages.extend(pages)
    
    df = pd.DataFrame(all_pages)
    print(f"\nExtracted {len(df)} pages from {len(pdf_files)} PDFs")
    
    return df

In [None]:
# Check if images already extracted
existing_images = [f for f in os.listdir(IMAGES_DIR) if f.endswith('.png')] if os.path.exists(IMAGES_DIR) else []

if len(existing_images) > 0:
    print(f"Found {len(existing_images)} existing images in {IMAGES_DIR}")
    print("Skipping extraction. Delete the images folder to re-extract.")
    
    # Reconstruct DataFrame from existing images
    pages_data = []
    for img_file in existing_images:
        parts = img_file.rsplit('_page_', 1)
        if len(parts) == 2:
            pdf_name = parts[0]
            page_num = int(parts[1].split('.')[0])
            img_path = os.path.join(IMAGES_DIR, img_file)
            
            # Get image dimensions
            img = Image.open(img_path)
            w, h = img.size
            img.close()
            
            pages_data.append({
                'pdf_name': pdf_name,
                'page_num': page_num,
                'image_path': img_path,
                'image_filename': img_file,
                'width': w,
                'height': h
            })
    
    pages_df = pd.DataFrame(pages_data)
else:
    # Extract all PDFs
    print("Extracting pages from PDFs...")
    pages_df = extract_all_pdfs(PDF_DATA_PATH, IMAGES_DIR, dpi=CONFIG['extraction_dpi'])

print(f"\nTotal pages: {len(pages_df)}")
print(f"Unique PDFs: {pages_df['pdf_name'].nunique()}")

---
## 3. Labeling Strategy

Since manual labeling for 1,179 pages can be time-consuming, we provide multiple options:
1. **Load existing labels** (if you have them)
2. **Use cluster assignments** as pseudo-labels from EDA
3. **Semi-supervised approach**: Label a subset and propagate
4. **Manual labeling tool**

In [None]:
# ============================================
# OPTION 1: Load existing labels
# ============================================

# Uncomment and modify if you have labels
# EXISTING_LABELS_PATH = "/content/drive/MyDrive/your_labels.csv"
# if os.path.exists(EXISTING_LABELS_PATH):
#     labels_df = pd.read_csv(EXISTING_LABELS_PATH)
#     pages_df = pages_df.merge(labels_df, on=['pdf_name', 'page_num'], how='left')
#     print(f"Loaded labels for {pages_df['label'].notna().sum()} pages")

In [None]:
# ============================================
# OPTION 2: Load cluster assignments from EDA
# ============================================

EDA_RESULTS_PATH = os.path.join(OUTPUT_DIR, "page_analysis_results.csv")

if os.path.exists(EDA_RESULTS_PATH):
    eda_df = pd.read_csv(EDA_RESULTS_PATH)
    if 'cluster' in eda_df.columns:
        pages_df = pages_df.merge(
            eda_df[['pdf_name', 'page_num', 'cluster']], 
            on=['pdf_name', 'page_num'], 
            how='left'
        )
        print(f"Loaded cluster assignments for {pages_df['cluster'].notna().sum()} pages")
        print(f"\nCluster distribution:")
        print(pages_df['cluster'].value_counts().sort_index())
else:
    print(f"EDA results not found at {EDA_RESULTS_PATH}")
    print("Run EDA notebook first or create labels manually.")

In [None]:
# ============================================
# OPTION 3: Create label column based on rules or clusters
# ============================================

# If using clusters as pseudo-labels, map them to class names
# IMPORTANT: Review cluster samples in EDA to determine appropriate mapping

# Example mapping (ADJUST BASED ON YOUR EDA RESULTS):
# From EDA clustering:
# Cluster 0: 635 pages (53.9%) - Likely Notes (Text) - largest cluster
# Cluster 1: 457 pages (38.8%) - Likely Notes (Tabular) or Financial Sheets
# Cluster 2: 62 pages (5.3%)   - Could be Independent Auditor's Report
# Cluster 3: 2 pages (0.2%)    - Outliers/Other
# Cluster 4: 23 pages (2.0%)   - Likely Other Pages (covers, TOC)

CLUSTER_TO_LABEL_MAP = {
    0: 'Notes Text',
    1: 'Notes Tabular',  # or Financial Sheets - verify with EDA visuals
    2: 'Independent Auditors Report',
    3: 'Other Pages',
    4: 'Other Pages'
}

if 'cluster' in pages_df.columns:
    pages_df['label'] = pages_df['cluster'].map(CLUSTER_TO_LABEL_MAP)
    print("Label distribution (from cluster mapping):")
    print(pages_df['label'].value_counts())
    print("\n‚ö†Ô∏è  IMPORTANT: Verify this mapping by reviewing cluster samples in EDA!")
else:
    print("No cluster column found. Please create labels manually.")
    pages_df['label'] = None

In [None]:
# ============================================
# OPTION 4: Manual labeling helper
# ============================================

def create_labeling_interface(df: pd.DataFrame, start_idx: int = 0, n_samples: int = 10):
    """
    Display images for manual labeling.
    
    Args:
        df: DataFrame with image paths
        start_idx: Starting index
        n_samples: Number of samples to display
    """
    print("Class options:")
    for i, name in enumerate(CLASS_NAMES):
        print(f"  {i}: {name}")
    print()
    
    end_idx = min(start_idx + n_samples, len(df))
    
    n_cols = 2
    n_rows = (end_idx - start_idx + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 6*n_rows))
    axes = axes.flatten() if n_rows > 1 else [axes]
    
    for i, idx in enumerate(range(start_idx, end_idx)):
        row = df.iloc[idx]
        img = Image.open(row['image_path'])
        
        axes[i].imshow(img)
        current_label = row.get('label', 'Unlabeled')
        axes[i].set_title(f"Index: {idx}\n{row['pdf_name']}\nPage: {row['page_num']}\nCurrent: {current_label}", fontsize=9)
        axes[i].axis('off')
    
    for i in range(end_idx - start_idx, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Uncomment to use manual labeling
# create_labeling_interface(pages_df, start_idx=0, n_samples=10)

In [None]:
# Manual label assignment (if needed)
# Uncomment and modify as needed

# Example: Set specific labels
# pages_df.loc[pages_df.index == 0, 'label'] = 'Other Pages'  # First page is cover
# pages_df.loc[pages_df.index == 1, 'label'] = 'Other Pages'  # TOC

# Or batch assign based on patterns
# pages_df.loc[pages_df['page_num'] == 1, 'label'] = 'Other Pages'  # All first pages are covers

In [None]:
# Verify labels
print("="*50)
print("LABEL SUMMARY")
print("="*50)

if 'label' in pages_df.columns and pages_df['label'].notna().any():
    print(f"\nLabeled pages: {pages_df['label'].notna().sum()} / {len(pages_df)}")
    print(f"\nLabel distribution:")
    label_counts = pages_df['label'].value_counts()
    for label, count in label_counts.items():
        print(f"  {label}: {count} ({count/len(pages_df)*100:.1f}%)")
    
    # Visualize
    plt.figure(figsize=(10, 6))
    label_counts.plot(kind='bar', color=sns.color_palette('husl', len(label_counts)))
    plt.title('Label Distribution', fontweight='bold')
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
else:
    print("No labels found. Please create labels using one of the options above.")

In [None]:
# Encode labels
if 'label' in pages_df.columns and pages_df['label'].notna().any():
    # Only use labeled data
    labeled_df = pages_df[pages_df['label'].notna()].copy()
    
    # Create label encoder
    label_encoder = LabelEncoder()
    label_encoder.fit(CLASS_NAMES)
    
    labeled_df['label_encoded'] = label_encoder.transform(labeled_df['label'])
    
    print(f"Label encoding:")
    for i, name in enumerate(label_encoder.classes_):
        print(f"  {i}: {name}")
    
    # Save label encoder
    with open(os.path.join(OUTPUT_DIR, 'label_encoder.pkl'), 'wb') as f:
        pickle.dump(label_encoder, f)
else:
    labeled_df = pages_df.copy()
    print("WARNING: No labels available. Models will need labels for training.")

---
## 4. Train/Validation/Test Split

In [None]:
def create_stratified_split(df: pd.DataFrame, test_size: float = 0.15, 
                            val_size: float = 0.15, seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Create stratified train/val/test splits.
    
    Args:
        df: DataFrame with 'label_encoded' column
        test_size: Proportion for test set
        val_size: Proportion for validation set (from remaining after test)
        seed: Random seed
    
    Returns:
        train_df, val_df, test_df
    """
    # First split: train+val vs test
    train_val_df, test_df = train_test_split(
        df, 
        test_size=test_size, 
        stratify=df['label_encoded'],
        random_state=seed
    )
    
    # Second split: train vs val
    actual_val_size = val_size / (1 - test_size)
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=actual_val_size,
        stratify=train_val_df['label_encoded'],
        random_state=seed
    )
    
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)

In [None]:
# Create splits
if 'label_encoded' in labeled_df.columns:
    train_df, val_df, test_df = create_stratified_split(
        labeled_df, 
        test_size=CONFIG['test_size'],
        val_size=CONFIG['val_size'],
        seed=CONFIG['seed']
    )
    
    print(f"Dataset splits:")
    print(f"  Train: {len(train_df)} ({len(train_df)/len(labeled_df)*100:.1f}%)")
    print(f"  Val:   {len(val_df)} ({len(val_df)/len(labeled_df)*100:.1f}%)")
    print(f"  Test:  {len(test_df)} ({len(test_df)/len(labeled_df)*100:.1f}%)")
    
    # Save splits
    train_df.to_csv(os.path.join(OUTPUT_DIR, 'train.csv'), index=False)
    val_df.to_csv(os.path.join(OUTPUT_DIR, 'val.csv'), index=False)
    test_df.to_csv(os.path.join(OUTPUT_DIR, 'test.csv'), index=False)
    print(f"\nSplits saved to {OUTPUT_DIR}")
else:
    print("Cannot create splits without labels.")

In [None]:
# Verify stratification
if 'label_encoded' in labeled_df.columns:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for ax, (name, df) in zip(axes, [('Train', train_df), ('Val', val_df), ('Test', test_df)]):
        counts = df['label'].value_counts()
        ax.bar(range(len(counts)), counts.values, color=sns.color_palette('husl', len(counts)))
        ax.set_xticks(range(len(counts)))
        ax.set_xticklabels(counts.index, rotation=45, ha='right')
        ax.set_title(f'{name} Set Distribution\n(n={len(df)})', fontweight='bold')
        ax.set_ylabel('Count')
    
    plt.tight_layout()
    plt.show()

---
## 5. Model-Specific Preprocessing

Each model requires different input sizes and normalization:

| Model | Input Size | Normalization |
|-------|------------|---------------|
| ResNet50 | 224x224 | ImageNet mean/std |
| EfficientNet-B2 | 260x260 | ImageNet mean/std |
| ViT-Base | 224x224 | ImageNet mean/std |

In [None]:
# ImageNet normalization values
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Model-specific configurations
MODEL_CONFIGS = {
    'resnet50': {
        'input_size': 224,
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD,
        'timm_name': 'resnet50',
        'pretrained': True
    },
    'efficientnet_b2': {
        'input_size': 260,
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD,
        'timm_name': 'efficientnet_b2',
        'pretrained': True
    },
    'vit_base': {
        'input_size': 224,
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD,
        'timm_name': 'vit_base_patch16_224',
        'pretrained': True
    }
}

print("Model configurations:")
for name, config in MODEL_CONFIGS.items():
    print(f"  {name}: input_size={config['input_size']}")

---
## 6. Data Augmentation Strategies

In [None]:
def get_train_transforms(input_size: int, mean: List[float], std: List[float]) -> A.Compose:
    """
    Get training augmentation pipeline using Albumentations.
    
    Augmentations suitable for document images:
    - Geometric: small rotations, slight perspective changes
    - Quality: compression artifacts, slight blur
    - No color jittering (documents are mostly grayscale/B&W)
    
    Args:
        input_size: Target image size
        mean: Normalization mean
        std: Normalization std
    
    Returns:
        Albumentations Compose transform
    """
    return A.Compose([
        # Resize with aspect ratio preservation, then pad/crop
        A.LongestMaxSize(max_size=int(input_size * 1.1)),
        A.PadIfNeeded(
            min_height=int(input_size * 1.1), 
            min_width=int(input_size * 1.1),
            border_mode=cv2.BORDER_CONSTANT,
            value=(255, 255, 255)  # White padding for documents
        ),
        A.RandomCrop(height=input_size, width=input_size),
        
        # Geometric augmentations (subtle for documents)
        A.ShiftScaleRotate(
            shift_limit=0.05,
            scale_limit=0.1,
            rotate_limit=5,  # Small rotation for documents
            border_mode=cv2.BORDER_CONSTANT,
            value=(255, 255, 255),
            p=0.5
        ),
        A.Perspective(
            scale=(0.02, 0.05),  # Subtle perspective change
            p=0.3
        ),
        
        # Quality augmentations
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 5), p=1.0),
            A.MotionBlur(blur_limit=3, p=1.0),
        ], p=0.2),
        
        A.ImageCompression(quality_lower=75, quality_upper=100, p=0.3),
        
        # Noise
        A.GaussNoise(var_limit=(5.0, 20.0), p=0.2),
        
        # Slight contrast variation
        A.RandomBrightnessContrast(
            brightness_limit=0.1,
            contrast_limit=0.1,
            p=0.3
        ),
        
        # Horizontal flip (may not be suitable for all document types)
        # A.HorizontalFlip(p=0.5),  # Uncomment if appropriate
        
        # Normalize and convert to tensor
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ])

In [None]:
def get_val_transforms(input_size: int, mean: List[float], std: List[float]) -> A.Compose:
    """
    Get validation/test transform pipeline.
    Only resize and normalize - no augmentation.
    
    Args:
        input_size: Target image size
        mean: Normalization mean
        std: Normalization std
    
    Returns:
        Albumentations Compose transform
    """
    return A.Compose([
        # Resize maintaining aspect ratio, then center crop
        A.LongestMaxSize(max_size=input_size),
        A.PadIfNeeded(
            min_height=input_size, 
            min_width=input_size,
            border_mode=cv2.BORDER_CONSTANT,
            value=(255, 255, 255)
        ),
        A.CenterCrop(height=input_size, width=input_size),
        
        # Normalize and convert to tensor
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ])

In [None]:
# Alternative: Simple torchvision transforms (for reference)

def get_torchvision_train_transforms(input_size: int, mean: List[float], std: List[float]) -> T.Compose:
    """Torchvision-based training transforms."""
    return T.Compose([
        T.Resize((int(input_size * 1.1), int(input_size * 1.1))),
        T.RandomCrop(input_size),
        T.RandomRotation(degrees=5),
        T.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])

def get_torchvision_val_transforms(input_size: int, mean: List[float], std: List[float]) -> T.Compose:
    """Torchvision-based validation transforms."""
    return T.Compose([
        T.Resize((input_size, input_size)),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])

In [None]:
# Visualize augmentations
def visualize_augmentations(image_path: str, transform: A.Compose, n_samples: int = 6):
    """
    Visualize multiple augmented versions of an image.
    
    Args:
        image_path: Path to source image
        transform: Augmentation pipeline
        n_samples: Number of augmented versions to show
    """
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    fig, axes = plt.subplots(2, (n_samples + 2) // 2, figsize=(15, 8))
    axes = axes.flatten()
    
    # Original
    axes[0].imshow(img)
    axes[0].set_title('Original', fontweight='bold')
    axes[0].axis('off')
    
    # Augmented versions (unnormalize for visualization)
    for i in range(1, n_samples + 1):
        augmented = transform(image=img)['image']
        
        # Unnormalize for display
        aug_img = augmented.permute(1, 2, 0).numpy()
        aug_img = aug_img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
        aug_img = np.clip(aug_img, 0, 1)
        
        axes[i].imshow(aug_img)
        axes[i].set_title(f'Augmented {i}')
        axes[i].axis('off')
    
    plt.suptitle('Data Augmentation Examples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

In [None]:
# Demonstrate augmentations on a sample image
if len(train_df) > 0:
    sample_image_path = train_df.iloc[0]['image_path']
    sample_transform = get_train_transforms(224, IMAGENET_MEAN, IMAGENET_STD)
    
    print(f"Visualizing augmentations on: {sample_image_path}")
    visualize_augmentations(sample_image_path, sample_transform, n_samples=5)

---
## 7. PyTorch Dataset Class

In [None]:
class FinancialStatementDataset(Dataset):
    """
    PyTorch Dataset for Financial Statement page classification.
    
    Supports both Albumentations and torchvision transforms.
    """
    
    def __init__(self, 
                 dataframe: pd.DataFrame,
                 transform: Optional[Callable] = None,
                 use_albumentations: bool = True):
        """
        Args:
            dataframe: DataFrame with 'image_path' and 'label_encoded' columns
            transform: Transform pipeline
            use_albumentations: Whether transform is from Albumentations (vs torchvision)
        """
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        self.use_albumentations = use_albumentations
        
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        row = self.df.iloc[idx]
        
        # Load image
        if self.use_albumentations:
            image = cv2.imread(row['image_path'])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            image = Image.open(row['image_path']).convert('RGB')
        
        # Get label
        label = row['label_encoded']
        
        # Apply transform
        if self.transform:
            if self.use_albumentations:
                transformed = self.transform(image=image)
                image = transformed['image']
            else:
                image = self.transform(image)
        
        return image, label
    
    def get_labels(self) -> np.ndarray:
        """Return all labels (for weighted sampler)."""
        return self.df['label_encoded'].values

In [None]:
class FinancialStatementInferenceDataset(Dataset):
    """
    Dataset for inference (no labels required).
    """
    
    def __init__(self, 
                 image_paths: List[str],
                 transform: Optional[Callable] = None,
                 use_albumentations: bool = True):
        """
        Args:
            image_paths: List of image file paths
            transform: Transform pipeline
            use_albumentations: Whether transform is from Albumentations
        """
        self.image_paths = image_paths
        self.transform = transform
        self.use_albumentations = use_albumentations
        
    def __len__(self) -> int:
        return len(self.image_paths)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
        image_path = self.image_paths[idx]
        
        # Load image
        if self.use_albumentations:
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            image = Image.open(image_path).convert('RGB')
        
        # Apply transform
        if self.transform:
            if self.use_albumentations:
                transformed = self.transform(image=image)
                image = transformed['image']
            else:
                image = self.transform(image)
        
        return image, image_path

---
## 8. DataLoader Factory

In [None]:
def get_weighted_sampler(dataset: FinancialStatementDataset) -> WeightedRandomSampler:
    """
    Create weighted random sampler for handling class imbalance.
    
    Args:
        dataset: Training dataset
    
    Returns:
        WeightedRandomSampler
    """
    labels = dataset.get_labels()
    class_counts = Counter(labels)
    
    # Calculate weights (inverse frequency)
    total_samples = len(labels)
    class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
    
    # Assign weight to each sample
    sample_weights = [class_weights[label] for label in labels]
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    print(f"Class weights: {class_weights}")
    return sampler

In [None]:
def create_dataloaders(train_df: pd.DataFrame, 
                       val_df: pd.DataFrame, 
                       test_df: pd.DataFrame,
                       model_name: str,
                       batch_size: int = 16,
                       num_workers: int = 2,
                       use_weighted_sampling: bool = True) -> Dict[str, DataLoader]:
    """
    Create DataLoaders for a specific model.
    
    Args:
        train_df: Training DataFrame
        val_df: Validation DataFrame
        test_df: Test DataFrame
        model_name: Name of the model (resnet50, efficientnet_b2, vit_base)
        batch_size: Batch size
        num_workers: Number of data loading workers
        use_weighted_sampling: Whether to use weighted sampling for imbalanced classes
    
    Returns:
        Dictionary with train, val, test DataLoaders
    """
    config = MODEL_CONFIGS[model_name]
    input_size = config['input_size']
    mean = config['mean']
    std = config['std']
    
    # Create transforms
    train_transform = get_train_transforms(input_size, mean, std)
    val_transform = get_val_transforms(input_size, mean, std)
    
    # Create datasets
    train_dataset = FinancialStatementDataset(train_df, train_transform, use_albumentations=True)
    val_dataset = FinancialStatementDataset(val_df, val_transform, use_albumentations=True)
    test_dataset = FinancialStatementDataset(test_df, val_transform, use_albumentations=True)
    
    # Create sampler for training if needed
    train_sampler = None
    shuffle_train = True
    
    if use_weighted_sampling:
        train_sampler = get_weighted_sampler(train_dataset)
        shuffle_train = False  # Sampler handles shuffling
    
    # Create dataloaders
    dataloaders = {
        'train': DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=shuffle_train,
            sampler=train_sampler,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True
        ),
        'val': DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        ),
        'test': DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
    }
    
    print(f"\nDataLoaders created for {model_name}:")
    print(f"  Input size: {input_size}x{input_size}")
    print(f"  Train batches: {len(dataloaders['train'])}")
    print(f"  Val batches: {len(dataloaders['val'])}")
    print(f"  Test batches: {len(dataloaders['test'])}")
    
    return dataloaders

In [None]:
# Create dataloaders for all three models
all_dataloaders = {}

if 'label_encoded' in train_df.columns:
    for model_name in MODEL_CONFIGS.keys():
        print(f"\n{'='*50}")
        print(f"Creating dataloaders for: {model_name.upper()}")
        print('='*50)
        
        all_dataloaders[model_name] = create_dataloaders(
            train_df, val_df, test_df,
            model_name=model_name,
            batch_size=CONFIG['batch_size'],
            num_workers=CONFIG['num_workers'],
            use_weighted_sampling=True
        )

---
## 9. Verify DataLoaders

In [None]:
def visualize_batch(dataloader: DataLoader, model_name: str, n_samples: int = 8):
    """
    Visualize a batch from the dataloader.
    
    Args:
        dataloader: DataLoader to visualize
        model_name: Name of the model
        n_samples: Number of samples to show
    """
    config = MODEL_CONFIGS[model_name]
    mean = np.array(config['mean'])
    std = np.array(config['std'])
    
    # Get a batch
    images, labels = next(iter(dataloader))
    
    n_samples = min(n_samples, len(images))
    n_cols = 4
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4*n_rows))
    axes = axes.flatten()
    
    for i in range(n_samples):
        img = images[i].permute(1, 2, 0).numpy()
        img = img * std + mean
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img)
        axes[i].set_title(f"Label: {CLASS_NAMES[labels[i]]}")
        axes[i].axis('off')
    
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(f'{model_name.upper()} - Training Batch Sample', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize batches for each model
if all_dataloaders:
    for model_name, dataloaders in all_dataloaders.items():
        print(f"\n{'='*50}")
        print(f"Sample batch for: {model_name.upper()}")
        print('='*50)
        visualize_batch(dataloaders['train'], model_name, n_samples=8)

---
## 10. Model Architecture Preview

In [None]:
def create_model(model_name: str, num_classes: int, pretrained: bool = True) -> nn.Module:
    """
    Create a model with pretrained weights.
    
    Args:
        model_name: Name of the model
        num_classes: Number of output classes
        pretrained: Whether to use pretrained weights
    
    Returns:
        PyTorch model
    """
    config = MODEL_CONFIGS[model_name]
    timm_name = config['timm_name']
    
    model = timm.create_model(
        timm_name,
        pretrained=pretrained,
        num_classes=num_classes
    )
    
    return model

In [None]:
# Preview model architectures
print("="*70)
print("MODEL ARCHITECTURE SUMMARY")
print("="*70)

for model_name in MODEL_CONFIGS.keys():
    print(f"\n{model_name.upper()}:")
    print("-"*40)
    
    model = create_model(model_name, CONFIG['num_classes'], pretrained=False)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Input size: {MODEL_CONFIGS[model_name]['input_size']}x{MODEL_CONFIGS[model_name]['input_size']}")
    
    # Test forward pass
    config = MODEL_CONFIGS[model_name]
    dummy_input = torch.randn(1, 3, config['input_size'], config['input_size'])
    output = model(dummy_input)
    print(f"  Output shape: {output.shape}")
    
    del model

---
## 11. Class Weights for Loss Function

In [None]:
def compute_class_weights(train_df: pd.DataFrame, num_classes: int) -> torch.Tensor:
    """
    Compute class weights for handling imbalanced dataset.
    
    Args:
        train_df: Training DataFrame with 'label_encoded' column
        num_classes: Number of classes
    
    Returns:
        Tensor of class weights
    """
    class_counts = train_df['label_encoded'].value_counts().sort_index()
    total_samples = len(train_df)
    
    # Compute weights using inverse frequency
    weights = total_samples / (num_classes * class_counts.values)
    
    # Normalize weights
    weights = weights / weights.sum() * num_classes
    
    return torch.FloatTensor(weights)

In [None]:
# Compute and display class weights
if 'label_encoded' in train_df.columns:
    class_weights = compute_class_weights(train_df, CONFIG['num_classes'])
    
    print("Class weights for loss function:")
    for i, (name, weight) in enumerate(zip(CLASS_NAMES, class_weights)):
        print(f"  {i}: {name}: {weight:.4f}")
    
    # Save class weights
    torch.save(class_weights, os.path.join(OUTPUT_DIR, 'class_weights.pt'))
    print(f"\nClass weights saved to {OUTPUT_DIR}/class_weights.pt")

---
## 12. Save Preprocessing Configuration

In [None]:
# Save all configuration and metadata
preprocessing_config = {
    'config': CONFIG,
    'model_configs': MODEL_CONFIGS,
    'class_names': CLASS_NAMES,
    'imagenet_mean': IMAGENET_MEAN,
    'imagenet_std': IMAGENET_STD,
    'dataset_stats': {
        'total_pages': len(labeled_df) if 'labeled_df' in dir() else 0,
        'train_size': len(train_df) if 'train_df' in dir() else 0,
        'val_size': len(val_df) if 'val_df' in dir() else 0,
        'test_size': len(test_df) if 'test_df' in dir() else 0,
    },
    'cluster_to_label_map': CLUSTER_TO_LABEL_MAP if 'CLUSTER_TO_LABEL_MAP' in dir() else None
}

# Save as JSON
config_path = os.path.join(OUTPUT_DIR, 'preprocessing_config.json')
with open(config_path, 'w') as f:
    json.dump(preprocessing_config, f, indent=2)

print(f"Configuration saved to {config_path}")

In [None]:
# Summary
print("="*70)
print("                    PREPROCESSING SUMMARY")
print("="*70)

print(f"\nüìÅ OUTPUT DIRECTORY: {OUTPUT_DIR}")
print(f"\nüìä DATASET:")
print(f"   Total pages: {len(labeled_df) if 'labeled_df' in dir() else 'N/A'}")
print(f"   Train: {len(train_df) if 'train_df' in dir() else 'N/A'}")
print(f"   Val: {len(val_df) if 'val_df' in dir() else 'N/A'}")
print(f"   Test: {len(test_df) if 'test_df' in dir() else 'N/A'}")

print(f"\nüè∑Ô∏è  CLASSES: {CONFIG['num_classes']}")
for i, name in enumerate(CLASS_NAMES):
    print(f"   {i}: {name}")

print(f"\nü§ñ MODELS PREPARED:")
for model_name, config in MODEL_CONFIGS.items():
    print(f"   - {model_name}: input={config['input_size']}x{config['input_size']}")

print(f"\nüì¶ FILES SAVED:")
saved_files = [
    'train.csv', 'val.csv', 'test.csv',
    'label_encoder.pkl', 'class_weights.pt',
    'preprocessing_config.json'
]
for f in saved_files:
    path = os.path.join(OUTPUT_DIR, f)
    exists = '‚úì' if os.path.exists(path) else '‚úó'
    print(f"   {exists} {f}")

print(f"\n‚úÖ Preprocessing complete! Ready for model training.")

---
## Next Steps

The data is now preprocessed and ready for training. The next notebook should:

1. **Load preprocessed data** using the saved CSV files and configurations
2. **Train three models**:
   - ResNet50 with transfer learning
   - EfficientNet-B2 with transfer learning
   - Vision Transformer (ViT-Base)
3. **Evaluate and compare** model performance
4. **Save best model** for inference

### Code to Load Data in Training Notebook:

```python
# Load configuration
with open('preprocessing_config.json', 'r') as f:
    config = json.load(f)

# Load data splits
train_df = pd.read_csv('train.csv')
val_df = pd.read_csv('val.csv')
test_df = pd.read_csv('test.csv')

# Load label encoder
with open('label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

# Load class weights
class_weights = torch.load('class_weights.pt')
```