# PyTorch DataLoader Integration with DeepBioP

This notebook demonstrates how to use DeepBioP's streaming datasets with PyTorch's DataLoader for efficient training on biological sequence data.

## Features Demonstrated
- Loading FASTQ, FASTA, and BAM files with streaming datasets
- Using PyTorch DataLoader for batching
- Multiprocessing with `num_workers > 0`
- Custom collate functions for biological data
- Distributed training with DistributedSampler

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader

## 1. Basic DataLoader Usage with FASTQ Files

The simplest way to use DeepBioP datasets with PyTorch is to create a streaming dataset and wrap it in a DataLoader.

In [None]:
from deepbiop.fq import FastqStreamDataset

# Create a streaming dataset
dataset = FastqStreamDataset("../tests/data/test.fastq")

# Wrap in DataLoader
loader = DataLoader(
    dataset,
    batch_size=4,
    num_workers=0,  # Single-threaded for now
    shuffle=False,  # Streaming datasets don't support shuffle
)

# Iterate through batches
for batch_idx, batch in enumerate(loader):
    print(f"Batch {batch_idx}:")
    print(f"  - Number of records: {len(batch)}")

    # Each batch is a list of dicts
    for item in batch:
        print(f"    ID: {item['id']}")
        print(f"    Sequence shape: {item['sequence'].shape}")
        print(
            f"    Quality shape: {item['quality'].shape if item['quality'] is not None else 'None'}"
        )

    if batch_idx >= 2:  # Only show first 3 batches
        break

## 2. Multiprocessing with DataLoader

For faster data loading, you can use multiple worker processes. DeepBioP datasets support pickling for multiprocessing.

In [None]:
# Create loader with multiple workers
loader = DataLoader(
    dataset,
    batch_size=8,
    num_workers=4,  # 4 worker processes
    shuffle=False,
)

# Count total records loaded
total_records = sum(len(batch) for batch in loader)
print(f"Total records loaded with multiprocessing: {total_records}")

## 3. Custom Collate Function for Sequence Data

Often you need to transform the data before batching. Here's an example collate function that converts sequences to tensors and pads them.

In [None]:
def bio_collate_fn(batch):
    """
    Custom collate function for biological sequences.

    Converts sequences to tensors and pads to max length in batch.
    """
    # Extract sequences and quality scores
    sequences = [torch.from_numpy(item["sequence"]).long() for item in batch]
    qualities = [
        torch.from_numpy(item["quality"]).long()
        if item["quality"] is not None
        else None
        for item in batch
    ]
    ids = [item["id"] for item in batch]

    # Pad sequences to max length in batch
    max_len = max(seq.shape[0] for seq in sequences)

    padded_seqs = torch.zeros(len(sequences), max_len, dtype=torch.long)
    for i, seq in enumerate(sequences):
        padded_seqs[i, : seq.shape[0]] = seq

    # Pad quality scores similarly
    if qualities[0] is not None:
        padded_quals = torch.zeros(len(qualities), max_len, dtype=torch.long)
        for i, qual in enumerate(qualities):
            if qual is not None:
                padded_quals[i, : qual.shape[0]] = qual
    else:
        padded_quals = None

    return {
        "sequences": padded_seqs,
        "qualities": padded_quals,
        "ids": ids,
        "lengths": torch.tensor([seq.shape[0] for seq in sequences]),
    }


# Use custom collate function
loader = DataLoader(dataset, batch_size=4, collate_fn=bio_collate_fn, num_workers=2)

# Check output format
batch = next(iter(loader))
print(f"Batch keys: {batch.keys()}")
print(f"Sequences shape: {batch['sequences'].shape}")
print(f"Lengths: {batch['lengths']}")

## 4. Training Loop Example

Here's a complete example of a training loop with a simple sequence classification model.

In [None]:
import torch.nn as nn
import torch.optim as optim


class SequenceClassifier(nn.Module):
    """Simple CNN for sequence classification."""

    def __init__(self, vocab_size=256, embed_dim=32, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.conv1 = nn.Conv1d(embed_dim, 64, kernel_size=7, padding=3)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x: Input tensor of shape (batch, seq_len) containing sequence indices.

        Returns
        -------
            Output tensor of shape (batch, num_classes) containing class logits.
        """
        # x: (batch, seq_len)
        x = self.embedding(x)  # (batch, seq_len, embed_dim)
        x = x.transpose(1, 2)  # (batch, embed_dim, seq_len)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x).squeeze(2)  # (batch, 128)
        return self.fc(x)


# Create model and optimizer
model = SequenceClassifier()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop
model.train()
num_epochs = 2

for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0

    for batch in loader:
        sequences = batch["sequences"]

        # Create dummy labels for demonstration
        labels = torch.randint(0, 2, (sequences.size(0),))

        # Forward pass
        outputs = model(sequences)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        if num_batches >= 10:  # Limit for demo
            break

    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

## 5. Using FASTA and BAM Files

The same DataLoader pattern works with FASTA and BAM files.

In [None]:
from deepbiop.bam import BamStreamDataset
from deepbiop.fa import FastaStreamDataset

# FASTA example
fasta_dataset = FastaStreamDataset("../tests/data/test.fasta")
fasta_loader = DataLoader(fasta_dataset, batch_size=4)

print("FASTA samples:")
for batch in fasta_loader:
    print(f"  Batch size: {len(batch)}")
    print(f"  First ID: {batch[0]['id']}")
    break

# BAM example (with threading support)
bam_dataset = BamStreamDataset("../tests/data/test.bam", threads=4)
bam_loader = DataLoader(bam_dataset, batch_size=4)

print("\nBAM samples:")
for batch in bam_loader:
    print(f"  Batch size: {len(batch)}")
    print(f"  First ID: {batch[0]['id']}")
    break

## 6. Worker Init Function for Reproducibility

For reproducible experiments with multiprocessing, use `worker_init_fn` to set seeds.

In [None]:
import random


def worker_init_fn(worker_id):
    """Initialize each worker with a unique but deterministic seed."""
    seed = 42 + worker_id
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


loader = DataLoader(dataset, batch_size=4, num_workers=4, worker_init_fn=worker_init_fn)

print("DataLoader with deterministic worker initialization created")

## 7. Distributed Training Setup

For distributed training across multiple GPUs or nodes, use DistributedSampler.

In [None]:
from torch.utils.data import DistributedSampler

# Note: In actual distributed training, these would come from torch.distributed
num_replicas = 2  # Number of GPUs/processes
rank = 0  # Current process rank

# Create distributed sampler
sampler = DistributedSampler(
    dataset,
    num_replicas=num_replicas,
    rank=rank,
    shuffle=False,  # Streaming datasets handle their own iteration
)

loader = DataLoader(
    dataset,
    batch_size=4,
    sampler=sampler,  # Use sampler instead of shuffle
    num_workers=2,
)

# Each process will see a subset of the data
records_seen = sum(len(batch) for batch in loader)
print(f"Rank {rank} processed {records_seen} records")

## Summary

This notebook demonstrated:
- ✅ Basic DataLoader usage with streaming datasets
- ✅ Multiprocessing with `num_workers`
- ✅ Custom collate functions for biological data
- ✅ Complete training loop example
- ✅ FASTA and BAM file support
- ✅ Worker initialization for reproducibility
- ✅ Distributed training setup

DeepBioP's streaming datasets integrate seamlessly with PyTorch's DataLoader, providing efficient, memory-friendly iteration over large biological files.