# Encoder-Decoder Transformer Training on Databricks

This notebook trains an **Encoder-Decoder Transformer** (original Transformer architecture) for **Seq2Seq tasks** using **PyTorch DDP** on **multiple GPUs**.

## What is Encoder-Decoder?

- **Encoder**: Understands source sequence (bidirectional)
- **Decoder**: Generates target sequence (causal)
- **Cross-Attention**: Decoder attends to encoder output

## Tasks Available

1. **Reversal**: Reverse sequences ("hello" → "olleh")
2. **Copy**: Copy sequences ("hello" → "hello")
3. **Addition**: Math problems ("12+34" → "46")

## Cluster Requirements

- **Runtime**: DBR 13.3 ML or higher
- **Driver**: Multi-GPU instance (8x A10 or V100)
- **Workers**: 0 (single-node multi-GPU)


## Step 0: Check GPU Availability


In [None]:
import torch
import os

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(f"\nGPU {i}: {props.name}")
        print(f"  Memory: {props.total_memory / 1024**3:.2f} GB")

print(f"\nCurrent directory: {os.getcwd()}")


## Step 1: Setup & Install Dependencies


In [None]:
%pip install einops pyyaml tensorboard tqdm --quiet
print("Packages installed successfully!")


In [None]:
import os
from pathlib import Path

PROJECT_DIR = "/dbfs/tmp/transformer-ddp-lm"

print(f"Project directory: {PROJECT_DIR}")
os.makedirs(PROJECT_DIR, exist_ok=True)
os.chdir(PROJECT_DIR)
print(f"Changed to: {os.getcwd()}")


**Note**: Upload project files or git clone:
```bash
%sh
cd /dbfs/tmp
git clone https://github.com/your-username/transformer-ddp-lm.git
```


## Step 2: Select Task


In [None]:
import yaml

TASK = "reversal"

print(f"Selected task: {TASK}")
print("\nAvailable tasks:")
print("  - reversal: Reverse sequences (hello -> olleh)")
print("  - copy: Copy sequences (hello -> hello)")
print("  - addition: Math problems (12+34 -> 46)")

with open('configs/enc_dec_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['data']['task'] = TASK

with open('configs/enc_dec_config.yaml', 'w') as f:
    yaml.dump(config, f)

print(f"\nConfig updated with task: {TASK}")


## Step 3: Train Encoder-Decoder Transformer (Multi-GPU)


In [None]:
import torch

num_gpus = torch.cuda.device_count()

print(f"Starting Encoder-Decoder Transformer training with {num_gpus} GPUs...")
print(f"Task: {TASK}")
print("="*80)

!torchrun --standalone --nproc_per_node={num_gpus} train_enc_dec_ddp.py --config configs/enc_dec_config.yaml

print("\n" + "="*80)
print("Training completed!")
print("="*80)


## Step 4: Test Inference


In [None]:
test_inputs = {
    "reversal": ["hello", "world", "transformer"],
    "copy": ["hello", "world", "transformer"],
    "addition": ["12+34", "25+37", "100+200"]
}

print(f"Testing {TASK} task:")
print("="*80)

for source in test_inputs[TASK]:
    print(f"\nTesting: {source}")
    !python inference_enc_dec.py --checkpoint checkpoints_enc_dec/best_model.pt --source {source} --max-length 30
    print("-"*80)


## Step 5: Interactive Testing


In [None]:
import torch
from models.transformer_enc_dec import EncoderDecoderTransformer
from models.config import TransformerConfig
from data.seq2seq_dataset import Seq2SeqDataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

checkpoint_path = 'checkpoints_enc_dec/best_model.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)

model_config_dict = checkpoint['config']['model']
model_config = TransformerConfig(
    vocab_size=model_config_dict['vocab_size'],
    max_seq_len=model_config_dict['max_seq_len'],
    dim=model_config_dict['dim'],
    depth=model_config_dict['depth'],
    heads=model_config_dict['heads'],
    dim_head=model_config_dict['dim_head'],
    mlp_dim=model_config_dict['mlp_dim'],
    dropout=model_config_dict['dropout'],
    use_rotary_emb=False,
)

model = EncoderDecoderTransformer(model_config).to(device)

state_dict = checkpoint['model_state_dict']
if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k[7:]: v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
model.eval()

dataset = Seq2SeqDataset(num_samples=1, seq_len=50, task=TASK, vocab_size=256)

print("Model loaded successfully!")
print(f"Task: {TASK}")
print(f"Parameters: {model.num_parameters():,}")


In [None]:
def translate(source_text, max_length=50, temperature=0.8):
    src = dataset.encode(source_text)
    src = torch.cat([src, torch.tensor([dataset.eos_token_id])])
    src = src.unsqueeze(0).to(device)
    
    with torch.no_grad():
        generated = model.generate(
            src,
            max_length=max_length,
            bos_token_id=dataset.bos_token_id,
            eos_token_id=dataset.eos_token_id,
            temperature=temperature,
            top_k=50,
            top_p=0.9,
        )
    
    return dataset.decode(generated[0])

test_source = "hello world"
result = translate(test_source)

print(f"Source: {test_source}")
print(f"Target: {result}")


## Summary

Congratulations! You've successfully:

1. Trained an Encoder-Decoder Transformer with multi-GPU DDP
2. Learned Seq2Seq tasks (reversal/copy/addition)
3. Generated outputs using the trained model

### Next Steps

1. Try different tasks: Change `TASK` variable and rerun
2. Adjust model size in `configs/enc_dec_config.yaml`
3. Use your own Seq2Seq dataset
4. Implement beam search for better generation

### Files Created

- **Model checkpoint**: `checkpoints_enc_dec/best_model.pt`
- **Training logs**: `logs_enc_dec/`
- **Config**: `configs/enc_dec_config.yaml`
