In [None]:
from src.requirements import *
from src.audio_handler import AudioDataset, collate_padding, ABXDataset, run_abx_val
from src.ssl_model import *

In [None]:
path = os.path.join('data', 'metadata.tsv')
batch_size = 4

train_dataset = AudioDataset(metadata_path=path)
train_dl = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    pin_memory = True,
    num_worker = 2,
    collate_fn = collate_padding, 
    shuffle = True,
    drop_last = True
)
abx_dataset = ABXDataset(metadata_path=path, segment_len=16000 * 2)
abx_loader = DataLoader(abx_dataset, batch_size=8, shuffle=False)

In [None]:
learning_rate = 2e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLModel().to(device)
for m in model.modules():
    if isinstance(m, nn.BatchNorm1d):
        m.eval()
        m.track_running_stats = False

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

scaler = torch.GradScaler(device)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0 = 5_000,
    T_mult = 2,
    eta_min = learning_rate * 0.1
)

In [None]:
def save_checkpoint(model, optimizer, scheduler, num_updates, path):
    checkpoint = {
        'num_updates' : num_updates,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'scheduler_state_dict' : scheduler.state_dict()
    }
    
    torch.save(checkpoint, path)

In [None]:
def train(model, train_dl, optimizer, scaler, scheduler, device):
    accum = 4
    max_updates = 20_000
    num_updates = 0
    epochs = 999

    for epoch in range(epochs):
        total_loss = 0
        print(f"\n--- Epoch [{epoch+1}/{epochs}] ---")
        model.train()
        model.encoder_m.eval()
        model.projector_m.eval()
        
        for i, batch in enumerate(tqdm(train_dl)):
            batch = batch.to(device)
            optimizer.zero_grad(set_to_none=True)
            
            with torch.autocast(device_type=device, dtype=torch.float16):
                loss = model(batch) / accum
                
            scaler.scale(loss).backward()
            total_loss += loss.item()
    
            if (i+1) % accum == 0:
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                num_updates += 1
                model.update_momentum()
    
            if num_updates % 1_000 == 0 and num_updates > 0:
                    save_path = os.path.join('models', 'ssl_model', f'ssl_model_prototype_{num_updates}.pth')
                    save_checkpoint(model, optimizer, scheduler, num_updates, save_path)
                    
            if num_updates >= max_updates:
                break
            
        avg_reported_loss = total_loss / len(train_dl)
        real_loss = avg_reported_loss * accum
        
        print(f'Reported Loss: {avg_reported_loss:.4f} | Real Loss: {real_loss:.4f}')
        torch.cuda.empty_cache()
    
        if num_updates >= max_updates:
            break

In [None]:
train(model, train_dl, optimizer, scaler, scheduler, device)