# Flash Attention with Docker for Local Development and Scaling

This document provides a guide on how to set up Flash Attention using Docker for local development and scaling. It includes instructions for building the Docker image, running the container, and using Flash Attention in your projects.

## Prerequisites
- Docker installed on your machine
- A compatible GPU (NVIDIA) 
- NVIDIA Docker runtime installed

## Sections
- [Building the Docker Image](#building-the-docker-image)
- [Using Flash Attention Locally](#using-flash-attention)
- [Scaling with GCP Cloud Computing](#scaling-with-gcp-cloud-computing)

## Building the Docker Image

To build the Docker image for Flash Attention, follow these steps.

1. Clone the repository:
   ```bash
   git clone https://github.com/gabenavarro/MLContainerLab.git
   cd MLContainerLab
   ```
2. Build the Docker image:
   ```bashll
   docker build -f ./assets/build/Dockerfile.flashattn.cu128py26cp312 -t flash-attention:128-26-312 .
   ```

   > Note: The following steps are to development within the container. My tutorials will be run inside the container, so you can skip them if you are not interested in development.

3. Run the Docker container detached with terminal access and GPUs connected:
   ```bash
      docker run -dt \
         --gpus all \
         -v "$(pwd):/workspace" \
         --name flash-attention \
         --env NVIDIA_VISIBLE_DEVICES=all \
         --env GOOGLE_APPLICATION_CREDENTIALS=/workspace/assets/secrets/gcp-key.json \
         flash-attention:128-26-312
   ```
   > Note: The `-v $(pwd):/workspace` option mounts the current directory to `/workspace` in the container, allowing you to access your files from within the container. <br>
   > Note: The `--env` options set environment variables for GPU visibility and Google Cloud credentials. <br>
   > Note: The `--gpus all` option allows the container to use all available GPUs. <br>
   > Note: The `--name` option names the container `flash-attention`, which you can use to reference it later. <br>
   > Note: The `-dt` option runs the container in detached mode with terminal access. <br>
   > Note: Get your token from [Synapse](https://synapse.org/), and set it as an environment variable in the container. You can also set it in your local environment, but this is not recommended for security reasons. <br>
   > Note: Get a GCP key from [Google Cloud](https://cloud.google.com/docs/authentication/getting-started) and set it as an environment variable in the container. You can also set it in your local environment, but this is not recommended for security reasons. <br>

4. Open the container in VSCode: 
   ```bash
   # In a scriptable manner
   CONTAINER_NAME=flash-attention
   FOLDER=/workspace
   HEX_CONFIG=$(printf {\"containerName\":\"/$CONTAINER_NAME\"} | od -A n -t x1 | tr -d '[\n\t ]')
   code --folder-uri "vscode-remote://attached-container+$HEX_CONFIG$FOLDER"
   ```

   If you have the Remote - Containers extension installed, this command will open the current directory in VSCode, allowing you to edit files directly in the container.
   If this fails, you can use the GUI to open the container:
   - Open VSCode
   - Press `F1` and type `Remote-Containers: Attach to Running Container...`
   - Select the `flash-attention` container from the list
   - This will open the container in a new VSCode window
   - Set workspace to `/workspace` in the container

   > Note: This command opens the current directory in VSCode, allowing you to edit files directly in the container. <br>
   > Note: You may need to install the Remote - Containers extension in VSCode to use this feature. <br>
   > Note: You may need to install the Python extension in VSCode to use this feature. <br>
   > Note: You may need to install the Jupyter extension in VSCode to use this feature. <br>

## Using Flash Attention

### Dataset Preparation
Lets setup a simple example to run Flash Attention in the container. First lets start off by downloading a sample dataset.

In [None]:
%%capture
# Download and unzip the Bitcoin historical data dataset from Kaggle
!curl -L -o /workspace/datasets/bitcoin-historical-data.zip \
  https://www.kaggle.com/api/v1/datasets/download/mczielinski/bitcoin-historical-data \
    && unzip -o /workspace/datasets/bitcoin-historical-data.zip -d /workspace/datasets/ \
    && rm /workspace/datasets/bitcoin-historical-data.zip

In [1]:
TIME_SERIES_CSV = "/workspace/datasets/btcusd_1-min_data.csv"
PROCESSED_DATA_DIR = "/workspace/datasets/auto_regressive_processed_timeseries"
CKPT_DIR = "/workspace/datasets/checkpoints"

In [2]:
# Make directories if they do not exist
import os
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

After downloading the dataset, we can load it into a Pandas DataFrame and take a look at the first few rows. The dataset contains historical data for Bitcoin, including the date, open, high, low, close prices, and volume.

In [3]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.read_csv(TIME_SERIES_CSV, low_memory=False, nrows=5).head()

Unnamed: 0,Timestamp,Open,High,Low,Close,Volume,datetime
0,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:01:00+00:00
1,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:02:00+00:00
2,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:03:00+00:00
3,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:04:00+00:00
4,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:05:00+00:00


Now that we have the dataset, lets create a dataset and dataloader for the model using litdata. 

In [4]:
import pandas as pd
import litdata as ld
import numpy as np
from typing import Dict, Any

def process_timeseries(file_path: str, sequence_length: int = 2048) -> Dict[str, Any]:
    """Process a timeseries CSV file into a format suitable for autoregressive modeling."""
    df = pd.read_csv(file_path, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # —— NEW: log-transform price values to handle exponential growth ——
    for price_col in ['Open', 'High', 'Low', 'Close']:
        # Ensure all values are positive before log transform
        df[price_col] = np.log(df[price_col].replace([np.inf, -np.inf, 0, np.nan], 0.01).fillna(0.01))
    
    # —— NEW: log1p-transform volume to reduce extreme skew ——
    df['Volume'] = np.log1p(df['Volume'].replace([np.inf, -np.inf], np.nan).fillna(0.0))
    
    stats = {}
    for feature in numerical_features:
        vals = df[feature].replace([np.inf, -np.inf], np.nan).dropna()
        mean, std = (vals.mean(), vals.std() or 1.0)
        stats[feature] = {'mean': mean, 'std': std}
        df[feature] = (df[feature] - mean) / std

    if len(df) <= sequence_length:
        print(f"Warning: only {len(df)} rows < sequence_length={sequence_length}")
        return None

    def create_timeseries_sample(index: int) -> Dict[str, Any]:
        if index < sequence_length or index >= len(df):
            return {"index": index,
                    "inputs": np.zeros((sequence_length,5),dtype=np.float32),
                    "mask": np.zeros(sequence_length, dtype=bool),
                    "stats": stats}
        seq = df.iloc[index-sequence_length:index][numerical_features].values
        arr = np.nan_to_num(seq.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        mask = ~np.isnan(seq).any(axis=1)
        return {"index": index, "inputs": arr, "mask": mask, "stats": stats}

    # Store original stats for inverse transformation during inference
    stats['transform_type'] = 'log_then_zscore'
    
    return create_timeseries_sample


# Set the sequence length for your model
sequence_length = 2048

# Get the processing function configured for your specific file
process_function = process_timeseries(TIME_SERIES_CSV, sequence_length)

# Filter indices to exclude those with NaN values
def get_valid_indices():
    df = pd.read_csv(TIME_SERIES_CSV, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # Replace infinity values with NaN
    for feature in numerical_features:
        df[feature] = df[feature].replace([np.inf, -np.inf], np.nan)
    
    valid_indices = []
    for idx in range(sequence_length, len(df), int(sequence_length * 0.25)):
        # Check if the entire sequence has no NaN values
        sequence = df.iloc[idx-sequence_length:idx][numerical_features].values
        if not np.isnan(sequence).any() and not np.isinf(sequence).any():
            valid_indices.append(idx)
        else:
            print(f"Skipping index {idx} due to NaN or inf values in the sequence.")
    
    if not valid_indices:
        raise ValueError("No valid sequences found! All sequences contain NaN or inf values.")
    
    return valid_indices

valid_indices = get_valid_indices()

# The optimize function writes data in an optimized format
ld.optimize(
    fn=process_function,              # the function that processes each sample
    inputs=valid_indices,             # the indices of valid samples
    output_dir=PROCESSED_DATA_DIR,    # optimized data is stored here
    num_workers=4,                    # The number of workers on the same machine
    chunk_bytes="64MB"                # size of each chunk
)
# Takes about 30 seconds to run


Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Setting multiprocessing start_method to fork. Tip: Libraries relying on lock can hang with `fork`. To use `spawn` in notebooks, move your code to files and import it within the notebook.
Storing the files under /workspace/datasets/auto_regressive_processed_timeseries
Setup started with fast_dev_run=False.
Setup finished in 0.003 seconds. Found 13690 items to process.
Starting 4 workers with 13690 items. The progress bar is only updated when a worker finishes.


  exec(code_obj, self.user_global_ns, self.user_ns)


Rank 0 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'str']` data format.
Rank 1 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'str']` data format.
Workers are ready ! Starting data processing...


Progress:   0%|          | 0/13690 [00:00<?, ?it/s]

Rank 2 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'str']` data format.
Rank 3 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'str']` data format.




Worker 1 is terminating.
Worker 0 is terminating.
Worker 1 is done.
Worker 2 is terminating.
Worker 3 is terminating.
Worker 0 is done.
Worker 2 is done.Worker 3 is done.

Workers are finished.
Finished data processing!


Next, lets split the dataset into training and validation sets. We will use 80% of the data for training and 10% for validation and 10% for test. We will also create a dataloader for the validation set.

In [5]:
from litdata import StreamingDataset, StreamingDataLoader, train_test_split
import torch
streaming_dataset = StreamingDataset(PROCESSED_DATA_DIR) # data are stored in the cloud

def custom_collate(batch):
    # Filter out None values
    batch = [item for item in batch if item is not None]
    
    if not batch:
        # Return empty tensors if the batch is empty
        return {
            "index": torch.tensor([], dtype=torch.long),
            "inputs": torch.tensor([], dtype=torch.float32),
            "mask": torch.tensor([], dtype=torch.bool),
            "stats": {}
        }
    
    # Process each key separately
    indices = torch.tensor([item["index"] for item in batch], dtype=torch.long)
    
    # Make sure arrays are writable by copying and convert to tensor
    inputs = torch.stack([torch.tensor(np.nan_to_num(item["inputs"].copy(), nan=0.0), dtype=torch.float32) for item in batch])
    masks = torch.stack([torch.tensor(item["mask"].copy(), dtype=torch.bool) for item in batch])
    
    # Get stats (use first non-empty item's stats)
    stats = next((item["stats"] for item in batch if "stats" in item), {})
    
    return {
        "index": indices,
        "inputs": inputs,
        "mask": masks,
        "stats": stats
    }

print(len(streaming_dataset)) # display the length of your data
# out: 100,000

train_dataset, val_dataset, test_dataset = train_test_split(streaming_dataset, splits=[0.8, 0.1, 0.1])

print("Train ", len(train_dataset))
train_dataloader = StreamingDataLoader(train_dataset, num_workers=4, batch_size=32, shuffle=True, collate_fn=custom_collate)  # Create DataLoader for training
# out: 80,000

print("Validation ", len(val_dataset))
val_dataloader = StreamingDataLoader(val_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)  # Create DataLoader for validation

test_dataloader = StreamingDataLoader(test_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)
# out: 10,000
print("Test ", len(test_dataset))


13690
Train  10952
Validation  1369
Test  1369


In [6]:
# Get train_dataset and findout dataset shape for input dimensions
for batch in train_dataloader:
    print(batch['inputs'].shape)  # Check the shape of the input tensor
    # out: torch.Size([32, 2048, 5])
    # 32 is the batch size, 2048 is the sequence length, and 5 is the number of features
    break  # Only need to check the first batch

torch.Size([32, 2048, 5])


### Model Definition

Now that we have the dataset and dataloader, we can define the model. To do this, we will create a class that inherits from `nn.Module` and define the model architecture in the `__init__` method. We will also define the forward pass in the `forward` method.

The model architecture consists of an embedding layer, a transformer encoder, and a linear layer. The embedding layer converts the input data into a higher-dimensional space, the transformer encoder processes the data using self-attention mechanisms, and the linear layer outputs the final predictions.

In [7]:
# Transformer Layer
from flash_attn.modules.mha import MHA
from flash_attn.ops.rms_norm import RMSNorm
from flash_attn.ops.fused_dense import FusedMLP
from torch import nn, optim, Tensor
import torch.nn.functional as F
import lightning as pl
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import numpy as np


class TransformerLayer(nn.Module):
    def __init__(self, layer_idx, embed_dim, num_heads, mlp_ratio=4.0, proj_groups=1,
                 dropout=0.05, fast_attention=True):
        super().__init__()
        self.attn = MHA(embed_dim, num_heads, causal=True, layer_idx=layer_idx,
                        num_heads_kv=num_heads//proj_groups,
                        rotary_emb_dim=embed_dim//num_heads,
                        use_flash_attn=fast_attention,
                        return_residual=False, dropout=dropout)
        self.norm1 = RMSNorm(embed_dim)
        self.mlp   = FusedMLP(embed_dim, int(embed_dim*mlp_ratio))
        self.norm2 = RMSNorm(embed_dim)

    def forward(self, x, mask=None):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
    


  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


### B


In [None]:
class AutoregressiveTransformerModel(pl.LightningModule):
    def __init__(
            self, 
            input_dim, embed_dim, num_heads, num_layers,
            mlp_ratio=4.0, lr=5e-5, weight_decay=1e-2
        ):
        super().__init__()
        self.save_hyperparameters()

        # Linear projection layer for input embedding
        self.input_proj = nn.Linear(input_dim, embed_dim)
        
        # Create transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer(i, embed_dim, num_heads, mlp_ratio)
            for i in range(num_layers)
        ])

        # Final linear layer for output projection
        self.fc_out = nn.Linear(embed_dim, input_dim)

        # Initialize weights
        nn.init.xavier_uniform_(self.input_proj.weight)
        nn.init.xavier_uniform_(self.fc_out.weight)

        # Provide a small weight for the volume feature
        fw = torch.tensor([1.,1.,1.,1.,0.25])
        self.register_buffer('feature_weights', fw)

        # For tracking metrics
        self.train_loss = 0.0
        self.val_loss = 0.0
        
        # Learning rate
        self.lr = lr
        self.weight_decay = weight_decay
        
        # For storing test predictions
        self.test_predictions = []
        self.test_targets = []
        self.test_masks = []
        self.feature_names = ['Open', 'High', 'Low', 'Close', 'Volume']


    def forward(self, x, mask=None):
        x = self.input_proj(torch.nan_to_num(x, nan=0.0))
        for layer in self.layers:
            x = layer(x, mask=mask)
        return self.fc_out(x)

    
    def _calculate_autoregressive_loss(self, preds, targets, mask):
        # shift for next-step prediction
        p, t = preds[:, :-1], targets[:, 1:]
        m = mask[:, 1:].unsqueeze(-1)
        # mse = F.mse_loss(p, t, reduction='none')
        mse = F.smooth_l1_loss(p, t, reduction='none', beta=0.01)
        # apply per-feature weights
        mse = mse * self.feature_weights.view(1,1,-1)
        # mask and reduce
        mse = mse * m
        denom = m.sum().clamp_min(1.0)
        return mse.sum() / denom

    
    def training_step(self, batch, batch_idx):
        """
        Training step for autoregressive prediction
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs, mask)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)        
        
        # Log the loss for monitoring (don't try to compute gradient norm here)
        self.train_loss = loss
        self.log('train_loss', loss, prog_bar=True)  
        
        return loss
    
    def on_after_backward(self):
        """
        Called after .backward() and before optimizers do anything.
        This is the right place to check gradients.
        """
        # Safely compute gradient norm - after backward pass when gradients exist
        if any(p.grad is not None for p in self.parameters()):
            grad_list = [p.grad.detach().norm(2) for p in self.parameters() if p.grad is not None]
            if grad_list:  # Make sure list is not empty
                grad_norm = torch.stack(grad_list).norm(2)
                self.log('grad_norm', grad_norm, prog_bar=True)
            else:
                self.log('grad_norm', torch.tensor(0.0), prog_bar=True)
        else:
            self.log('grad_norm', torch.tensor(0.0), prog_bar=True)
    
    def validation_step(self, batch, batch_idx):
        """
        Validation step for autoregressive prediction
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs, mask)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)
                
        # Log the loss for monitoring
        self.val_loss = loss
        self.log('val_loss', loss, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        """
        Test step for evaluating the model after training
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs, mask)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)
        
        # Store predictions and targets for later analysis
        # Use detach() and cpu() to avoid memory leaks
        self.test_predictions.append(predictions.detach().cpu())
        self.test_targets.append(inputs.detach().cpu())
        self.test_masks.append(mask.detach().cpu())
        
        # Log the test loss
        self.log('test_loss', loss, prog_bar=True)
        
        return loss
    
    
    # Configuring the optimizer
    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            mode='min',
            factor=0.5,
            patience=2
        )
        return {
            "optimizer": opt,
            "lr_scheduler_config": {
                "scheduler": sch,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1
            }
        }
    
    def on_test_epoch_end(self):
        """
        Calculate and log test metrics at the end of the test epoch
        """
        # Concatenate all batches
        all_preds = torch.cat(self.test_predictions, dim=0)
        all_targets = torch.cat(self.test_targets, dim=0)
        all_masks = torch.cat(self.test_masks, dim=0)
        
        # For autoregressive prediction, shift by one step
        pred_shifted = all_preds[:, :-1, :]
        target_shifted = all_targets[:, 1:, :]
        mask_shifted = all_masks[:, 1:]
        
        # Convert to numpy for sklearn metrics
        pred_np   = pred_shifted.detach().cpu().to(torch.float32).numpy()
        target_np = target_shifted.detach().cpu().to(torch.float32).numpy()
        mask_np   = mask_shifted.detach().cpu().to(torch.float32).numpy()
        
        # Calculate metrics for each feature
        metrics = {}
        for i, feature_name in enumerate(self.feature_names):
            # Extract predictions and targets for this feature
            feature_preds = pred_np[:, :, i]
            feature_targets = target_np[:, :, i]
            
            # Apply mask to consider only valid positions
            valid_preds = []
            valid_targets = []
            
            # Flatten and filter by mask
            for batch_idx in range(feature_preds.shape[0]):
                for seq_idx in range(feature_preds.shape[1]):
                    if mask_np[batch_idx, seq_idx]:
                        valid_preds.append(feature_preds[batch_idx, seq_idx])
                        valid_targets.append(feature_targets[batch_idx, seq_idx])
            
            # Convert to numpy arrays
            valid_preds = np.array(valid_preds)
            valid_targets = np.array(valid_targets)
            
            if len(valid_preds) > 0:
                # Calculate metrics
                mse = mean_squared_error(valid_targets, valid_preds)
                rmse = np.sqrt(mse)
                mae = mean_absolute_error(valid_targets, valid_preds)
                
                # Calculate MAPE (Mean Absolute Percentage Error) with handling for zeros
                with np.errstate(divide='ignore', invalid='ignore'):
                    mape = np.mean(np.abs((valid_targets - valid_preds) / np.maximum(np.abs(valid_targets), 1e-8))) * 100
                    mape = np.nan_to_num(mape, nan=0.0, posinf=0.0, neginf=0.0)
                
                # R-squared
                r2 = r2_score(valid_targets, valid_preds)
                
                # Store metrics
                metrics[f"{feature_name}_mse"] = mse
                metrics[f"{feature_name}_rmse"] = rmse
                metrics[f"{feature_name}_mae"] = mae
                metrics[f"{feature_name}_mape"] = mape
                metrics[f"{feature_name}_r2"] = r2
                
                # Log each metric
                self.log(f"test_{feature_name}_mse", mse)
                self.log(f"test_{feature_name}_rmse", rmse)
                self.log(f"test_{feature_name}_mae", mae)
                self.log(f"test_{feature_name}_mape", mape)
                self.log(f"test_{feature_name}_r2", r2)
        
        # Calculate average metrics across all features
        avg_mse = np.mean([metrics[f"{feature}_mse"] for feature in self.feature_names if f"{feature}_mse" in metrics])
        avg_rmse = np.mean([metrics[f"{feature}_rmse"] for feature in self.feature_names if f"{feature}_rmse" in metrics])
        avg_mae = np.mean([metrics[f"{feature}_mae"] for feature in self.feature_names if f"{feature}_mae" in metrics])
        avg_mape = np.mean([metrics[f"{feature}_mape"] for feature in self.feature_names if f"{feature}_mape" in metrics])
        avg_r2 = np.mean([metrics[f"{feature}_r2"] for feature in self.feature_names if f"{feature}_r2" in metrics])
        
        # Log average metrics
        self.log("test_avg_mse", avg_mse)
        self.log("test_avg_rmse", avg_rmse)
        self.log("test_avg_mae", avg_mae)
        self.log("test_avg_mape", avg_mape)
        self.log("test_avg_r2", avg_r2)
        
        # Print summary of test metrics
        print("\n===== TEST METRICS =====")
        print(f"Average MSE: {avg_mse:.6f}")
        print(f"Average RMSE: {avg_rmse:.6f}")
        print(f"Average MAE: {avg_mae:.6f}")
        print(f"Average MAPE: {avg_mape:.6f}%")
        print(f"Average R²: {avg_r2:.6f}")
        print("=======================\n")
        
        # Create and save visualizations
        self._create_prediction_visualizations(pred_shifted, target_shifted, mask_shifted)
        
        # Clear stored predictions to free memory
        self.test_predictions = []
        self.test_targets = []
        self.test_masks = []

    def _create_prediction_visualizations(self, predictions, targets, masks, num_samples=3):
        """
        Create and save visualization of predictions vs targets
        
        Args:
            predictions: Model predictions [batch_size, seq_len, feature_dim]
            targets: Target values [batch_size, seq_len, feature_dim]
            masks: Boolean masks [batch_size, seq_len]
            num_samples: Number of samples to visualize
        """
        # Convert to numpy for plotting
        preds_np = predictions.detach().cpu().to(torch.float32).numpy()
        targets_np = targets.detach().cpu().to(torch.float32).numpy()
        masks_np = masks.detach().cpu().to(torch.float32).numpy()
        
        # Create directory for plots if it doesn't exist
        import os
        os.makedirs(os.path.join(CKPT_DIR, "plots"), exist_ok=True)
        
        # Plot for each feature
        for feature_idx, feature_name in enumerate(self.feature_names):
            plt.figure(figsize=(15, 10))
            
            # Plot for a few random samples
            for sample_idx in range(min(num_samples, preds_np.shape[0])):
                # Get predictions and targets for this sample and feature
                sample_preds = preds_np[sample_idx, :, feature_idx]
                sample_targets = targets_np[sample_idx, :, feature_idx]
                sample_mask = masks_np[sample_idx, :]
                
                # Create time index for x-axis
                time_idx = np.arange(len(sample_preds))
                
                # Plot targets
                plt.subplot(num_samples, 1, sample_idx + 1)
                plt.plot(time_idx, sample_targets, 'b-', label='Actual', alpha=0.7)
                
                # Plot predictions (only where mask is True)
                masked_preds = np.where(sample_mask, sample_preds, np.nan)
                plt.plot(time_idx, masked_preds, 'r-', label='Predicted', alpha=0.7)
                
                # Add title and legend
                plt.title(f"Sample {sample_idx+1}: {feature_name}")
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                # Calculate metrics for this sample
                valid_indices = np.where(sample_mask)[0]
                if len(valid_indices) > 0:
                    valid_preds = sample_preds[valid_indices]
                    valid_targets = sample_targets[valid_indices]
                    
                    mse = mean_squared_error(valid_targets, valid_preds)
                    rmse = np.sqrt(mse)
                    r2 = r2_score(valid_targets, valid_preds)
                    
                    plt.figtext(0.01, 0.5 - 0.15 * sample_idx, 
                                f"MSE: {mse:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}", 
                                fontsize=9)
            
            plt.tight_layout()
            plt.savefig(os.path.join(CKPT_DIR, "plots", f"{feature_name}_predictions.svg"))
            plt.close()
        
        # Create a summary plot with all features for the first sample
        plt.figure(figsize=(15, 12))
        
        for feature_idx, feature_name in enumerate(self.feature_names):
            plt.subplot(len(self.feature_names), 1, feature_idx + 1)
            
            # Get predictions and targets for first sample and this feature
            sample_preds = preds_np[0, :, feature_idx]
            sample_targets = targets_np[0, :, feature_idx]
            sample_mask = masks_np[0, :]
            
            # Create time index for x-axis
            time_idx = np.arange(len(sample_preds))
            
            # Plot targets
            plt.plot(time_idx, sample_targets, 'b-', label='Actual', alpha=0.7)
            
            # Plot predictions (only where mask is True)
            masked_preds = np.where(sample_mask, sample_preds, np.nan)
            plt.plot(time_idx, masked_preds, 'r-', label='Predicted', alpha=0.7)
            
            # Add title and legend
            plt.title(f"{feature_name}")
            plt.legend()
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(CKPT_DIR, "plots", "all_features_predictions.svg"))
        plt.close()



In [None]:
# Initialize the model
# Creating a dict of parameters for the model for better readability
# and to be ready for cloud training
# Note: The parameters below are just examples. You may need to adjust them based on your specific use case.
parameters = dict(
    input_dim = 5,
    embed_dim = 64,
    num_heads = 8,
    num_layers = 4,
    lr = 5e-5,                      # Reduced learning rate
    weight_decay = 1e-3,            # Updated weight decay
    limit_val_batches = 100,        # Use a fraction of validation data for faster training
    accumulate_grad_batches = 8,    # Gradient accumulation
    gradient_clip_val = 0.5,        # Gradient clipping for stability
    batch_size = 32,                # Batch size
    sequence_length = 2048,         # Sequence length
    num_workers = 4,                # Number of workers for data loading
    cktp_dir = CKPT_DIR,            # Checkpoint directory
    nodes = 1,                      # Number of nodes
    devices = 1,                    # Number of devices (GPUs)
    accelerator = "gpu",            # Use GPU for training
    strategy = "auto",              # Distributed Data Parallel
    precision = "bf16-mixed",       # Mixed precision training
    epochs = 10,                    # Number of epochs
    log_every_n_steps = 10,         # Log every n steps
    data_dir = PROCESSED_DATA_DIR,  # Directory for processed data
    data_splits = [0.8, 0.1, 0.1],  # Train/Validation/Test splits
)

model = AutoregressiveTransformerModel(**parameters)

# Lightning callbacks for early stopping and model checkpointing
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=parameters['cktp_dir'],
    filename='best-model-{epoch:02d}-{val_loss:.4f}',
    save_top_k=1,
    mode='min'
)

trainer = pl.Trainer(
    limit_val_batches=parameters['limit_val_batches'],
    max_epochs=parameters['epochs'], 
    accumulate_grad_batches=parameters['accumulate_grad_batches'], 
    gradient_clip_val=parameters['gradient_clip_val'],  
    default_root_dir=parameters['cktp_dir'],
    precision=parameters['precision'],
    log_every_n_steps=parameters['log_every_n_steps'],
    accelerator=parameters['accelerator'],
    devices=parameters['devices'],
    strategy=parameters['strategy'],
    num_nodes=parameters['nodes'],
    callbacks=[early_stop_callback, checkpoint_callback]
) 

# Train the model
trainer.fit(model, train_dataloader, val_dataloader)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /workspace/datasets/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params | Mode 
--------------------------------------------------
0 | input_proj | Linear     | 384    | train
1 | layers     | ModuleList | 199 K  | train
2 | fc_out     | Linear     | 325    | train
--------------------------------------------------
200 K     Trainable params
0         Non-trainable params
200 K     Total params
0.801     Total estimated model params size (MB)
59        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 25. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Now that we have trained the model, lets look at the training characteristics we caputured during the training process with tensorboard. Below is the code to read and render a image from the tensorboard logs.

In [None]:
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt

def plot_tensorboard_scalars(ckpt_dir:str, tag=None, tags=None):
    """
    Plot scalar curves from TensorBoard logs in a given directory.

    Args:
        tag (str): Specific tag to plot.
        tags (list): List of tags to plot.
    """

    import glob
    
    # Find the most recent version
    version_dirs = glob.glob(f"{ckpt_dir}/lightning_logs/version_*")
    if not version_dirs:
        print("No training logs found.")
        return
    

    latest_version = max(version_dirs, key=lambda x: int(x.split('_')[-1]))

    # Find any event files in the directory
    files = glob.glob(os.path.join(latest_version, 'events.out.tfevents.*'))
    if not files:
        raise FileNotFoundError(f"No TensorBoard event files found in {latest_version}")

    # Initialize the EventAccumulator to read scalars
    ea = event_accumulator.EventAccumulator(
        latest_version,
        size_guidance={  # Load all scalar data
            event_accumulator.SCALARS: 0,
        }
    )
    ea.Reload()  # Load the event data

    # Determine which tags to plot
    available_tags = ea.Tags().get('scalars', [])
    if tags:
        plot_tags = [t for t in tags if t in available_tags]
    elif tag:
        if tag not in available_tags:
            raise ValueError(f"Tag '{tag}' not found. Available tags: {available_tags}")
        plot_tags = [tag]
    else:
        # If no tag(s) specified, plot all scalar tags
        plot_tags = available_tags

    if not plot_tags:
        raise ValueError("No valid scalar tags to plot.")

    # Plot each tag's values over steps
    plt.figure()
    for t in plot_tags:
        events = ea.Scalars(t)
        steps = [e.step for e in events]
        values = [e.value for e in events]
        plt.plot(steps, values, label=t)

    plt.xlabel('Step')
    plt.ylabel('Value')
    plt.title(f"TensorBoard Scalars")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(ckpt_dir, "plots", "tensorBoard_scalars.svg"))
    plt.close()


# Plot the training and validation loss curves
plot_tensorboard_scalars(ckpt_dir=parameters['data_dir'], tags=['train_loss', 'val_loss'])

Below is the scalar plot of the training and validation loss over the epochs. The training loss is shown in blue and the validation loss is shown in orange. As we can see, the training loss decreases over time, indicating that the model is learning. The validation loss also decreases, but at a slower rate, indicating that the model is not overfitting to the training data.

![Flash Attention Predictions](../assets/images/flash_attn_tensorBoard_scalars.svg)


## Model Evaluation 

### Test Dataset
Now that we have trained the model, we can evaluate it on the test dataset. We will use the `evaluate` method from the `litdata` library to evaluate the model on the test dataset. Below is the code to evaluate the model on the test dataset.

In [10]:
# Load the best checkpoint
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
    print(f"Loading best model from {best_model_path}")
    model = AutoregressiveTransformerModel.load_from_checkpoint(best_model_path)

# Test the model and calculate metrics
print("Evaluating model on test set...")
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loading best model from /workspace/datasets/checkpoints/best-model-epoch=08-val_loss=0.2039.ckpt
Evaluating model on test set...


Testing: |          | 0/? [00:00<?, ?it/s]


===== TEST METRICS =====
Average MSE: 0.254523
Average RMSE: 0.247193
Average MAE: 0.172371
Average MAPE: 80.817858%
Average R²: 0.824211



[{'test_loss': 0.2503935396671295,
  'test_Open_mse': 0.0006383138825185597,
  'test_Open_rmse': 0.02526487410068512,
  'test_Open_mae': 0.016561882570385933,
  'test_Open_mape': 2.5399210453033447,
  'test_Open_r2': 0.9900282025337219,
  'test_High_mse': 0.0005231107934378088,
  'test_High_rmse': 0.02287161536514759,
  'test_High_mae': 0.01634710468351841,
  'test_High_mape': 2.466440439224243,
  'test_High_r2': 0.991832971572876,
  'test_Low_mse': 0.0012872182996943593,
  'test_Low_rmse': 0.03587782382965088,
  'test_Low_mae': 0.022319288924336433,
  'test_Low_mape': 3.4797444343566895,
  'test_Low_r2': 0.9798782467842102,
  'test_Close_mse': 0.0006358284736052155,
  'test_Close_rmse': 0.025215638801455498,
  'test_Close_mae': 0.01602933742105961,
  'test_Close_mape': 2.417116165161133,
  'test_Close_r2': 0.9900671243667603,
  'test_Volume_mse': 1.2695327997207642,
  'test_Volume_rmse': 1.1267354488372803,
  'test_Volume_mae': 0.7905952334403992,
  'test_Volume_mape': 393.18606567382

Now lets go ahead and create some figures to better visualize the results. We will create a figure with a subplots for each of the predictions. Each subplot will plot the difference between the actual values and the predicted values.

In [None]:
import os
import glob


# Create a performance report with all the metrics
def create_performance_report(ckpt_dir: str = CKPT_DIR):
    # Get metrics from the lightning logs
    import json
    import glob
    
    # Find the most recent version
    version_dirs = glob.glob(f"{ckpt_dir}/lightning_logs/version_*")
    if not version_dirs:
        print("No training logs found.")
        return
    
    latest_version = max(version_dirs, key=lambda x: int(x.split('_')[-1]))
    metrics_path = f"{latest_version}/metrics.csv"
    
    if not os.path.exists(metrics_path):
        print(f"No metrics file found at {metrics_path}")
        return
    
    # Load metrics
    metrics_df = pd.read_csv(metrics_path)
    
    # Filter for test metrics only
    test_metrics = metrics_df[metrics_df['step'].isna()]
    
    # Create a report
    with open(f"{ckpt_dir}/test_performance_report.md", 'w') as f:
        f.write("# Model Performance Report\n\n")
        f.write("## Overall Metrics\n\n")
        
        # Add overall metrics
        if 'test_avg_mse' in test_metrics.columns:
            f.write(f"* **Average MSE:** {test_metrics['test_avg_mse'].iloc[-1]:.6f}\n")
        if 'test_avg_rmse' in test_metrics.columns:
            f.write(f"* **Average RMSE:** {test_metrics['test_avg_rmse'].iloc[-1]:.6f}\n")
        if 'test_avg_mae' in test_metrics.columns:
            f.write(f"* **Average MAE:** {test_metrics['test_avg_mae'].iloc[-1]:.6f}\n")
        if 'test_avg_mape' in test_metrics.columns:
            f.write(f"* **Average MAPE:** {test_metrics['test_avg_mape'].iloc[-1]:.6f}%\n")
        if 'test_avg_r2' in test_metrics.columns:
            f.write(f"* **Average R²:** {test_metrics['test_avg_r2'].iloc[-1]:.6f}\n")
        
        f.write("\n## Feature-Specific Metrics\n\n")
        
        # Add feature-specific metrics
        for feature in ['Open', 'High', 'Low', 'Close', 'Volume']:
            f.write(f"### {feature}\n\n")
            
            if f'test_{feature}_mse' in test_metrics.columns:
                f.write(f"* **MSE:** {test_metrics[f'test_{feature}_mse'].iloc[-1]:.6f}\n")
            if f'test_{feature}_rmse' in test_metrics.columns:
                f.write(f"* **RMSE:** {test_metrics[f'test_{feature}_rmse'].iloc[-1]:.6f}\n")
            if f'test_{feature}_mae' in test_metrics.columns:
                f.write(f"* **MAE:** {test_metrics[f'test_{feature}_mae'].iloc[-1]:.6f}\n")
            if f'test_{feature}_mape' in test_metrics.columns:
                f.write(f"* **MAPE:** {test_metrics[f'test_{feature}_mape'].iloc[-1]:.6f}%\n")
            if f'test_{feature}_r2' in test_metrics.columns:
                f.write(f"* **R²:** {test_metrics[f'test_{feature}_r2'].iloc[-1]:.6f}\n")
            
            f.write("\n")
        
        # Add prediction plots
        f.write("## Prediction Visualizations\n\n")
        for feature in ['Open', 'High', 'Low', 'Close', 'Volume']:
            f.write(f"### {feature} Predictions\n\n")
            f.write(f"![{feature} Predictions](plots/{feature}_predictions.svg)\n\n")
        
        f.write("### All Features (Sample 1)\n\n")
        f.write("![All Features](plots/all_features_predictions.svg)\n\n")

# Create the performance report
create_performance_report(ckpt_dir=parameters['ckpt_dir'])
print(f"Test performance report saved to {parameters['ckpt_dir']}/test_performance_report.md")

No metrics file found at /workspace/datasets/checkpoints/lightning_logs/version_0/metrics.csv
Test performance report saved to /workspace/datasets/checkpoints/test_performance_report.md


![Flash Attention Predictions](../assets/images/flash_attn_all_features_predictions.svg)

Let's take a look at how the model performs differently across features. The metrics table shows:

| Feature | FlashAttention R²  | Flash Attention MAPE |
| ------- | ------------------ | -------------------- |
| Close   | 0.82               | 2.41                 |
| Open    | 0.81               | 2.54                 |
| High    | 0.77               | 2.46                 |
| Low     | 0.77               | 2.47                 |
| Volume  | 0.16               | 393.19               |

Lets first start with the the high R² values. The model performs well on the close, open, high, and low price predictions, with R² values ranging from 0.77 to 0.82. This indicates that the model is able to explain a significant portion of the variance in the data. The mean absolute percentage error (MAPE) values for these features are also relatively low, ranging from 0.8% to 1.2%. This indicates that the model is able to make accurate predictions for these features. This is a good sign, as it indicates that the model is able to learn the underlying patterns in the data and make accurate predictions. However, this is likely due to the fact that these features are highly correlated with each other and the model is able to learn these correlations. Additionally, the variance in the data is relatively low, which makes it easier for the model to learn the underlying patterns.

Volume is clearly the most challenging feature to predict since it has the highest variance in the dataset. With an R² value of only 0.16 and extremely high MAPE of 253% the model struggles to make accurate volume predictions. The high MAPE value indicates substantial percentage errors in predictions. This happens when:

- Working with normalized data where values are close to zero
- Dealing with highly volatile series like cryptocurrency prices
- The model struggles with abrupt changes

We can improve the model in many ways, such as:
- Data transformation: Better normalization or scaling
- Feature engineering: Adding more features or using different features such as technical indicators or more complex features such as wavelet transforms.
- Model architecture: Using more complex architectures such as CNNs or LSTMs or alternative attention architectures such as state space models.
- Hyperparameter tuning: Tuning the hyperparameters of the model to improve performance
- Regularization: Adding regularization to the model to prevent overfitting.

## Scaling with GCP Cloud Computing

This section demonstrates how to scale our machine learning model training to Google Cloud Platform (GCP) using Vertex AI. Vertex AI is a managed machine learning platform that enables you to train and deploy ML models at scale with comprehensive tooling and infrastructure.

### GCP Setup Prerequisites
1. Create a GCP project and enable the Vertex AI API
2. Create a GCP service account with appropriate permissions
3. Download the JSON key file for authentication
4. Set the environment variable `GOOGLE_APPLICATION_CREDENTIALS` to the path of the JSON key file

Steps 1-3 are covered in the [GCP setup](#gcp-setup) section. Below, we'll focus on steps 4-6:
- Building and pushing Docker images to GCP Artifact Registry
- Uploading configuration, training scripts, and datasets to GCP Cloud Storage
- Submitting training jobs to Vertex AI

### Step 4: Build and Push the Docker Image

First, authenticate Docker with GCP Artifact Registry:

```bash
# Configure Docker to use GCP credentials for authentication
gcloud auth configure-docker your-location-docker.pkg.dev
```

Next, build the Docker image with the same Dockerfile used for local development:

```bash
# Build the Docker image with GCP Artifact Registry URL
docker build -f ./assets/build/Dockerfile.flashattn.cu128py26cp312 \
  -t your-location-docker.pkg.dev/your-project-id/repositories/flash-attention:latest .

# Push the Docker image to GCP Artifact Registry
docker push your-location-docker.pkg.dev/your-project-id/repositories/flash-attention:latest

# Verify the Docker image in GCP Artifact Registry
gcloud artifacts docker images list your-location-docker.pkg.dev/your-project-id/repositories/flash-attention
```

> **Notes:**
> - Replace `your-location` with your GCP region (e.g., `us-central1`)
> - Replace `your-project-id` with your GCP project ID
> - You must first create a repository in Artifact Registry called `repositories` (or use your preferred name)

### Step 5: Upload Files to GCP Cloud Storage

First, create a GCP Cloud Storage bucket:

```bash
# Create a GCP Cloud Storage bucket
gsutil mb -l your-location gs://your-gcp-training-bucket
```

Next, create the configuration file:

In [None]:
import yaml

GCP_DATA_DIR = "/gcs/your-gcp-training-bucket/flash-attn-example/dataset/"
GCP_CKPT_DIR = "/gcs/your-gcp-training-bucket/flash-attn-example/checkpoints/"

# Model and training parameters
parameters = dict(
    input_dim = 5,                   # Number of features in the input data  
    embed_dim = 64,                  # Embedding dimension
    num_heads = 8,                   # Number of attention heads    
    num_layers = 4,                  # Number of transformer layers
    lr = 5e-5,                       # Learning rate
    weight_decay = 1e-3,             # Weight decay for regularization
    limit_val_batches = 100,         # Fraction of validation data to use
    accumulate_grad_batches = 8,     # Gradient accumulation steps
    gradient_clip_val = 0.5,         # Gradient clipping value
    batch_size = 32,                 # Batch size
    sequence_length = 2048,          # Sequence length
    num_workers = 4,                 # Data loading workers
    cktp_dir = GCP_CKPT_DIR,         # Checkpoint directory
    nodes = 1,                       # Number of nodes
    devices = 8,                     # Number of GPUs
    accelerator = "gpu",             # Accelerator type
    strategy = "auto",               # Training strategy
    precision = "bf16-mixed",        # Mixed precision
    epochs = 50,                     # Training epochs
    log_every_n_steps = 10,          # Logging frequency
    data_dir = GCP_DATA_DIR,         # Data directory
    data_splits = [0.8, 0.1, 0.1],   # Train/Val/Test splits
)

# Save the parameters to a YAML file
with open("/workspace/datasets/flash_attn_crypto_model_config.yaml", 'w') as f:
    yaml.dump(parameters, f, default_flow_style=False)

Then, upload the configuration file, training script, and dataset to GCP Cloud Storage:

```bash
# Create directories in GCP Cloud Storage
gsutil mb -p gs://your-gcp-training-bucket/flash-attn-example/config
gsutil mb -p gs://your-gcp-training-bucket/flash-attn-example/scripts
gsutil mb -p gs://your-gcp-training-bucket/flash-attn-example/datasets
gsutil mb -p gs://your-gcp-training-bucket/flash-attn-example/checkpoints
gsutil mb -p gs://your-gcp-training-bucket/flash-attn-example/staging

# Upload the config file
gsutil cp /workspace/datasets/flash_attn_crypto_model_config.yaml \
  gs://your-gcp-training-bucket/flash-attn-example/config/

# Upload the training script
gsutil cp /workspace/scripts/flash_attn_train.py \
  gs://your-gcp-training-bucket/flash-attn-example/scripts/

# Upload the dataset (multi-threaded for faster upload)
gsutil -m cp -r /workspace/datasets/auto_regressive_processed_timeseries \
  gs://your-gcp-training-bucket/flash-attn-example/datasets/

# Verify uploads
gsutil ls -r gs://your-gcp-training-bucket/flash-attn-example/
```

### Step 6: Submit the Training Job to Vertex AI

Use the `google-cloud-aiplatform` Python library to submit the training job:

In [None]:
from google.cloud import aiplatform
from google.oauth2 import service_account
import os

# Vertex AI Configuration
SERVICE_KEY_PATH = os.getenv(
    "GOOGLE_APPLICATION_CREDENTIALS", 
    "/path/to/your/service_account_key.json"
)
LOCATION = "your-gcp-region"         # e.g., "us-central1"
ZONE = "your-gcp-zone"               # e.g., "us-central1-a"
PROJECT_ID = "your-gcp-project-id"   # e.g., "my-project-12345"
RESERVATION_TYPE = "ANY"             # or "ANY_RESERVATION"
STAGING_BUCKET = "gs://your-gcp-training-bucket/flash-attn-example/staging"
SERVICE_ACCOUNT = f"vertex-ai@{PROJECT_ID}.iam.gserviceaccount.com"
TRAIN_IMAGE = f"your-location-docker.pkg.dev/{PROJECT_ID}/repositories/flash-attention:latest"
DISPLAY_NAME = "flash-attn-crypto-model-training"

# Hardware Configuration
NODES = 1
MACHINE_TYPE = "a3-megagpu-8g"
ACCELERATOR_TYPE = "NVIDIA_H100_MEGA_80GB"
ACCELERATOR_COUNT = 8

# Training Command
CMD = [
    "python3", 
    "/gcs/your-gcp-training-bucket/flash-attn-example/scripts/flash_attn_train.py",
    "--config", 
    "/gcs/your-gcp-training-bucket/flash-attn-example/config/flash_attn_crypto_model_config.yaml",
]

# Worker pool specification
worker_pool_specs=[
    {
        "replica_count": NODES,
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_COUNT,
            "reservation_affinity": {
                "reservation_affinity_type": RESERVATION_TYPE,
            }
        },
        "container_spec": {
            "image_uri": TRAIN_IMAGE,
            "command": CMD
        }
    }
]

# Initialize Vertex AI
aiplatform.init(
    project=PROJECT_ID,
    location=LOCATION,
    credentials=service_account.Credentials.from_service_account_file(
        SERVICE_KEY_PATH
    )
)

# Create and submit the training job
job = aiplatform.CustomJob(
    display_name=DISPLAY_NAME, 
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
)

job.submit(
    service_account=SERVICE_ACCOUNT
)

# Print job details
print(f"Job ID: {job.resource_name}")
print(f"Job state: {job.state}")

### Using Specific Reservations

If you have a specific hardware reservation, you can specify it like this:

In [None]:
# Worker pool specification with specific reservation
worker_pool_specs=[
    {
        "replica_count": 1,
        "machine_spec": {
            "machine_type": "a3-megagpu-8g",
            "accelerator_type": "NVIDIA_H100_MEGA_80GB",
            "accelerator_count": 8,
            "reservation_affinity": {
                "reservation_affinity_type": "SPECIFIC_RESERVATION",
                "key": "compute.googleapis.com/reservation-name",
                "values": [
                    f"projects/{PROJECT_ID}/zones/{ZONE}/reservations/your-reservation-name",
                ]
            }
        },
        "container_spec": {
            "image_uri": TRAIN_IMAGE,
            "command": CMD
        }
    }
]