In [None]:
from src.requirements import *
from src.audio_handler import AudioDataset, collate_padding
from src.models import FeatureEncoder, ContextModule, ContrastivePredictor, SSLModel, compute_mask_indices

In [None]:
def contrastive_loss_chunked(z, q, mask=None, temperature=0.1, chunk_size=256):
    B, T, D = z.shape
    total_loss = 0.0
    total_valid = 0.0

    # first normalize
    z = F.normalize(z, dim=-1)
    q = F.normalize(q, dim=-1)

    # iterate through chunks
    for start in range(0, T, chunk_size):
        end = min(start + chunk_size, T)

        q_chunk = q[:, start:end, :]
        z_pos = z[:, start:end, :]

        sim_all = torch.bmm(q_chunk, z.transpose(1, 2)) / temperature

        sim_pos = torch.sum(q_chunk * z_pos, dim=-1) / temperature

        logsumexp = torch.logsumexp(sim_all, dim=-1)
        loss_chunk = -(sim_pos - logsumexp)

        if mask is not None:
            m = mask[:, start:end].float()
            loss_chunk = loss_chunk * m 
            valid_positions = m.sum()
        else:
            valid_positions = loss_chunk.numel()

        total_loss += loss_chunk.sum()
        total_valid += valid_positions

        del q_chunk, z_pos, sim_all, sim_pos, logsumexp, loss_chunk

    return total_loss / (total_valid + 1e-10)
    

In [None]:
path = os.path.join('data', 'metadata.tsv')
batch_size = 2

train_dataset = AudioDataset(metadata_path=path)
train_dl = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    pin_memory = True,
    collate_fn = collate_padding, 
    shuffle=True
)

In [None]:
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SSLModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

scaler = torch.GradScaler(device)
scheduler = StepLR(optimizer, step_size=1)

In [None]:
def save_checkpoint(model, optimizer, scheduler, num_updates, path):
    checkpoint = {
        'num_updates' : num_updates,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'scheduler_state_dict' : scheduler.state_dict()
    }
    
    torch.save(checkpoint, path)

In [None]:
def train(model, train_dl, loss_fn, optimizer, scaler, scheduler, device):
    accum = 8
    max_updates = 250_000
    num_updates = 0
    epochs = 999
    model.train()

    for epoch in range(epochs):
        print(f"Epoch [{epoch+1}/{epochs}]")
        
        for i, batch in enumerate(tqdm(train_dl)):
            batch = batch.to(device)
            
            with torch.autocast(device):
                z, q, mask = model(batch)
                loss = loss_fn(z, q, mask, chunk_size=128) / accum
            scaler.scale(loss).backward()
    
            if (i+1) % accum == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                num_updates += 1

            if num_updates % 10_000 == 0:
                save_path = os.path.join('models', 'ssl_model', f'ssl_model_prototype_{num_updates}.pth')
                save_checkpoint(model, optimizer, scheduler, num_updates, save_path)
                
            if num_updates >= max_updates:
                break
        
        scheduler.step()
        torch.cuda.empty_cache()

        if num_updates >= max_updates:
            break

In [None]:
train(model, train_dl, contrastive_loss_chunked, optimizer, scaler, scheduler, device)