# Transformer-based Time Series Classification


In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import fbeta_score, precision_score, recall_score
import numpy as np
from collections import Counter

from parser import load_saved_dataframe
from helper import *


# Dataset


In [2]:
class TimeSeriesDataset(Dataset):
    def __init__(self, df):
        self.X = torch.tensor(
            [
                [
                    row['timeseries_V1'],
                    row['timeseries_V2'],
                    row['timeseries_V3'],
                    row['timeseries_V4'],
                    row['timeseries_V5'],
                    row['timeseries_V6'],
                ]
                for _, row in df.iterrows()
            ],
            dtype=torch.float32
        )  # shape: (samples, 6, series_length)
        self.y = torch.tensor(
            [0 if row['group'] == 'control' else 1 for _, row in df.iterrows()],
            dtype=torch.long
        )

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# Transformer Model


In [3]:
class TimeSeriesTransformer(nn.Module):
    def __init__(self, input_channels=6, num_classes=2, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.input_proj = nn.Linear(input_channels, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # x: (batch, channels, seq_len) -> (batch, seq_len, channels)
        x = x.permute(0, 2, 1)
        x = self.input_proj(x)  # (batch, seq_len, d_model)
        x = self.transformer_encoder(x)  # (batch, seq_len, d_model)
        x = x.permute(0, 2, 1)  # (batch, d_model, seq_len)
        x = self.global_pool(x).squeeze(-1)  # (batch, d_model)
        x = self.fc(x)  # (batch, num_classes)
        return x


# Utility Functions


In [4]:
def calculate_class_weights(y):
    class_counts = Counter(y)
    total_samples = len(y)
    class_weights = {}
    for class_id, count in class_counts.items():
        class_weights[class_id] = total_samples / (len(class_counts) * count)
    return class_weights

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
            predictions = torch.argmax(outputs, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(y_batch.cpu().numpy())
    avg_loss = total_loss / len(dataloader)
    f2 = fbeta_score(all_targets, all_predictions, beta=2, average='weighted')
    precision = precision_score(all_targets, all_predictions, average='weighted')
    recall = recall_score(all_targets, all_predictions, average='weighted')
    return avg_loss, f2, precision, recall

def train_model_with_early_stopping(model, train_loader, val_loader,
                                    criterion, optimizer, scheduler,
                                    device, num_epochs=10,
                                    patience=10, min_delta=0.001):
    """Train model with early stopping based on validation F2 score"""

    # History tracking
    train_losses = []
    val_losses = []
    train_f2_scores = []
    val_f2_scores = []

    # Early stopping variables
    best_val_f2 = 0
    patience_counter = 0
    best_model_state = None

    print(f"Training for up to {num_epochs} epochs with early stopping...")
    print(f"Early stopping patience: {patience} epochs")
    print(f"Minimum improvement delta: {min_delta}")
    print("-" * 60)

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_preds = []
        train_targets = []

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            train_targets.extend(y_batch.cpu().numpy())

        # Compute train F2 for this epoch
        train_f2 = fbeta_score(train_targets, train_preds, beta=2, average='weighted')
        train_f2_scores.append(train_f2)

        # Validation phase
        val_loss, val_f2, val_precision, val_recall = evaluate_model(
            model, val_loader, criterion, device
        )

        # Update learning rate
        scheduler.step(val_loss)

        # Track metrics
        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(val_loss)
        val_f2_scores.append(val_f2)

        # Print progress
        if epoch % 5 == 0 or epoch < 10:
            print(f"Epoch {epoch + 1:3d}/{num_epochs} | "
                  f"Train Loss: {avg_train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val F2: {val_f2:.4f} | "
                  f"Val Precision: {val_precision:.4f} | "
                  f"Val Recall: {val_recall:.4f}")

        # Early stopping check
        if val_f2 > best_val_f2 + min_delta:
            best_val_f2 = val_f2
            patience_counter = 0
            best_model_state = model.state_dict().copy()
            print(f"New best validation F2: {val_f2:.4f} (epoch {epoch + 1})")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs")
            print(f"Best validation F2: {best_val_f2:.4f}")
            break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Loaded best model weights")

    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_f2_scores': train_f2_scores,
        'val_f2_scores': val_f2_scores,
        'best_val_f2': best_val_f2,
        'epochs_trained': epoch + 1
    }

def get_predictions(model, dataloader, device):
    model.eval()
    preds, targets = [], []
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            predictions = torch.argmax(outputs, dim=1)
            preds.extend(predictions.cpu().numpy())
            targets.extend(y_batch.numpy())
    return np.array(targets), np.array(preds)


# Data Loading and Splitting


In [5]:
df = load_saved_dataframe("timeseries_data.pkl")
series_length = len(df.iloc[0]['timeseries_V1'])

train_df, temp_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['group']
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, random_state=42, stratify=temp_df['group']
)

train_dataset = TimeSeriesDataset(train_df)
val_dataset = TimeSeriesDataset(val_df)
test_dataset = TimeSeriesDataset(test_df)

train_labels = [0 if row['group'] == 'control' else 1 for _, row in train_df.iterrows()]
class_weights = calculate_class_weights(train_labels)
weight_tensor = torch.tensor([class_weights[0], class_weights[1]], dtype=torch.float32)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


DataFrame loaded successfully from timeseries_data.pkl


# Model, Optimizer, Training


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TimeSeriesTransformer(
    input_channels=6,
    num_classes=2,
    d_model=8,
    nhead=1,
    num_layers=1
)
model.to(device)
weight_tensor = weight_tensor.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(weight=weight_tensor)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

history = train_model_with_early_stopping(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=50,
    patience=15,
    min_delta=0.001
)



Training for up to 50 epochs with early stopping...
Early stopping patience: 15 epochs
Minimum improvement delta: 0.001
------------------------------------------------------------


# Training History Plots


In [None]:
plot_training_history(history)

# Evaluation


In [None]:
val_loss, val_f2, val_precision, val_recall = evaluate_model(
    model, val_loader, criterion, device
)
test_loss, test_f2, test_precision, test_recall = evaluate_model(
    model, test_loader, criterion, device
)

print(f"\nValidation Loss: {val_loss:.4f}")
print(f"Validation F2 Score: {val_f2:.4f}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test F2 Score: {test_f2:.4f}")
print(f"Best Validation F2: {history['best_val_f2']:.4f}")
print(f"Total Epochs Trained: {history['epochs_trained']}")

# Confusion Matrices


In [None]:
val_targets, val_preds = get_predictions(model, val_loader, device)
test_targets, test_preds = get_predictions(model, test_loader, device)

plot_confusion_matrices(
    [val_targets, test_targets],
    [val_preds, test_preds],
    ["Validation Confusion Matrix", "Test Confusion Matrix"]
)

# Save Model

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_val_f2': history['best_val_f2'],
    'test_f2': test_f2,
    'class_weights': class_weights,
    'history': history
}, 'best_transformer_model.pth')
print(f"\nModel saved as 'best_transformer_model.pth'")