In [None]:
#make sure to launch this script using torchrun
#bash: torchrun --nproc_per_node=4 train.py

import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import GradScaler, autocast
from torchvision import transforms

# --- Your custom dataset ---
class TimeSeriesDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data  # Tensor of shape [N, 1, 3_000_000]
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# --- Your 1D CNN Model ---
class CNN1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=7, stride=2),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Conv1d(16, 32, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Flatten(),
            nn.Linear(32 * 187000, 128),  # Adjust depending on input size
            nn.ReLU(),
            nn.Linear(128, 10)  # Assuming 10-class output
        )

    def forward(self, x):
        return self.net(x)

# --- Training Function ---
def train(rank, world_size):
    # DDP setup
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # Dataset and Dataloader
    dataset = TimeSeriesDataset(torch.randn(200_000, 1, 3_000_000), torch.randint(0, 10, (200_000,)))
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=4, sampler=sampler, num_workers=4, pin_memory=True)

    # Model, Loss, Optimizer, AMP
    model = CNN1D().to(rank)
    model = DDP(model, device_ids=[rank])
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = GradScaler()

    # Training loop
    for epoch in range(50):
        sampler.set_epoch(epoch)
        model.train()
        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(rank, non_blocking=True), targets.to(rank, non_blocking=True)

            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if i % 50 == 0 and rank == 0:
                print(f"[GPU {rank}] Epoch {epoch} Batch {i} Loss: {loss.item():.4f}")

    # Cleanup
    dist.destroy_process_group()

# --- Entry Point ---
def main():
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()
