In Colab, make sure you run cells in this order:

1. Imports
2. Config class
3. StockDataset class
4. sLSTMCell class (with state initialization)
5. mLSTMCell class (with state initialization)
6. xLSTMBlock class (with sequential processing)
7. xLSTM class
8. Trainer class
9. main() function
10. Call main()


In [148]:
!pip install torch pandas numpy tqdm wandb matplotlib



In [149]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import wandb
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

In [150]:
class Config:
    """Single config class for everything"""
    # Model
    num_features = 5        # Close, Volume, Open, High, Low
    embed_dim = 64
    num_blocks = 2
    block_types = ['slstm']  # alternating

    # Training
    batch_size = 2
    learning_rate = 3e-4
    num_epochs = 10
    seq_length = 50       # Lookback window
    prediction_horizon = 1  # Predict N days ahead
    normalize_data = True
    grad_clip = 1.0

    # Checkpoint Resume
    resume_from_checkpoint = 9  # Set to None to start from scratch

    # Paths
    train_path = "/content/Panantir-5Y.csv"
    val_path = "/content/Panantir-5Y.csv"
    checkpoint_dir = "checkpoints"

    # WandB
    use_wandb = False
    wandb_project = "xlstm-stock-prediction"
    wandb_run_name = None  # Auto-generated if None

    # Device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [151]:
class StockDataset(Dataset):
    """Load CSV stock data and create sequences for time-series prediction"""

    def __init__(self, path, seq_length=50, prediction_horizon=1, normalize=True):
        self.seq_length = seq_length
        self.prediction_horizon = prediction_horizon

        # Load and clean data
        df = pd.read_csv(path)
        df = df.sort_values('Date')  # Ensure chronological order

        # Clean price columns (remove $ and commas)
        price_cols = ['Close/Last', 'Open', 'High', 'Low']
        for col in price_cols:
            df[col] = df[col].str.replace('$', '').str.replace(',', '').astype(float)

        # Select features: Close, Volume, Open, High, Low
        features = df[['Close/Last', 'Volume', 'Open', 'High', 'Low']].values

        # Normalize
        if normalize:
            self.mean = features.mean(axis=0)
            self.std = features.std(axis=0)
            features = (features - self.mean) / (self.std + 1e-8)

        # Create sequences
        self.sequences = []
        for i in range(len(features) - seq_length - prediction_horizon + 1):
            input_seq = features[i:i + seq_length]
            target_seq = features[i + seq_length:i + seq_length + prediction_horizon]
            self.sequences.append((input_seq, target_seq))

    def __getitem__(self, idx):
        input_seq, target_seq = self.sequences[idx]
        return {
            'input_ids': torch.tensor(input_seq, dtype=torch.float32),
            'labels': torch.tensor(target_seq, dtype=torch.float32)
        }

    def __len__(self):
        return len(self.sequences)

In [152]:
class sLSTMCell(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim  # ADD THIS
        self.W_i = nn.Linear(dim*2, dim)
        self.W_f = nn.Linear(dim*2, dim)
        self.W_o = nn.Linear(dim*2, dim)
        self.W_z = nn.Linear(dim*2, dim)

    def forward(self, x, state):
        # Initialize state if None
        if state is None:
            batch_size = x.size(0)
            device = x.device
            h = torch.zeros(batch_size, self.dim, device=device)
            c = torch.zeros(batch_size, self.dim, device=device)
            n = torch.zeros(batch_size, self.dim, device=device)
            m = torch.zeros(batch_size, self.dim, device=device)
            state = (h, c, n, m)

        h, c, n, m = state
        combined = torch.cat([x, h], -1)

        i_t = self.W_i(combined)
        f_t = self.W_f(combined)
        o_t = self.W_o(combined)
        z = torch.tanh(self.W_z(combined))

        m_new = torch.max(f_t + m, i_t)
        i = torch.exp(i_t - m_new)
        f = torch.exp(f_t + m - m_new)

        c_new = f * c + i * z
        n_new = f * n + i
        h_new = torch.sigmoid(o_t) * (c_new / (n_new + 1e-6))

        return h_new, (h_new, c_new, n_new, m_new)

In [153]:
class mLSTMCell(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)
        self.W_i = nn.Linear(dim, dim)
        self.W_f = nn.Linear(dim, dim)
        self.W_o = nn.Linear(dim, dim)
        self.dim = dim

    def forward(self, x, state):
        # Initialize state if None
        if state is None:
            batch_size = x.size(0)
            device = x.device
            C = torch.zeros(batch_size, self.dim, self.dim, device=device)
            n = torch.zeros(batch_size, self.dim, device=device)
            m = torch.zeros(batch_size, self.dim, device=device)
            state = (C, n, m)

        C, n, m = state

        q = self.W_q(x)
        k = self.W_k(x) / (self.dim ** 0.5)
        v = self.W_v(x)
        i_t = self.W_i(x)
        f_t = self.W_f(x)

        m_new = torch.max(f_t + m, i_t)
        i = torch.exp(i_t - m_new).unsqueeze(-1)
        f = torch.exp(f_t + m - m_new).unsqueeze(-1)

        C_new = f * C + i * torch.bmm(v.unsqueeze(-1), k.unsqueeze(-2))
        n_new = f.squeeze(-1) * n + i.squeeze(-1) * k

        h = torch.sigmoid(self.W_o(x)) * torch.bmm(C_new, q.unsqueeze(-1)).squeeze(-1)

        return h, (C_new, n_new, m_new)

In [154]:
class xLSTMBlock(nn.Module):
    def __init__(self, dim, cell_type):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.cell = sLSTMCell(dim) if cell_type == 'slstm' else mLSTMCell(dim)
        self.cell_type = cell_type  # ADD THIS
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Linear(dim*4, dim)
        )

    def forward(self, x, state):
        batch_size, seq_len, dim = x.shape

        # Reset state to None (each block maintains its own state)
        state = None

        # Process sequence step by step
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :]  # [batch, dim]
            normed = self.norm1(x_t)
            h, state = self.cell(normed, state)

            # Add residual and FFN
            out = x_t + h
            out = out + self.ffn(self.norm2(out))
            outputs.append(out)

        # Stack outputs back to [batch, seq_len, dim]
        x = torch.stack(outputs, dim=1)

        return x, state

In [155]:
class xLSTM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_proj = nn.Linear(config.num_features, config.embed_dim)
        self.blocks = nn.ModuleList([
            xLSTMBlock(config.embed_dim, config.block_types[i % len(config.block_types)])
            for i in range(config.num_blocks)
        ])
        self.norm = nn.LayerNorm(config.embed_dim)
        self.head = nn.Linear(config.embed_dim, config.num_features)

    def forward(self, x):
        x = self.input_proj(x)
        state = None
        for block in self.blocks:
            x, state = block(x, state)
        x = self.head(self.norm(x))
        return x[:, -1:, :]  # Return only last timestep prediction

In [156]:
import gc
torch.cuda.empty_cache()
gc.collect()

924

In [157]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model.to(config.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config

        self.optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
        self.criterion = nn.MSELoss()
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'val_mae': [],
            'val_rmse': []
        }

        Path(config.checkpoint_dir).mkdir(exist_ok=True)

        # Initialize WandB
        if config.use_wandb:
            wandb.init(
                project=config.wandb_project,
                name=config.wandb_run_name,
                config={
                    "num_features": config.num_features,
                    "embed_dim": config.embed_dim,
                    "num_blocks": config.num_blocks,
                    "block_types": config.block_types,
                    "batch_size": config.batch_size,
                    "learning_rate": config.learning_rate,
                    "num_epochs": config.num_epochs,
                    "seq_length": config.seq_length,
                    "prediction_horizon": config.prediction_horizon,
                }
            )
            wandb.watch(self.model, log="all", log_freq=100)

    def load_checkpoint(self, checkpoint_path):
            """Load model and optimizer state from checkpoint"""
            checkpoint = torch.load(checkpoint_path, map_location=self.config.device)

            self.model.load_state_dict(checkpoint['model_state'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state'])

           # Load history if available
            if 'history' in checkpoint:
                self.history = checkpoint['history']

            start_epoch = checkpoint['epoch'] + 1

            print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
            print(f"Resuming training from epoch {start_epoch}")

            return start_epoch

    def plot_training_history(self):
        """Plot training metrics"""
        epochs = range(1, len(self.history['train_loss']) + 1)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # Plot loss
        ax1.plot(epochs, self.history['train_loss'], 'b-o', label='Train Loss')
        ax1.plot(epochs, self.history['val_loss'], 'r-o', label='Val Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss (MSE)')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)

        # Plot metrics
        ax2.plot(epochs, self.history['val_mae'], 'g-s', label='Val MAE')
        ax2.plot(epochs, self.history['val_rmse'], 'm-^', label='Val RMSE')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Error')
        ax2.set_title('Validation Metrics')
        ax2.legend()
        ax2.grid(True)

        plt.tight_layout()
        plt.savefig('training_history.png', dpi=150)
        plt.close()

    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(self.train_loader, desc=f"Training Epoch {epoch+1}")):
            inputs = batch['input_ids'].to(self.config.device)
            labels = batch['labels'].to(self.config.device)

            predictions = self.model(inputs)
            loss = self.criterion(predictions, labels)

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
            self.optimizer.step()

            total_loss += loss.item()

            # Log to WandB every 10 steps
            if self.config.use_wandb and step % 10 == 0:
                wandb.log({
                    "train_loss_step": loss.item(),
                    "epoch": epoch,
                    "step": epoch * len(self.train_loader) + step
                })

        return total_loss / len(self.train_loader)

    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_loss = 0
        total_mae = 0

        for batch in self.val_loader:
            inputs = batch['input_ids'].to(self.config.device)
            labels = batch['labels'].to(self.config.device)

            predictions = self.model(inputs)
            loss = self.criterion(predictions, labels)
            mae = torch.abs(predictions - labels).mean()

            total_loss += loss.item()
            total_mae += mae.item()

        avg_loss = total_loss / len(self.val_loader)
        avg_mae = total_mae / len(self.val_loader)
        rmse = np.sqrt(avg_loss)

        return avg_loss, avg_mae, rmse

    @torch.no_grad()
    def plot_predictions(self, num_samples=100):
        """Plot actual vs predicted prices"""
        self.model.eval()

        all_predictions = []
        all_actuals = []

        # Get predictions from validation set
        for i, batch in enumerate(self.val_loader):
            if len(all_predictions) >= num_samples:
                break

            inputs = batch['input_ids'].to(self.config.device)
            labels = batch['labels'].to(self.config.device)

            predictions = self.model(inputs)

            all_predictions.append(predictions.cpu())
            all_actuals.append(labels.cpu())

        # Concatenate all batches
        predictions = torch.cat(all_predictions, dim=0)[:num_samples]
        actuals = torch.cat(all_actuals, dim=0)[:num_samples]

        # Extract closing price (feature index 0)
        pred_close = predictions[:, 0, 0].numpy()
        actual_close = actuals[:, 0, 0].numpy()

        # Create plot
        plt.figure(figsize=(15, 8))

        # Top plot: Predictions vs Actuals
        plt.subplot(2, 1, 1)
        plt.plot(actual_close, label='Actual Price', color='blue', linewidth=2, alpha=0.7)
        plt.plot(pred_close, label='Predicted Price', color='red', linewidth=2, alpha=0.7)
        plt.title('Stock Price Predictions vs Actuals', fontsize=14, fontweight='bold')
        plt.xlabel('Time Steps')
        plt.ylabel('Normalized Price')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Bottom plot: Prediction Error
        plt.subplot(2, 1, 2)
        errors = pred_close - actual_close
        plt.plot(errors, color='purple', linewidth=1.5, alpha=0.7)
        plt.axhline(y=0, color='black', linestyle='--', linewidth=1)
        plt.fill_between(range(len(errors)), errors, 0, alpha=0.3, color='purple')
        plt.title('Prediction Error Over Time', fontsize=14, fontweight='bold')
        plt.xlabel('Time Steps')
        plt.ylabel('Error')
        plt.grid(True, alpha=0.3)

        # Add statistics
        mae = np.abs(errors).mean()
        rmse = np.sqrt((errors**2).mean())
        textstr = f'MAE: {mae:.4f}\nRMSE: {rmse:.4f}'
        plt.text(0.02, 0.98, textstr, transform=plt.gca().transAxes,
                 fontsize=10, verticalalignment='top',
                 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        plt.tight_layout()
        plt.savefig('predictions_vs_actuals.png', dpi=150)
        plt.close()

        if self.config.use_wandb:
            wandb.log({"predictions": wandb.Image('predictions_vs_actuals.png')})

    def train(self):
        start_epoch = 0

        if self.config.resume_from_checkpoint is not None:
            checkpoint_path = f"{self.config.checkpoint_dir}/checkpoint_epoch_{self.config.resume_from_checkpoint}.pt"
            start_epoch = self.load_checkpoint(checkpoint_path)

        for epoch in range(start_epoch, self.config.num_epochs):
            train_loss = self.train_epoch(epoch)
            val_loss, val_mae, val_rmse = self.validate()

            # Store history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['val_mae'].append(val_mae)
            self.history['val_rmse'].append(val_rmse)

            print(f"Epoch {epoch+1}: Train Loss={train_loss:.6f}, "
                  f"Val Loss={val_loss:.6f}, Val MAE={val_mae:.6f}, Val RMSE={val_rmse:.6f}")

            # Log to WandB
            if self.config.use_wandb:
                wandb.log({
                    "train_loss_epoch": train_loss,
                    "val_loss": val_loss,
                    "val_mae": val_mae,
                    "val_rmse": val_rmse,
                    "epoch": epoch + 1
                })

            # Save checkpoint
            checkpoint_path = f"{self.config.checkpoint_dir}/checkpoint_epoch_{epoch}.pt"
            torch.save({
                'epoch': epoch,
                'model_state': self.model.state_dict(),
                'optimizer_state': self.optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'history': self.history,  # ADDED: Save history
            }, checkpoint_path)

            # Log checkpoint to WandB
            if self.config.use_wandb:
                wandb.save(checkpoint_path)

        # Plot training history at the end
        self.plot_training_history()

        # Plot predictions vs actuals
        self.plot_predictions(num_samples=100)

        if self.config.use_wandb:
            wandb.log({"training_history": wandb.Image('training_history.png')})
            wandb.finish()

In [158]:
def main():
    config = Config()

    # Load data
    train_dataset = StockDataset(
        config.train_path,
        seq_length=config.seq_length,
        prediction_horizon=config.prediction_horizon,
        normalize=config.normalize_data
    )
    val_dataset = StockDataset(
        config.val_path,
        seq_length=config.seq_length,
        prediction_horizon=config.prediction_horizon,
        normalize=config.normalize_data
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        num_workers=4,
        pin_memory=True
    )

    # Create model
    model = xLSTM(config)
    trainer = Trainer(model, train_loader, val_loader, config)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Train
    trainer.train()

if __name__ == "__main__":
    main()



Model parameters: 133,573
Loaded checkpoint from epoch 9
Resuming training from epoch 10
