In [1]:
import sys
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset,DataLoader

from datetime import datetime
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from transformers import CLIPModel

source_path = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(source_path)

# definitions

In [2]:
class CustomMLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(CustomMLP, self).__init__()
        layers = []
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(input_size, output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
class WrappedModel(torch.nn.Module):
    def __init__(self, model, custom_mlp,type_of_output='cls'):
        super().__init__()
        self.model = model
        self.type_of_output = type_of_output
        self.mlp = custom_mlp

    def forward(self, x):
        image_features = self.model.get_image_features(x)
        # Normalize the features (optional but common)
        image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
        out = self.mlp(image_features)
        return out

## scheduling

In [3]:
def get_cosine_schedule_with_warmup(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return float(current_epoch) / float(max(1, warmup_epochs))
        progress = float(current_epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1. + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

## training

In [4]:
def train(
    model,
    train_dataloader,
    val_dataloader,
    device,
    base_lr,
    weight_decay=0.01,
    warmup_epochs=10,
    total_epochs=100,
    checkpoint_path='checkpoint.pt',
    plot_every=5,
    early_stopping_patience=10
):
    model = model.to(device)
    def get_model_size_mb(model):
        buffer = io.BytesIO()
        torch.save(model.state_dict(), buffer)
        size_mb = buffer.getbuffer().nbytes / 1e6
        return size_mb

    model_size_mb = get_model_size_mb(model)
    print(f"Model size: {model_size_mb:.2f} MB")

    optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, warmup_epochs=warmup_epochs, total_epochs=total_epochs
    )
    loss_fn = torch.nn.CrossEntropyLoss()

    # Initialize tracking variables
    start_epoch = 0
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    lrs = []
    best_val_loss = float('inf')
    epochs_without_improvement = 0

    # Optional Resume from checkpoint
    if os.path.exists(checkpoint_path):
        print(f"🔄 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        best_val_loss = checkpoint['best_val_loss']
        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])
        lrs = checkpoint.get('lrs', [])
    else:
        print(f"📂 No checkpoint found at {checkpoint_path}. Starting fresh training.")

    for epoch in range(start_epoch, total_epochs):
        model.train()
        epoch_train_loss = 0
        correct, total = 0, 0
        for idx,batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{total_epochs} [Train]")):
            x, labels = batch['image'].to(device), batch['label'].to(device)
            outputs=model(x)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            
            optimizer.step()
            epoch_train_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

            '''if idx % 50 == 0:  # Print every 100 batches
                before = torch.cuda.memory_allocated() / 1e6
                before_reserved = torch.cuda.memory_reserved() / 1e6
                print(f"[GPU Memory] Allocated: {before:.2f} MB | Reserved: {before_reserved:.2f} MB")'''

        avg_train_loss = epoch_train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        lrs.append(optimizer.param_groups[0]['lr'])
        train_acc = correct / total
        train_accuracies.append(train_acc)
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}")

        # Validation
        model.eval()
        epoch_val_loss = 0
        correct, total = 0, 0
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{total_epochs} [Val]"):
                x, labels = batch['image1'].to(device), batch['image2'].to(device)
                output = model(x)
                loss = loss_fn(output, labels)
                epoch_val_loss += loss.item()
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)
        avg_val_loss = epoch_val_loss / len(val_dataloader)
        val_losses.append(avg_val_loss)
        val_acc = correct / total
        val_accuracies.append(val_acc)
        scheduler.step()
        print(f"Epoch {epoch+1}, Val Loss: {avg_val_loss:.4f}, Train Acc: {val_acc:.4f}")
        # Logging
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {lrs[-1]:.6f}")

        # Checkpointing based on val loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'lrs': lrs,
                'train_accuracies': train_accuracies,
                'val_accuracies': val_accuracies,
            }
            torch.save(checkpoint, checkpoint_path)
            epochs_without_improvement = 0
            print(f"✅ Saved new best model at epoch {epoch+1}")
        else:
            epochs_without_improvement += 1
            print(f"⏳ No improvement for {epochs_without_improvement} epoch(s)")

        # Early stopping
        if epochs_without_improvement >= early_stopping_patience:
            print(f"⛔ Early stopping at epoch {epoch+1} (no improvement for {early_stopping_patience} epochs)")
            break

        # Plot every `plot_every` epochs
        if (epoch + 1) % plot_every == 0 or (epoch + 1) == total_epochs:
            plt.figure(figsize=(15, 5))

            # Loss plot
            plt.subplot(1, 3, 1)
            plt.plot(train_losses, label="Train Loss")
            plt.plot(val_losses, label="Val Loss")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.legend()

            # Accuracy plot
            plt.subplot(1, 3, 2)
            plt.plot(train_accuracies, label="Train Acc")
            plt.plot(val_accuracies, label="Val Acc")
            plt.xlabel("Epoch")
            plt.ylabel("Accuracy")
            plt.legend()

            # Learning rate plot
            plt.subplot(1, 3, 3)
            plt.plot(lrs, label="Learning Rate", color='orange')
            plt.xlabel("Epoch")
            plt.ylabel("LR")
            plt.legend()

            plt.suptitle(f"Epoch {epoch+1}")
            plt.tight_layout()
            plt.show()

In [5]:
def train_model(model, train_loader, val_loader, criterion, optimizer, 
                device, num_epochs=5, checkpoint_path=None,early_stopping_patience=10, scheduler=None
                ,data_type='image'):
    start_time=datetime.now()
    train_losses = []
    val_losses = []
    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        train_loss, correct, total = 0, 0, 0

        # Training Loop
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = batch[data_type], batch['label']
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        train_acc = correct / total
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}")

        # Validation Loop
        model.eval()
        val_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, labels = batch[data_type], batch['label']
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        val_acc = correct / total
        print(f"Epoch {epoch+1}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Save checkpoint if loss improves
        if checkpoint_path and avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_loss,
                'val_acc': val_acc,
                'train_loss': avg_train_loss,
                'train_acc': train_acc,
                'epoch': epoch,
                'time_from_start': datetime.now()-start_time,
            }
            torch.save(checkpoint, checkpoint_path+'best_checkpoint.pth')
            print(f"Checkpoint saved: {checkpoint_path}"+'best_checkpoint.pth')
        else:
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_loss,
                'val_acc': val_acc,
                'train_loss': avg_train_loss,
                'train_acc': train_acc,
                'epoch': epoch,
                'time_from_start': datetime.now()-start_time,
            }
            torch.save(checkpoint, checkpoint_path+'last_checkpoint.pth')
            print(f"Checkpoint saved: {checkpoint_path}"+'last_checkpoint.pth')
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break
        if scheduler:
            scheduler.step()
    
    return model, train_losses, val_losses

## data loading

In [6]:
class CustomPatchDataset(Dataset):
    def __init__(self, df, label_column,transform=None,huggingface=True):
        """
        Args:
            image_dirs (list of str): List of directories to load images from.
            labels_df (DataFrame): DataFrame containing labeled images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.image_files = df['file_name'].tolist()
        self.img_labels = df[label_column].tolist()
        self.img_writers = df['writer'].tolist()
        self.x1 = df['x'].tolist()
        self.y1 = df['y'].tolist()
        self.x2 = df['x2'].tolist()
        self.y2 = df['y2'].tolist()
        self.transform = transform
        self.huggingface = huggingface

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        x1=self.x1[idx]
        y1=self.y1[idx]
        x2=self.x2[idx]
        y2=self.y2[idx]
        image = Image.open(img_path).convert("RGB")
        patch = image.crop((x1, y1, x2, y2))
        writer=self.img_writers[idx]
        label = self.img_labels[idx]

        if self.huggingface:
            # the transform is actually an huggingface processor in this case
            inputs = self.transform(images=patch, return_tensors="pt")
            # Remove batch dimension from inputs
            patch = inputs['pixel_values'].squeeze()
        else:
            if self.transform:
                patch = self.transform(patch)

        return {
            'image': patch,
            'writer': int(writer),
            'label': label
        }

# initialization

In [14]:
import random

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [15]:
N_max=282
patches=True
input_filename='icdar_train_df_patches_20250515_164130.csv' #gw=5 m_patches=5
huggingface=True
pooling=False # if true in transformer mdoels use pooling, if false only the cls token
custom_transform=False
transform_mode='resize'
selected_model = 'clip-vit-large-patch14'
truncation = 'remove head'
running = 'new-laptop'
saved = 'old-laptop'
num_classes = 2
hidden_size = 1024

In [16]:
weight_decay = 0.01
base_lr = 3e-5
warmup_epochs = 10
total_epochs = 100
patience = 10
batch_size=32
checkpoint_path = f"{source_path}\\outputs\\models\\clip-vit\\checkpoint.pt"
p_train = 0.9

In [19]:
transform = u_transforms.get_transform(selected_model, use_patches=patches, custom=custom_transform, mode=transform_mode)

In [20]:
model = CLIPModel.from_pretrained(f'openai/clip-vit-large-patch14') # probably i need to train with reduced precision

In [21]:
custom_mlp = CustomMLP(input_size=384, hidden_sizes=[hidden_size], output_size=num_classes)
model = WrappedModel(model, custom_mlp, type_of_output='cls')

In [22]:
train_df = pd.read_csv(f"{source_path}\\outputs\\preprocessed_data\\{input_filename}")
train_df=file_IO.change_filename_from_to(train_df, fr=saved, to=running)

In [25]:
train_df.head()

Unnamed: 0,writer,isEng,same_text,file_name,male,train,index,x,y,x2,y2,n_cc
0,1,0,0,C:\Users\andre\PhD\Datasets\ICDAR 2013 - Gende...,0,1,0,0,493,493,986,111
1,1,0,0,C:\Users\andre\PhD\Datasets\ICDAR 2013 - Gende...,0,1,1,493,493,986,986,96
2,1,0,0,C:\Users\andre\PhD\Datasets\ICDAR 2013 - Gende...,0,1,2,1479,493,1972,986,99
3,1,0,0,C:\Users\andre\PhD\Datasets\ICDAR 2013 - Gende...,0,1,3,986,493,1479,986,90
4,1,0,0,C:\Users\andre\PhD\Datasets\ICDAR 2013 - Gende...,0,1,4,1972,493,2465,986,99


In [24]:
# Set the probability of being 0
N = train_df['writer'].nunique()

# Create a dataframe with writer column from 1 to 282
pages_df = pd.DataFrame({'writer': np.arange(1, N)})

# Add a train column that is randomly 0 or 1 with probability p of being 0
pages_df['train'] = np.random.choice([0, 1], size=len(pages_df), p=[1-p_train, p_train])

# Merge with the train_df dataframe on the writer column
train_df = train_df.merge(pages_df, on='index', how='left')

# Display the dataframe
train_df.head()

KeyError: 'index'

In [None]:
N_max=N
train_dataset = CustomPatchDataset(train_df[(train_df['train']==1) & (train_df['index']<=N_max)] ,transform=transform, huggingface=huggingface)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomPatchDataset(train_df[(train_df['train']==0) & (train_df['index']<=N_max)] , transform=transform, huggingface=huggingface)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# run 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(f"[GPU Memory] Allocated: {torch.cuda.memory_allocated() / 1e6:.2f} MB | Reserved: {torch.cuda.memory_reserved() / 1e6:.2f} MB")

In [None]:
train(
    model,
    train_dataloader,
    val_dataloader,
    device,
    base_lr,
    weight_decay=weight_decay,
    warmup_epochs=warmup_epochs,
    total_epochs=total_epochs,
    checkpoint_path=checkpoint_path,
    plot_every=5,
    early_stopping_patience=patience
)

# tests

In [None]:
#flow check
model = CLIPModel.from_pretrained(f'openai/clip-vit-large-patch14') # probably i need to train with reduced precision
custom_mlp = CustomMLP(input_size=384, hidden_sizes=[hidden_size], output_size=num_classes)
model = WrappedModel(model, custom_mlp, type_of_output='cls')
input = torch.randn(2, 3, 224, 224)  # Use actual image tensors or pixel_values

output = model(input)
loss = torch.nn.CrossEntropyLoss()(output, torch.tensor([0, 1]))  # Example labels
loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, "has gradients")

# easy access

In [18]:
def reload_modules():
    import importlib
    import utils.image_processing as image_processing
    import utils.file_IO as file_IO
    import utils.visualization as visualization
    import utils.tests as tests
    import utils.data_loading as data_loading
    import utils.utils_transforms as u_transforms

    importlib.reload(data_loading)
    importlib.reload(file_IO)
    importlib.reload(image_processing)
    importlib.reload(visualization)
    importlib.reload(tests)
    importlib.reload(u_transforms)

    return image_processing, file_IO, visualization, tests, data_loading, u_transforms
image_processing, file_IO, visualization, tests, data_loading, u_transforms = reload_modules()