# Train a small transformer model on preprocessed segments

This notebook loads `segments_preproc_24.csv` produced by the preprocessing notebook, builds segment-level sequences, and trains a small transformer-based classifier.

In [18]:
# Imports and configuration
import os
import math
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

BASE_DATA_DIR = os.path.abspath("../data")
EXPORT_DIR = os.path.join(BASE_DATA_DIR, "export")
PREPROC_CSV = os.path.join(EXPORT_DIR, "segments_preproc_24.csv")

print("Using preprocessed file:", PREPROC_CSV)
assert os.path.exists(PREPROC_CSV), f"Preprocessed CSV not found: {PREPROC_CSV}"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Using preprocessed file: /work/data/export/segments_preproc_24.csv
Device: cuda


In [19]:
# Load preprocessed dataset and build segment-level sequences
df = pd.read_csv(PREPROC_CSV)
print("Raw preprocessed shape:", df.shape)

# Ensure correct ordering within each segment
df = df.sort_values(["segment_id", "seq_pos"], kind="mergesort").reset_index(drop=True)

feature_cols = [c for c in df.columns if c not in ["segment_id", "label", "csv_file", "seq_pos"]]
print("Feature columns (", len(feature_cols), "):", feature_cols)

# Group into (segment, sequence of length 24, label)
segments = []
labels = []

for seg_id, g in df.groupby("segment_id", sort=True):
    g = g.sort_values("seq_pos", kind="mergesort")
    feat = g[feature_cols].to_numpy(dtype=np.float32)
    # Expect 24 steps; if shorter/longer, adjust with simple strategies
    if feat.shape[0] < 24:
        # pad by repeating last step
        pad = np.repeat(feat[-1:, :], 24 - feat.shape[0], axis=0)
        feat = np.concatenate([feat, pad], axis=0)
    elif feat.shape[0] > 24:
        # truncate extra steps
        feat = feat[:24, :]

    assert feat.shape[0] == 24, feat.shape
    segments.append(feat)
    labels.append(g["label"].iloc[0])

X = np.stack(segments, axis=0)  # (N, 24, F)
y = np.array(labels)

print("Num segments:", X.shape[0], "Seq len:", X.shape[1], "Num features:", X.shape[2])
print("Label distribution:")
print(pd.Series(y).value_counts())

Raw preprocessed shape: (1392, 12)
Feature columns ( 8 ): ['Open_norm', 'High_norm', 'Low_norm', 'Close_norm', 'vol_close', 'vol_high_low', 'compression_ratio', 'trend']
Num segments: 58 Seq len: 24 Num features: 8
Label distribution:
Bearish Pennant    15
Bullish Normal     14
Bullish Wedge       9
Bullish Pennant     9
Bearish Wedge       6
Bearish Normal      5
Name: count, dtype: int64


In [20]:
# Encode labels as integers
label_values = np.sort(pd.unique(y))
label_to_idx = {lbl: i for i, lbl in enumerate(label_values)}
idx_to_label = {i: lbl for lbl, i in label_to_idx.items()}

y_idx = np.vectorize(label_to_idx.get)(y)
num_classes = len(label_values)
print("Classes:", label_values, "-> num_classes =", num_classes)

Classes: ['Bearish Normal' 'Bearish Pennant' 'Bearish Wedge' 'Bullish Normal'
 'Bullish Pennant' 'Bullish Wedge'] -> num_classes = 6


In [21]:
# Train/validation split at segment level
X_train, X_val, y_train, y_val = train_test_split(
    X, y_idx, test_size=0.2, random_state=42, stratify=y_idx,
 )

print("Train segments:", X_train.shape[0])
print("Val segments:", X_val.shape[0])

class SegmentDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)  # (N, T, F)
        self.y = torch.from_numpy(y).long()
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = SegmentDataset(X_train, y_train)
val_ds = SegmentDataset(X_val, y_val)

batch_size = 12
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

len(train_loader), len(val_loader)

Train segments: 46
Val segments: 12


(4, 1)

In [22]:
# Simple CNN-based classifier for pattern recognition
class SimpleCNN(nn.Module):
    def __init__(self, input_dim: int, num_classes: int, hidden_channels: int = 32):
        super().__init__()
        # Input: (B, T, F) -> rearrange to (B, F, T) for Conv1d
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=input_dim, out_channels=hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)  # global average over time
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, num_classes),
        )
        # Apply Kaiming (He) initialization for ReLU layers
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):  # x: (B, T, F)
        x = x.transpose(1, 2)  # (B, F, T)
        h = self.conv(x)
        h = self.pool(h).squeeze(-1)  # (B, C)
        logits = self.fc(h)
        return logits

model = SimpleCNN(input_dim=X.shape[2], num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(model)

SimpleCNN(
  (conv): Sequential(
    (0): Conv1d(8, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (3): ReLU()
  )
  (pool): AdaptiveAvgPool1d(output_size=1)
  (fc): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=6, bias=True)
  )
)


In [23]:
# Training loop using the simple CNN
def run_epoch(loader, model, criterion, optimizer=None):
    is_train = optimizer is not None
    total_loss = 0.0
    all_preds = []
    all_targets = []

    if is_train:
        model.train()
    else:
        model.eval()

    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        with torch.set_grad_enabled(is_train):
            logits = model(xb)
            loss = criterion(logits, yb)
            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1).detach().cpu().numpy()
        all_preds.append(preds)
        all_targets.append(yb.detach().cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    return avg_loss, acc

num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = run_epoch(train_loader, model, criterion, optimizer)
    val_loss, val_acc = run_epoch(val_loader, model, criterion, optimizer=None)
    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} acc={val_acc:.4f}")

Epoch 01 | train_loss=1.8676 acc=0.1304 | val_loss=1.8365 acc=0.0833
Epoch 02 | train_loss=1.7734 acc=0.2174 | val_loss=1.7998 acc=0.1667
Epoch 03 | train_loss=1.7454 acc=0.3043 | val_loss=1.7766 acc=0.2500
Epoch 04 | train_loss=1.7159 acc=0.2609 | val_loss=1.7622 acc=0.3333
Epoch 05 | train_loss=1.7034 acc=0.2826 | val_loss=1.7539 acc=0.3333
Epoch 05 | train_loss=1.7034 acc=0.2826 | val_loss=1.7539 acc=0.3333
Epoch 06 | train_loss=1.6910 acc=0.2826 | val_loss=1.7404 acc=0.3333
Epoch 06 | train_loss=1.6910 acc=0.2826 | val_loss=1.7404 acc=0.3333
Epoch 07 | train_loss=1.6756 acc=0.3043 | val_loss=1.7258 acc=0.3333
Epoch 08 | train_loss=1.6585 acc=0.3261 | val_loss=1.7142 acc=0.3333
Epoch 07 | train_loss=1.6756 acc=0.3043 | val_loss=1.7258 acc=0.3333
Epoch 08 | train_loss=1.6585 acc=0.3261 | val_loss=1.7142 acc=0.3333
Epoch 09 | train_loss=1.6412 acc=0.3478 | val_loss=1.7009 acc=0.3333
Epoch 10 | train_loss=1.6242 acc=0.3478 | val_loss=1.6907 acc=0.3333
Epoch 09 | train_loss=1.6412 acc=0

In [None]:
# Final evaluation and per-class metrics on validation set
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

model.eval()

def get_predictions(loader):
    all_preds = []
    all_targets = []
    for xb, yb in loader:
        xb = xb.to(device)
        with torch.no_grad():
            logits = model(xb)
            preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.append(preds)
        all_targets.append(yb.numpy())
    return np.concatenate(all_preds), np.concatenate(all_targets)

# Get predictions for train and validation sets
train_preds, train_targets = get_predictions(train_loader)
val_preds, val_targets = get_predictions(val_loader)

print(f"Train accuracy: {accuracy_score(train_targets, train_preds):.4f}")
print(f"Val accuracy: {accuracy_score(val_targets, val_preds):.4f}")

# Plot confusion matrices side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training confusion matrix
cm_train = confusion_matrix(train_targets, train_preds)
disp_train = ConfusionMatrixDisplay(confusion_matrix=cm_train, display_labels=label_values)
disp_train.plot(ax=axes[0], cmap='Blues', xticks_rotation=45)
axes[0].set_title("Training Confusion Matrix")

# Validation confusion matrix
cm_val = confusion_matrix(val_targets, val_preds)
disp_val = ConfusionMatrixDisplay(confusion_matrix=cm_val, display_labels=label_values)
disp_val.plot(ax=axes[1], cmap='Blues', xticks_rotation=45)
axes[1].set_title("Validation Confusion Matrix")

plt.tight_layout()
plt.show()

# Classification report for validation
print("\nClassification report (validation):")
print(classification_report(val_targets, val_preds, target_names=[str(lbl) for lbl in label_values]))

Final val_loss=1.6907, val_acc=0.3333
Classification report (validation):
                 precision    recall  f1-score   support

 Bearish Normal       0.00      0.00      0.00         1
Bearish Pennant       0.33      1.00      0.50         3
  Bearish Wedge       0.00      0.00      0.00         1
 Bullish Normal       0.00      0.00      0.00         3
Bullish Pennant       1.00      0.50      0.67         2
  Bullish Wedge       0.00      0.00      0.00         2

       accuracy                           0.33        12
      macro avg       0.22      0.25      0.19        12
   weighted avg       0.25      0.33      0.24        12



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
