# Flash Attention with Docker for Local Development and Scaling

This document provides a guide on how to set up Flash Attention using Docker for local development and scaling. It includes instructions for building the Docker image, running the container, and using Flash Attention in your projects.

## Prerequisites
- Docker installed on your machine
- A compatible GPU (NVIDIA) 
- NVIDIA Docker runtime installed

## Sections
- [Building the Docker Image](#building-the-docker-image)
- [Using Flash Attention Locally](#using-flash-attention)
- [Scaling with GCP Cloud Computing](#scaling-with-gcp-cloud-computing)

## Building the Docker Image

To build the Docker image for Flash Attention, follow these steps.

1. Clone the repository:
   ```bash
   git clone https://github.com/gabenavarro/MLContainerLab.git
   cd MLContainerLab
   ```
2. Build the Docker image:
   ```bashll
   docker build -f ./assets/build/Dockerfile.flashattn.cu128py26cp312 -t flash-attention:128-26-312 .
   ```

   > Note: The following steps are to development within the container. My tutorials will be run inside the container, so you can skip them if you are not interested in development.

3. Run the Docker container detached with terminal access and GPUs connected:
   ```bash
   docker run -dt \
   --gpus all \
   -v "$(pwd):/workspace" \
   --name flash-attention \
   --env NVIDIA_VISIBLE_DEVICES=all \
   --env GOOGLE_APPLICATION_CREDENTIALS=/workspace/assets/secrets/gcp-key.json \
   flash-attention:128-26-312
   ```
   > Note: The `-v $(pwd):/workspace` option mounts the current directory to `/workspace` in the container, allowing you to access your files from within the container. <br>
   > Note: The `--env` options set environment variables for GPU visibility and Google Cloud credentials. <br>
   > Note: The `--gpus all` option allows the container to use all available GPUs. <br>
   > Note: The `--name` option names the container `flash-attention`, which you can use to reference it later. <br>
   > Note: The `-dt` option runs the container in detached mode with terminal access. <br>
   > Note: Get your token from [Synapse](https://synapse.org/), and set it as an environment variable in the container. You can also set it in your local environment, but this is not recommended for security reasons. <br>
   > Note: Get a GCP key from [Google Cloud](https://cloud.google.com/docs/authentication/getting-started) and set it as an environment variable in the container. You can also set it in your local environment, but this is not recommended for security reasons. <br>

4. Open the container in VSCode: 
   ```bash
   code --folder-uri vscode-remote://dev-container+flash-attention/workspace
   ```

   If you have the Remote - Containers extension installed, this command will open the current directory in VSCode, allowing you to edit files directly in the container.
   If this fails, you can use the GUI to open the container:
   - Open VSCode
   - Press `F1` and type `Remote-Containers: Attach to Running Container...`
   - Select the `flash-attention` container from the list
   - This will open the container in a new VSCode window
   - Set workspace to `/workspace` in the container

   > Note: This command opens the current directory in VSCode, allowing you to edit files directly in the container. <br>
   > Note: You may need to install the Remote - Containers extension in VSCode to use this feature. <br>
   > Note: You may need to install the Python extension in VSCode to use this feature. <br>
   > Note: You may need to install the Jupyter extension in VSCode to use this feature. <br>

## Using Flash Attention

### Dataset Preparation
Lets setup a simple example to run Flash Attention in the container. First lets start off by downloading a sample dataset.

In [1]:
TIME_SERIES_CSV = "/workspace/datasets/btcusd_1-min_data.csv"
PROCESSED_DATA_DIR = "/workspace/datasets/auto_regressive_processed_timeseries"
CKPT_DIR = "/workspace/datasets/checkpoints"

In [2]:
# Make directories if they do not exist
import os
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

In [None]:
%%capture
# Download and unzip the Bitcoin historical data dataset from Kaggle
!curl -L -o /workspace/datasets/bitcoin-historical-data.zip \
  https://www.kaggle.com/api/v1/datasets/download/mczielinski/bitcoin-historical-data \
    && unzip -o /workspace/datasets/bitcoin-historical-data.zip -d /workspace/datasets/ \
    && rm /workspace/datasets/bitcoin-historical-data.zip

After downloading the dataset, we can load it into a Pandas DataFrame and take a look at the first few rows. The dataset contains historical data for Bitcoin, including the date, open, high, low, close prices, and volume.

In [4]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.read_csv(TIME_SERIES_CSV, low_memory=False, nrows=5).head()

Unnamed: 0,Timestamp,Open,High,Low,Close,Volume,datetime
0,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:01:00+00:00
1,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:02:00+00:00
2,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:03:00+00:00
3,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:04:00+00:00
4,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:05:00+00:00


Now that we have the dataset, lets create a dataset and dataloader for the model using litdata. 

In [5]:
import pandas as pd
import litdata as ld
import numpy as np
from typing import Dict, Any
import torch

def process_timeseries(file_path: str, sequence_length: int = 2048) -> Dict[str, Any]:
    """Process a timeseries CSV file into a format suitable for autoregressive modeling with mask"""
    
    # Read the CSV file
    df = pd.read_csv(file_path, low_memory=False)
    
    # Drop datetime
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    
    # Sort by timestamp to ensure chronological order
    df = df.sort_values('Timestamp')
    
    # Select the numerical columns for prediction
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # Normalize the data to prevent numerical instability - NEW!
    # Store statistics for later use
    stats = {}
    for feature in numerical_features:
        # Replace infinity values with NaN
        df[feature] = df[feature].replace([np.inf, -np.inf], np.nan)
        
        # Calculate statistics using only finite values
        finite_values = df[feature].dropna()
        if len(finite_values) > 0:
            mean = finite_values.mean()
            std = finite_values.std()
            # Prevent zero std which would cause division by zero
            if std == 0:
                std = 1.0
            
            # Store stats for this feature
            stats[feature] = {'mean': mean, 'std': std}
            
            # Normalize the feature
            df[feature] = (df[feature] - mean) / std
    
    # Check if we have enough data
    if len(df) <= sequence_length:
        print(f"Warning: File {file_path} has {len(df)} rows, which is less than the required sequence_length {sequence_length}.")
        return None  # Return None for the entire file, not just specific indices
    
    # Function to create samples for litdata - modified to handle edge cases and NaN values
    def create_timeseries_sample(index: int) -> Dict[str, Any]:
        # Only process indices that have enough previous data
        if index < sequence_length or index >= len(df):
            # Create writable arrays
            input_array = np.zeros((sequence_length, len(numerical_features)), dtype=np.float32).copy()
            mask_array = np.zeros(sequence_length, dtype=np.bool_).copy()  # All positions invalid
            
            return {
                "index": index,
                "inputs": input_array,
                "mask": mask_array,  # No valid positions
                "stats": stats       # Include normalization stats
            }
            
        # Get the sequence of previous data points
        input_sequence = df.iloc[index-sequence_length:index][numerical_features].values
        
        # Make a writable copy of the arrays
        input_array = input_sequence.astype(np.float32).copy()
        
        # Additional check to replace any remaining NaN or inf values
        input_array = np.nan_to_num(input_array, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Create mask based on both sequence positions and NaN values
        # True for valid positions (not NaN), False for invalid (NaN)
        mask_array = ~np.isnan(input_sequence).any(axis=1)
        
        # Return the required keys: index, inputs, and mask
        return {
            "index": index,
            "inputs": input_array,
            "mask": mask_array,
            "stats": stats  # Include normalization stats
        }
    
    return create_timeseries_sample


# Set the sequence length for your model
sequence_length = 2048

# Get the processing function configured for your specific file
process_function = process_timeseries(TIME_SERIES_CSV, sequence_length)

# Filter indices to exclude those with NaN values
def get_valid_indices():
    df = pd.read_csv(TIME_SERIES_CSV, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # Replace infinity values with NaN
    for feature in numerical_features:
        df[feature] = df[feature].replace([np.inf, -np.inf], np.nan)
    
    valid_indices = []
    for idx in range(sequence_length, len(df), int(sequence_length * 0.25)):
        # Check if the entire sequence has no NaN values
        sequence = df.iloc[idx-sequence_length:idx][numerical_features].values
        if not np.isnan(sequence).any() and not np.isinf(sequence).any():
            valid_indices.append(idx)
        else:
            print(f"Skipping index {idx} due to NaN or inf values in the sequence.")
    
    if not valid_indices:
        raise ValueError("No valid sequences found! All sequences contain NaN or inf values.")
    
    return valid_indices

valid_indices = get_valid_indices()

# The optimize function writes data in an optimized format
ld.optimize(
    fn=process_function,              # the function that processes each sample
    inputs=valid_indices,             # the indices of valid samples
    output_dir=PROCESSED_DATA_DIR,    # optimized data is stored here
    num_workers=4,                    # The number of workers on the same machine
    chunk_bytes="64MB"                # size of each chunk
)
# Takes about 30 seconds to run


Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Setting multiprocessing start_method to fork. Tip: Libraries relying on lock can hang with `fork`. To use `spawn` in notebooks, move your code to files and import it within the notebook.
Storing the files under /workspace/datasets/auto_regressive_processed_timeseries
Setup started with fast_dev_run=False.
Setup finished in 0.001 seconds. Found 13690 items to process.
Starting 4 workers with 13690 items. The progress bar is only updated when a worker finishes.


  exec(code_obj, self.user_global_ns, self.user_ns)


Rank 0 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float']` data format.Rank 1 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float']` data format.
Workers are ready ! Starting data processing...



Progress:   0%|          | 0/13690 [00:00<?, ?it/s]

Rank 2 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float']` data format.
Rank 3 inferred the following `['int', 'numpy', 'no_header_numpy:13', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float', 'float']` data format.
Worker 1 is terminating.
Worker 1 is done.
Worker 2 is terminating.
Worker 0 is terminating.
Worker 3 is terminating.
Worker 2 is done.
Worker 0 is done.Worker 3 is done.

Workers are finished.
Finished data processing!


Next, lets split the dataset into training and validation sets. We will use 80% of the data for training and 10% for validation and 10% for test. We will also create a dataloader for the validation set.

In [6]:
from litdata import StreamingDataset, StreamingDataLoader, train_test_split
import torch
streaming_dataset = StreamingDataset(PROCESSED_DATA_DIR) # data are stored in the cloud

def custom_collate(batch):
    # Filter out None values
    batch = [item for item in batch if item is not None]
    
    if not batch:
        # Return empty tensors if the batch is empty
        return {
            "index": torch.tensor([], dtype=torch.long),
            "inputs": torch.tensor([], dtype=torch.float32),
            "mask": torch.tensor([], dtype=torch.bool),
            "stats": {}
        }
    
    # Process each key separately
    indices = torch.tensor([item["index"] for item in batch], dtype=torch.long)
    
    # Make sure arrays are writable by copying and convert to tensor
    inputs = torch.stack([torch.tensor(np.nan_to_num(item["inputs"].copy(), nan=0.0), dtype=torch.float32) for item in batch])
    masks = torch.stack([torch.tensor(item["mask"].copy(), dtype=torch.bool) for item in batch])
    
    # Get stats (use first non-empty item's stats)
    stats = next((item["stats"] for item in batch if "stats" in item), {})
    
    return {
        "index": indices,
        "inputs": inputs,
        "mask": masks,
        "stats": stats
    }

print(len(streaming_dataset)) # display the length of your data
# out: 100,000

train_dataset, val_dataset, test_dataset = train_test_split(streaming_dataset, splits=[0.8, 0.1, 0.1])

print("Train ", len(train_dataset))
train_dataloader = StreamingDataLoader(train_dataset, num_workers=4, batch_size=32, shuffle=True, collate_fn=custom_collate)  # Create DataLoader for training
# out: 80,000

print("Validation ", len(val_dataset))
val_dataloader = StreamingDataLoader(val_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)  # Create DataLoader for validation

test_dataloader = StreamingDataLoader(test_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)
# out: 10,000


13690
Train  10952
Validation  1369


In [7]:
# Get train_dataset and findout dataset shape for input dimensions
for batch in train_dataloader:
    print(batch['inputs'].shape)  # Check the shape of the input tensor
    # out: torch.Size([32, 2048, 5])
    # 32 is the batch size, 2048 is the sequence length, and 5 is the number of features
    break  # Only need to check the first batch

torch.Size([32, 2048, 5])


### Model Definition

Now that we have the dataset and dataloader, we can define the model. To do this, we will create a class that inherits from `nn.Module` and define the model architecture in the `__init__` method. We will also define the forward pass in the `forward` method.

The model architecture consists of an embedding layer, a transformer encoder, and a linear layer. The embedding layer converts the input data into a higher-dimensional space, the transformer encoder processes the data using self-attention mechanisms, and the linear layer outputs the final predictions.

In [8]:
# Transformer Layer
from flash_attn.modules.mha import MHA
from flash_attn.ops.rms_norm import RMSNorm
from flash_attn.ops.fused_dense import FusedMLP
from torch import nn, optim, Tensor
import lightning as pl

class TransformerLayer(nn.Module):
    def __init__(
            self,
            layer_idx: int,
            embed_dim: int,
            num_heads: int,
            mlp_ratio: float = 4.0,
            proj_groups: int = 1,
            droput: float = 0.05,
            fast_attention: bool = True,
        ):
        super(TransformerLayer, self).__init__()

        self.attention = MHA(
            embed_dim = embed_dim,                   # Dimension of the model
            num_heads = num_heads,                   # Number of attention heads
            causal = True,                           # Causal attention
            layer_idx = layer_idx,                   # Layer index for rotary embedding
            num_heads_kv = num_heads // proj_groups, # Number of heads for key/value
            rotary_emb_dim = embed_dim // num_heads, # Rotary embedding dimension
            use_flash_attn = fast_attention,         # Use flash attention
            return_residual = False,                 # Return residual connection
            dropout=droput,                          # Dropout rate
        )
        self.norm1 = RMSNorm(embed_dim)
        self.mlp = FusedMLP(embed_dim, int(embed_dim * mlp_ratio))
        self.norm2 = RMSNorm(embed_dim)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        # Apply pre-normalization for stability
        normed_x = self.norm1(x)
        attn_output = self.attention(normed_x)
        x = x + attn_output
        
        normed_x = self.norm2(x)
        mlp_output = self.mlp(normed_x)
        x = x + mlp_output
        return x

    
# Transformer Model
class AutoregressiveTransformerModel(pl.LightningModule):
    def __init__(
        self, 
        input_dim: int, 
        embed_dim: int, 
        num_heads: int, 
        num_layers: int, 
        mlp_ratio: float = 4.0,
        lr: float = 1e-4,
        weight_decay: float = 1e-2,  # Added weight decay
        **kwargs
    ):
        super().__init__()
        
        # Save hyperparameters
        self.save_hyperparameters()
        
        # Linear layer to expand feature size
        self.input_proj = nn.Linear(input_dim, embed_dim)
        
        # Initialize weights properly
        nn.init.xavier_uniform_(self.input_proj.weight)
        nn.init.zeros_(self.input_proj.bias)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer(layer_idx, embed_dim, num_heads, mlp_ratio) for layer_idx in range(num_layers)
        ])
        
        # Final linear layer to project back to input dimension for autoregressive prediction
        self.fc_out = nn.Linear(embed_dim, input_dim)
        nn.init.xavier_uniform_(self.fc_out.weight)
        nn.init.zeros_(self.fc_out.bias)
        
        # For tracking metrics
        self.train_loss = 0.0
        self.val_loss = 0.0
        
        # Learning rate
        self.lr = lr
        self.weight_decay = weight_decay
        
    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        """
        Forward pass with optional attention mask
        
        Args:
            x: Input tensor of shape [batch_size, seq_len, input_dim]
            mask: Boolean mask tensor of shape [batch_size, seq_len]
            
        Returns:
            Tensor of shape [batch_size, seq_len, input_dim]
        """
        # Check for NaN in input and replace with zeros
        if torch.isnan(x).any():
            x = torch.nan_to_num(x, nan=0.0)
            
        # Project input to embedding dimension
        x = self.input_proj(x)
        
        # Apply transformer layers with attention masking
        for layer in self.layers:
            x = layer(x, mask=mask)
            
            # Check for NaN after each layer (in case of instability)
            if torch.isnan(x).any():
                print(f"WARNING: NaN detected in transformer layer output")
                x = torch.nan_to_num(x, nan=0.0)
            
        # Project back to feature dimension 
        x = self.fc_out(x)
        
        return x
    
    def _calculate_autoregressive_loss(self, predictions: Tensor, targets: Tensor, mask: Tensor) -> Tensor:
        """
        Calculate autoregressive loss by comparing predictions at each timestep 
        with the actual next timestep values
        
        Args:
            predictions: Model predictions [batch_size, seq_len, feature_dim]
            targets: Target values [batch_size, seq_len, feature_dim]
            mask: Boolean mask [batch_size, seq_len]
            
        Returns:
            Tensor: Scalar loss value
        """
        # Shift predictions and targets to align for autoregressive loss
        # We predict the next timestep based on previous timesteps
        pred_shifted = predictions[:, :-1, :]  # Remove last prediction
        target_shifted = targets[:, 1:, :]    # Remove first target
        mask_shifted = mask[:, 1:]            # Adjust mask accordingly
        
        # Calculate MSE loss only on masked positions
        mse_loss = nn.MSELoss(reduction='none')(pred_shifted, target_shifted)
        
        # Apply mask to consider only valid positions (reshape mask to match loss dimensions)
        mask_expanded = mask_shifted.unsqueeze(-1).expand_as(mse_loss)
        masked_loss = mse_loss * mask_expanded
        
        # Average the loss over the masked positions, with safety check
        num_valid = mask_expanded.sum()
        
        # Prevent division by zero
        if num_valid == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True)
        
        # Check for any NaN in the loss
        if torch.isnan(masked_loss).any():
            print("WARNING: NaN in loss calculation!")
            masked_loss = torch.nan_to_num(masked_loss, nan=0.0)
        
        loss = masked_loss.sum() / max(num_valid, 1)  # Ensure denominator is at least 1
        
        return loss
    
    def training_step(self, batch, batch_idx):
        """
        Training step for autoregressive prediction
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs, mask)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)
        
        # Check for NaN loss and handle it
        if torch.isnan(loss).any():
            print("NaN loss detected in training step!")
            # Return a small constant loss instead
            loss = torch.tensor(0.01, device=loss.device, requires_grad=True)
        
        
        # Log the loss for monitoring
        self.train_loss = loss
        self.log('train_loss', loss, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """
        Validation step for autoregressive prediction
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs, mask)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)
        
        # Check for NaN loss and handle it
        if torch.isnan(loss).any():
            print("NaN loss detected in validation step!")
            # Return a small constant loss instead
            loss = torch.tensor(0.01, device=loss.device, requires_grad=True)
        
        # Log the loss for monitoring
        self.val_loss = loss
        self.log('val_loss', loss, prog_bar=True)
        
        return loss
    
    # Configuring the optimizer
    def configure_optimizers(self):
        # Use a smaller learning rate and add weight decay
        optimizer = optim.AdamW(
            self.parameters(), 
            lr=self.lr, 
            weight_decay=self.weight_decay,
            betas=(0.9, 0.999),
            eps=1e-8          # Increased epsilon for numerical stability
        )
        
        # Add a learning rate scheduler for better convergence
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=10,          # Match with max_epochs 
            eta_min=self.lr
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1
            }
        }

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [None]:

# Initialize the model
parameters = dict(
    input_dim = 5,
    embed_dim = 128,
    num_heads = 8,
    num_layers = 6,
    lr = 5e-5,            # Reduced learning rate
    weight_decay = 1e-2   # Added weight decay
)

model = AutoregressiveTransformerModel(**parameters)

# Lightning callbacks for early stopping and model checkpointing
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=CKPT_DIR,
    filename='best-model-{epoch:02d}-{val_loss:.4f}',
    save_top_k=1,
    mode='min'
)

trainer = pl.Trainer(
    limit_val_batches=200,
    max_epochs=10,
    accumulate_grad_batches=16,  
    gradient_clip_val=0.25,  # Updated for better stability 
    default_root_dir=CKPT_DIR,
    precision="bf16-mixed",  # Use mixed precision
    log_every_n_steps=10,
    callbacks=[early_stop_callback, checkpoint_callback]
) 

# Train the model
trainer.fit(model, train_dataloader, val_dataloader)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /workspace/datasets/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params | Mode 
--------------------------------------------------
0 | input_proj | Linear     | 768    | train
1 | layers     | ModuleList | 1.2 M  | train
2 | fc_out     | Linear     | 645    | train
--------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.758     Total estimated model params size (MB)
87        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]