In [None]:
from src.requirements import *
from src.audio_handler import AudioDataset, collate_padding
from src.ssl_model import *
from src.ssl_large import *

In [None]:
path = os.path.join('data', 'metadata_normal.tsv')
cache = os.path.join('data', 'cache_mmap', 'ssl')
batch_size = 4

train_dataset = AudioDataset(
    metadata_path = path,
    cache_dir = cache,
    top_db=TOP_DB
)

train_dl = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_padding
)

In [None]:
# Test with a few samples
test_samples = [train_dataset[i] for i in range(4)]

print("Sample shapes:")
for i, sample in enumerate(test_samples):
    print(f"  Sample {i}: {sample.shape if isinstance(sample, torch.Tensor) else type(sample)}")

# Test collate
collated = collate_padding(test_samples)
print(f"\nCollated shape: {collated.shape}")
print(f"Expected: (4, 1, max_length)")

In [None]:
for sample in train_dataset:
    print(sample)
    break

In [None]:
# For medium model (~95M params)
learning_rate = 1e-4  # Lower LR for larger model
weight_decay = 0.05   # Higher weight decay
batch_size = 4       # Smaller batch (memory constraints)
gradient_accumulation = 8  # Effective batch = 512
max_updates = 150_000  # More updates for larger model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Check parameter count
model = LargeSSLModel().to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"  Encoder: {sum(p.numel() for p in model.encoder.parameters()):,}")
print(f"  Context: {sum(p.numel() for p in model.context.parameters()):,}")
print(f"  Projector: {sum(p.numel() for p in model.projector.parameters()):,}")
print(f"  Predictor: {sum(p.numel() for p in model.predictor.parameters()):,}")

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
    betas=(0.9, 0.95)  # Different betas for large models
)

# Scheduler with warmup
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = (current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10_000,
    num_training_steps=150_000
)

# Mixed precision
scaler = torch.amp.GradScaler(device)

print(f"Model size: {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
print(f"Effective batch size: {batch_size * gradient_accumulation}")

print(f"\nTraining Configuration:")
print(f"  Learning rate: {learning_rate}")
print(f"  Weight decay: {weight_decay}")
print(f"  Scheduler: CosineAnnealingWarmRestarts (T_0=5000, T_mult=2)")
print(f"  Mixed precision: {scaler is not None}")

In [None]:
def save_checkpoint(model, optimizer, scheduler, num_updates, save_path, scaler=None):
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Prepare checkpoint dictionary
    checkpoint = {
        'num_updates': num_updates,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }
    
    # Add scaler if using mixed precision
    if scaler is not None:
        checkpoint['scaler_state_dict'] = scaler.state_dict()
    
    # Save checkpoint
    torch.save(checkpoint, save_path)
    print(f"✓ Checkpoint saved: {save_path}")


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device, scaler=None):
    # Load checkpoint
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    print("✓ Model state loaded")
    
    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("✓ Optimizer state loaded")
    
    # Load scheduler state
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    print("✓ Scheduler state loaded")
    
    # Load scaler state if available
    if scaler is not None and 'scaler_state_dict' in checkpoint:
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        print("✓ Scaler state loaded")
    
    # Get number of updates
    num_updates = checkpoint['num_updates']
    print(f"✓ Resumed from update {num_updates}")
    
    if scaler is not None:
        return model, optimizer, scheduler, num_updates, scaler
    else:
        return model, optimizer, scheduler, num_updates


In [None]:
def train(model, train_dl, optimizer, scaler, scheduler, device):    
    accum = 4
    max_updates = 50_000
    num_updates = 0
    epochs = 25
    
    for epoch in range(epochs):
        # Training
        model.train()
        model.encoder_m.eval()
        model.context_m.eval()
        model.projector_m.eval()
        
        train_loss = 0
        
        for i, batch in enumerate(tqdm(train_dl, desc=f"Epoch {epoch+1}")):
            batch = batch.to(device)
            
            with torch.autocast(device_type=device, dtype=torch.float16):
                loss = model(batch) / accum
            
            scaler.scale(loss).backward()
            train_loss += loss.item() * accum
            
            if (i+1) % accum == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
                num_updates += 1
                model.update_momentum()
                
                # Regular checkpoints
                if num_updates % 5_000 == 0:
                    save_path = f'models/ssl_model/ssl_model_checkpoint_{num_updates}.pth'
                    save_checkpoint(model, optimizer, scheduler, num_updates, save_path, scaler)
                
                if num_updates >= max_updates:
                    break
        
        avg_train_loss = train_loss / len(train_dl)
        print(f'Epoch {epoch+1}:')
        print(f'  Train Loss: {avg_train_loss:.4f}')
        
        torch.cuda.empty_cache()
        
        if num_updates >= max_updates:
            break
    
    return model


In [None]:
train(model, train_dl, optimizer, scaler, scheduler, device)