# Train a small transformer model on preprocessed segments

This notebook loads `segments_preproc_24.csv` produced by the preprocessing notebook, builds segment-level sequences, and trains a small transformer-based classifier.

In [2]:
# Imports and configuration
import os
import math
import numpy as np
import pandas as pd
from typing import Dict, Any

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix, 
    ConfusionMatrixDisplay, f1_score, roc_auc_score, roc_curve, 
    auc, precision_recall_curve, average_precision_score
)
from sklearn.preprocessing import label_binarize

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger

import wandb

BASE_DATA_DIR = os.path.abspath("../data")
EXPORT_DIR = os.path.join(BASE_DATA_DIR, "export")
PREPROC_CSV = os.path.join(EXPORT_DIR, "segments_preproc_24.csv")

print("Using preprocessed file:", PREPROC_CSV)
assert os.path.exists(PREPROC_CSV), f"Preprocessed CSV not found: {PREPROC_CSV}"
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Wandb version: {wandb.__version__}")

Using preprocessed file: /work/data/export/segments_preproc_24.csv
PyTorch Lightning version: 2.6.0
CUDA available: True
Wandb version: 0.23.1


In [3]:
# Set random seeds for reproducibility
SEED = 1
seed_everything(SEED, workers=True)
print(f"Random seed locked to {SEED}")

Seed set to 1


Random seed locked to 1


In [4]:
# Wandb authentication - paste your token here
WANDB_API_TOKEN = "e5e50a1ad57d78a1def2302321a9d83243fe6fd8"  # Replace with your actual token

# Login to wandb
wandb.login(key=WANDB_API_TOKEN)
print("Successfully logged in to Weights & Biases!")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33meznagyonkellettmz[0m ([33meznagyonkellettmz-budapesti-m-szaki-s-gazdas-gtudom-nyi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33meznagyonkellettmz[0m ([33meznagyonkellettmz-budapesti-m-szaki-s-gazdas-gtudom-nyi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Successfully logged in to Weights & Biases!


In [5]:
# Model evaluation configuration
# Change this single variable to switch the primary metric for model selection
# Options: 'auc_ovo', 'auc_ovr', 'f1', 'accuracy', 'pr_auc'
PRIMARY_METRIC = 'pr_auc'

# Metric display names and whether higher is better
METRIC_CONFIG = {
    'auc_ovo': {
        'name': 'AUC-ROC (OvO)',
        'short': 'ovo',
        'higher_is_better': True,
        'monitor': 'val_auc_ovo',
        'description': 'One-vs-One: evaluates all pairwise class comparisons'
    },
    'auc_ovr': {
        'name': 'AUC-ROC (OvR)',
        'short': 'ovr',
        'higher_is_better': True,
        'monitor': 'val_auc_ovr',
        'description': 'One-vs-Rest: evaluates each class vs all others'
    },
    'f1': {
        'name': 'F1 Score (macro)',
        'short': 'f1',
        'higher_is_better': True,
        'monitor': 'val_f1',
        'description': 'Harmonic mean of precision and recall'
    },
    'accuracy': {
        'name': 'Accuracy',
        'short': 'acc',
        'higher_is_better': True,
        'monitor': 'val_accuracy',
        'description': 'Proportion of correct predictions'
    },
    'pr_auc': {
        'name': 'PR-AUC (macro)',
        'short': 'pr',
        'higher_is_better': True,
        'monitor': 'val_pr_auc',
        'description': 'Precision-Recall curve area (macro): better for imbalanced classes'
    }
}

print(f"Primary metric: {METRIC_CONFIG[PRIMARY_METRIC]['name']}")
print(f"Description: {METRIC_CONFIG[PRIMARY_METRIC]['description']}")

Primary metric: PR-AUC (macro)
Description: Precision-Recall curve area (macro): better for imbalanced classes


In [6]:
# Load preprocessed dataset and build segment-level sequences
df = pd.read_csv(PREPROC_CSV)
print("Raw preprocessed shape:", df.shape)

# Ensure correct ordering within each segment
df = df.sort_values(["segment_id", "seq_pos"], kind="mergesort").reset_index(drop=True)

feature_cols = [c for c in df.columns if c not in ["segment_id", "label", "csv_file", "seq_pos"]]
print("Feature columns (", len(feature_cols), "):", feature_cols)

# Group into (segment, sequence of length 24, label)
segments = []
labels = []

for seg_id, g in df.groupby("segment_id", sort=True):
    g = g.sort_values("seq_pos", kind="mergesort")
    feat = g[feature_cols].to_numpy(dtype=np.float32)
    # Expect 24 steps; if shorter/longer, adjust with simple strategies
    if feat.shape[0] < 24:
        # pad by repeating last step
        pad = np.repeat(feat[-1:, :], 24 - feat.shape[0], axis=0)
        feat = np.concatenate([feat, pad], axis=0)
    elif feat.shape[0] > 24:
        # truncate extra steps
        feat = feat[:24, :]

    assert feat.shape[0] == 24, feat.shape
    segments.append(feat)
    labels.append(g["label"].iloc[0])

X = np.stack(segments, axis=0)  # (N, 24, F)
y = np.array(labels)

print("Num segments:", X.shape[0], "Seq len:", X.shape[1], "Num features:", X.shape[2])
print("Label distribution:")
print(pd.Series(y).value_counts())

Raw preprocessed shape: (3456, 12)
Feature columns ( 8 ): ['open_norm', 'high_norm', 'low_norm', 'close_norm', 'vol_close', 'vol_high_low', 'compression_ratio', 'trend']
Num segments: 144 Seq len: 24 Num features: 8
Label distribution:
Bullish Normal     40
Bearish Normal     27
Bearish Pennant    26
Bullish Pennant    22
Bullish Wedge      15
Bearish Wedge      14
Name: count, dtype: int64


In [7]:
# Encode labels as integers
label_values = np.sort(pd.unique(y))
label_to_idx = {lbl: i for i, lbl in enumerate(label_values)}
idx_to_label = {i: lbl for lbl, i in label_to_idx.items()}

y_idx = np.vectorize(label_to_idx.get)(y)
num_classes = len(label_values)
print("Classes:", label_values, "-> num_classes =", num_classes)

Classes: ['Bearish Normal' 'Bearish Pennant' 'Bearish Wedge' 'Bullish Normal'
 'Bullish Pennant' 'Bullish Wedge'] -> num_classes = 6


In [8]:
# Train/validation split at segment level
X_train, X_val, y_train, y_val = train_test_split(
    X, y_idx, test_size=0.2, random_state=SEED, stratify=y_idx,
 )

print("Train segments:", X_train.shape[0])
print("Val segments:", X_val.shape[0])

class SegmentDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)  # (N, T, F)
        self.y = torch.from_numpy(y).long()
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = SegmentDataset(X_train, y_train)
val_ds = SegmentDataset(X_val, y_val)

# Calculate class weights for imbalanced dataset
class_counts = np.bincount(y_idx)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_weights)
print(f"Class distribution: {class_counts}")
print(f"Class weights: {class_weights}")

Train segments: 115
Val segments: 29
Class distribution: [27 26 14 40 22 15]
Class weights: [0.7823394  0.81242937 1.50879741 0.52807909 0.96014381 1.40821092]


In [9]:
# PyTorch Lightning Module with Wandb integration
class FlagPatternClassifier(pl.LightningModule):
    def __init__(
        self,
        input_dim: int,
        num_classes: int,
        class_weights: np.ndarray,
        hidden_channels: int = 64,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        batch_size: int = 12
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['class_weights'])
        
        # Model architecture
        self.conv = nn.Sequential(
            nn.Conv1d(input_dim, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv1d(hidden_channels, hidden_channels * 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_channels * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv1d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_channels * 2),
            nn.ReLU(),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(),
            nn.Linear(hidden_channels // 2, num_classes),
        )
        
        # Loss function with class weights
        self.class_weights = torch.FloatTensor(class_weights)
        self.criterion = nn.CrossEntropyLoss(weight=self.class_weights)
        
        # Store predictions for epoch-end metrics
        self.validation_step_outputs = []
        self.training_step_outputs = []
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        x = x.transpose(1, 2)  # (B, T, F) -> (B, F, T)
        h = self.conv(x)
        h = self.pool(h).squeeze(-1)
        logits = self.fc(h)
        return logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        
        self.training_step_outputs.append({
            'loss': loss,
            'preds': preds.detach().cpu(),
            'probs': probs.detach().cpu(),
            'targets': y.detach().cpu()
        })
        
        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        
        self.validation_step_outputs.append({
            'loss': loss,
            'preds': preds.detach().cpu(),
            'probs': probs.detach().cpu(),
            'targets': y.detach().cpu()
        })
        
        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    def on_train_epoch_end(self):
        self._compute_epoch_metrics(self.training_step_outputs, 'train')
        self.training_step_outputs.clear()
    
    def on_validation_epoch_end(self):
        self._compute_epoch_metrics(self.validation_step_outputs, 'val')
        self.validation_step_outputs.clear()
    
    def _compute_epoch_metrics(self, outputs, prefix):
        all_preds = torch.cat([x['preds'] for x in outputs]).numpy()
        all_probs = torch.cat([x['probs'] for x in outputs]).numpy()
        all_targets = torch.cat([x['targets'] for x in outputs]).numpy()
        
        # Calculate metrics
        accuracy = accuracy_score(all_targets, all_preds)
        f1 = f1_score(all_targets, all_preds, average='macro')
        
        try:
            auc_ovo = roc_auc_score(all_targets, all_probs, multi_class='ovo', average='macro')
            auc_ovr = roc_auc_score(all_targets, all_probs, multi_class='ovr', average='macro')
        except ValueError:
            auc_ovo = 0.0
            auc_ovr = 0.0
        
        try:
            y_bin = label_binarize(all_targets, classes=range(all_probs.shape[1]))
            pr_auc_per_class = []
            for i in range(all_probs.shape[1]):
                pr_auc_per_class.append(average_precision_score(y_bin[:, i], all_probs[:, i]))
            pr_auc = np.mean(pr_auc_per_class)
        except ValueError:
            pr_auc = 0.0
        
        # Log all metrics (wandb will automatically capture these)
        self.log(f'{prefix}_accuracy', accuracy, prog_bar=True)
        self.log(f'{prefix}_f1', f1, prog_bar=True)
        self.log(f'{prefix}_auc_ovo', auc_ovo, prog_bar=True)
        self.log(f'{prefix}_auc_ovr', auc_ovr, prog_bar=True)
        self.log(f'{prefix}_pr_auc', pr_auc, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=20
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        }

# Initial model creation for reference (will be recreated during sweep)
batch_size = 12

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)

model = FlagPatternClassifier(
    input_dim=X.shape[2],
    num_classes=num_classes,
    class_weights=class_weights,
    hidden_channels=64,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=batch_size
)

print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

FlagPatternClassifier(
  (conv): Sequential(
    (0): Conv1d(8, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (9): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.1, inplace=False)
    (12): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (13): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
  )
  (pool): AdaptiveAvgPool1d(output_size=1)
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (

In [None]:
# Wandb Sweep Configuration
# The sweep will optimize for val_pr_auc (validation PR-AUC)
sweep_config = {
    'method': 'bayes',  # Bayesian optimization
    'metric': {
        'name': 'val_pr_auc',
        'goal': 'maximize'
    },
    'parameters': {
        'lr': {
            'distribution': 'log_uniform_values',
            'min': 1e-4,
            'max': 1e-2
        },
        'batch_size': {
            'values': [8, 12, 16, 24, 32]
        },
        'hidden_channels': {
            'values': [32, 64, 128,]
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 1e-5,
            'max': 1e-3
        },
        'max_epochs': {
            'value': 60  # Fixed value
        }
    }
}

print("Sweep configuration:")
print(f"  Optimization metric: {sweep_config['metric']['name']} ({sweep_config['metric']['goal']})")
print(f"  Method: {sweep_config['method']}")
print(f"  Hyperparameters to optimize:")
print(f"    - Learning rate: {sweep_config['parameters']['lr']['min']} to {sweep_config['parameters']['lr']['max']}")
print(f"    - Batch size: {sweep_config['parameters']['batch_size']['values']}")
print(f"    - Hidden channels: {sweep_config['parameters']['hidden_channels']['values']}")
print(f"    - Weight decay: {sweep_config['parameters']['weight_decay']['min']} to {sweep_config['parameters']['weight_decay']['max']}")

Sweep configuration:
  Optimization metric: val_pr_auc (maximize)
  Method: bayes
  Hyperparameters to optimize:
    - Learning rate: 0.0001 to 0.01
    - Batch size: [8, 12, 16, 24, 32]
    - Hidden channels: [32, 64, 128, 256]
    - Weight decay: 1e-05 to 0.001


In [11]:
# Training function for wandb sweep
def train_sweep():
    """
    Training function called by wandb agent for each sweep run.
    Wandb automatically injects hyperparameters via wandb.config.
    """
    # Initialize wandb run
    with wandb.init() as run:
        # Get hyperparameters from wandb config
        config = wandb.config
        
        # Create data loaders with sweep batch size
        sweep_train_loader = DataLoader(
            train_ds, 
            batch_size=config.batch_size, 
            shuffle=True, 
            num_workers=0
        )
        sweep_val_loader = DataLoader(
            val_ds, 
            batch_size=config.batch_size, 
            shuffle=False, 
            num_workers=0
        )
        
        # Create model with sweep hyperparameters
        sweep_model = FlagPatternClassifier(
            input_dim=X.shape[2],
            num_classes=num_classes,
            class_weights=class_weights,
            hidden_channels=config.hidden_channels,
            lr=config.lr,
            weight_decay=config.weight_decay,
            batch_size=config.batch_size
        )
        
        # Setup wandb logger
        wandb_logger = WandbLogger(
            project='flag-pattern-classifier',
            log_model=False  # Don't save models during sweep to save space
        )
        
        # Setup checkpoint callback (based on val_pr_auc)
        checkpoint_callback = ModelCheckpoint(
            dirpath=os.path.join(EXPORT_DIR, "checkpoints", f"sweep_{run.id}"),
            filename='best_model_{epoch:02d}_{val_pr_auc:.4f}',
            monitor='val_pr_auc',
            mode='max',
            save_top_k=1,
            verbose=False
        )
        
        # Early stopping to prevent wasting time on bad runs
        early_stop_callback = EarlyStopping(
            monitor='val_pr_auc',
            patience=10,
            mode='max',
            verbose=False
        )
        
        # Create trainer
        sweep_trainer = pl.Trainer(
            max_epochs=config.max_epochs,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            devices=1,
            callbacks=[checkpoint_callback, early_stop_callback],
            logger=wandb_logger,
            deterministic=True,
            log_every_n_steps=10,
            enable_progress_bar=False,  # Disable for cleaner sweep output
            enable_model_summary=False
        )
        
        # Train
        sweep_trainer.fit(sweep_model, sweep_train_loader, sweep_val_loader)
        
        # Log best score
        print(f"Run {run.id}: Best val_pr_auc = {checkpoint_callback.best_model_score:.4f}")

print("Training function defined. Ready to start sweep!")

Training function defined. Ready to start sweep!


In [None]:
# Initialize and run the wandb sweep
# This will create a sweep on wandb servers and run the agent locally

# Create the sweep
sweep_id = wandb.sweep(sweep_config, project='flag-pattern-classifier')
print(f"Sweep created with ID: {sweep_id}")
print(f"View sweep at: https://wandb.ai/[your-username]/flag-pattern-classifier/sweeps/{sweep_id}")

# Run the sweep agent
# count=10 means it will run 10 different hyperparameter configurations
# You can adjust this number or remove it to run indefinitely until stopped
wandb.agent(sweep_id, function=train_sweep, count=100)

print("\nSweep completed! Check your wandb dashboard for results.")

In [None]:
# Get the best run from the sweep (run this after the sweep completes)
api = wandb.Api()

# You'll need to get your username from wandb - you can find it in your dashboard URL
# or run: wandb.Api().viewer()['entity']
wandb_username = api.viewer()['entity']
sweep_path = f"{wandb_username}/flag-pattern-classifier/{sweep_id}"

print(f"Fetching sweep: {sweep_path}")
sweep = api.sweep(sweep_path)
best_run = sweep.best_run()

print("\nBest run configuration:")
print(f"  Run ID: {best_run.id}")
print(f"  Run name: {best_run.name}")
print(f"  Best val_pr_auc: {best_run.summary.get('val_pr_auc', 'N/A')}")
print(f"\nBest hyperparameters:")
for param in ['lr', 'batch_size', 'hidden_channels', 'weight_decay']:
    print(f"  {param}: {best_run.config.get(param, 'N/A')}")

print(f"\nAll metrics:")
print(f"  Accuracy: {best_run.summary.get('val_accuracy', 'N/A'):.4f}")
print(f"  F1: {best_run.summary.get('val_f1', 'N/A'):.4f}")
print(f"  PR-AUC: {best_run.summary.get('val_pr_auc', 'N/A'):.4f}")
print(f"  AUC-OvO: {best_run.summary.get('val_auc_ovo', 'N/A'):.4f}")
print(f"  AUC-OvR: {best_run.summary.get('val_auc_ovr', 'N/A'):.4f}")