# PyTorch Lightning Integration with DeepBioP

This notebook demonstrates how to use DeepBioP's BiologicalDataModule with PyTorch Lightning for streamlined deep learning workflows.

## Features Demonstrated
- Using BiologicalDataModule for train/val/test splits
- Integration with Lightning Trainer
- Automatic file type detection
- Multi-GPU training support
- Best practices for biological data

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn

from deepbiop.lightning import BiologicalDataModule

## 1. Basic BiologicalDataModule Setup

The BiologicalDataModule handles train/val/test splits and creates DataLoaders automatically.

In [None]:
# Create data module with train/val/test splits
data_module = BiologicalDataModule(
    train_path="../tests/data/test.fastq",
    val_path="../tests/data/test.fastq",  # Using same file for demo
    test_path="../tests/data/test.fastq",  # Using same file for demo
    batch_size=8,
    num_workers=2,
)

# Setup creates the datasets
data_module.setup(stage="fit")

# Access dataloaders
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")

## 2. Automatic File Type Detection

BiologicalDataModule automatically detects file types from extensions.

In [None]:
# FASTQ files (.fastq, .fq, .fastq.gz)
fastq_module = BiologicalDataModule(train_path="../tests/data/test.fastq", batch_size=4)

# FASTA files (.fasta, .fa, .fasta.gz)
fasta_module = BiologicalDataModule(train_path="../tests/data/test.fasta", batch_size=4)

# BAM files (.bam)
bam_module = BiologicalDataModule(train_path="../tests/data/test.bam", batch_size=4)

print("Data modules created with automatic file type detection")

## 3. Creating a Lightning Module

Define a LightningModule for your model. Here's a simple sequence classifier.

In [None]:
class SequenceClassifier(pl.LightningModule):
    """Simple CNN-based sequence classifier."""

    def __init__(
        self, vocab_size=256, embed_dim=32, num_classes=2, learning_rate=0.001
    ):
        super().__init__()
        self.save_hyperparameters()

        # Model layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.conv1 = nn.Conv1d(embed_dim, 64, kernel_size=7, padding=3)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(128, num_classes)

        # Loss function
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, sequences):
        """Forward pass."""
        x = self.embedding(sequences)  # (batch, seq_len, embed_dim)
        x = x.transpose(1, 2)  # (batch, embed_dim, seq_len)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x).squeeze(2)  # (batch, 128)
        return self.fc(x)

    def training_step(self, batch, batch_idx):
        """Training step."""
        # batch is a list of dicts from our dataset
        # Convert to tensors (in practice, use a custom collate_fn)
        sequences = torch.stack(
            [torch.from_numpy(item["sequence"]).long() for item in batch]
        )

        # Create dummy labels for demonstration
        labels = torch.randint(0, 2, (len(batch),))

        # Forward pass
        logits = self(sequences)
        loss = self.criterion(logits, labels)

        # Log metrics
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step."""
        sequences = torch.stack(
            [torch.from_numpy(item["sequence"]).long() for item in batch]
        )
        labels = torch.randint(0, 2, (len(batch),))

        logits = self(sequences)
        loss = self.criterion(logits, labels)

        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        """Test step."""
        sequences = torch.stack(
            [torch.from_numpy(item["sequence"]).long() for item in batch]
        )
        labels = torch.randint(0, 2, (len(batch),))

        logits = self(sequences)
        loss = self.criterion(logits, labels)

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("test_loss", loss)
        self.log("test_acc", acc)
        return loss

    def configure_optimizers(self):
        """Configure optimizer."""
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

## 4. Training with Lightning Trainer

Use the Lightning Trainer to train your model with the BiologicalDataModule.

In [None]:
# Create model
model = SequenceClassifier(
    vocab_size=256, embed_dim=32, num_classes=2, learning_rate=0.001
)

# Create data module
data_module = BiologicalDataModule(
    train_path="../tests/data/test.fastq",
    val_path="../tests/data/test.fastq",
    batch_size=8,
    num_workers=2,
)

# Create trainer
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="auto",  # Automatically use GPU if available
    devices=1,
    enable_progress_bar=True,
    enable_checkpointing=True,
    default_root_dir="./lightning_logs",
)

# Train the model
trainer.fit(model, data_module)

## 5. Testing the Model

After training, evaluate on the test set.

In [None]:
# Setup test data
data_module.setup(stage="test")

# Test the model
results = trainer.test(model, data_module)
print("Test results:", results)

## 6. Advanced: Custom Collate Function

For production use, create a custom DataModule with a collate function.

In [None]:
from torch.utils.data import DataLoader


class CustomBioDataModule(BiologicalDataModule):
    """Custom DataModule with padding collate function."""

    @staticmethod
    def collate_fn(batch):
        """Custom collate with padding."""
        sequences = [torch.from_numpy(item["sequence"]).long() for item in batch]

        # Pad to max length
        max_len = max(seq.shape[0] for seq in sequences)
        padded = torch.zeros(len(sequences), max_len, dtype=torch.long)

        for i, seq in enumerate(sequences):
            padded[i, : seq.shape[0]] = seq

        # Create dummy labels
        labels = torch.randint(0, 2, (len(batch),))

        return {
            "sequences": padded,
            "labels": labels,
            "lengths": torch.tensor([seq.shape[0] for seq in sequences]),
        }

    def train_dataloader(self):
        """Override to use custom collate."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        """Override to use custom collate."""
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )


# Use custom data module
custom_data_module = CustomBioDataModule(
    train_path="../tests/data/test.fastq",
    val_path="../tests/data/test.fastq",
    batch_size=8,
    num_workers=2,
)

print("Custom DataModule with padding created")

## 7. Multi-GPU Training

Lightning makes multi-GPU training easy.

In [None]:
# Multi-GPU trainer
multi_gpu_trainer = pl.Trainer(
    max_epochs=3,
    accelerator="gpu",
    devices=2,  # Use 2 GPUs
    strategy="ddp",  # Distributed Data Parallel
    enable_progress_bar=True,
)

# Note: This will only work if you have multiple GPUs
# The DataModule handles distributed data loading automatically
print("Multi-GPU trainer configured (requires multiple GPUs)")

## 8. Logging and Checkpointing

Lightning provides built-in logging and checkpointing.

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Create logger
logger = TensorBoardLogger(save_dir="./logs", name="sequence_classifier")

# Create callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="./checkpoints",
    filename="best-{epoch:02d}-{val_loss:.2f}",
    save_top_k=3,
    mode="min",
)

early_stop_callback = EarlyStopping(monitor="val_loss", patience=3, mode="min")

# Trainer with callbacks and logger
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger,
    enable_progress_bar=True,
)

print("Trainer with logging and checkpointing configured")

## Summary

This notebook demonstrated:
- ✅ BiologicalDataModule for train/val/test splits
- ✅ Automatic file type detection (FASTQ, FASTA, BAM)
- ✅ Creating Lightning modules for biological data
- ✅ Training and testing with Lightning Trainer
- ✅ Custom collate functions for padding
- ✅ Multi-GPU training setup
- ✅ Logging and checkpointing

BiologicalDataModule simplifies deep learning on biological data by handling:
- Data loading from multiple file formats
- Train/val/test splitting
- DataLoader configuration
- Integration with Lightning's ecosystem

This allows you to focus on model development rather than data engineering!