In [1]:
import sys
from pathlib import Path

# Add ./src to Python path
sys.path.append(str(Path().resolve() / "src"))

# Now you can import modules
from data_loader import SpectogramDataset
from model import TinyBird
from transformers import BertConfig


2025-08-30 21:36:06.798480: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-30 21:36:06.823513: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
config = {
    "patch_size": (32, 8),           
    "max_seq": 512,                   
    "enc_hidden_d": 192,              # Encoder hidden dimension
    "dec_hidden_d": 192,              # Decoder hidden dimension  
    "enc_n_head": 4,                  # Encoder number of attention heads
    "enc_n_layer": 3,                 # Encoder number of transformer layers
    "enc_dim_ff": 768,                # Encoder feed-forward dimension
    "dec_n_head": 4,                  # Decoder number of attention heads
    "dec_n_layer": 3,                 # Decoder number of transformer layers
    "dec_dim_ff": 768,                # Decoder feed-forward dimension
    "dropout": 0.1,                   # Dropout rate
    "mask_p": .25
}

tinybird = TinyBird(config)





In [3]:
from torch.utils.data import DataLoader

test_dataset = SpectogramDataset(dir="/media/george-vengrovski/disk1/llb3_train", n_mels=128, n_timebins=1024)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

spec, label = next(iter(test_dataloader))

# z = tinybird.project_to_patch(spec)
# z = tinybird.encode_pos(z)

# print(z.shape)

Computing dataset statistics across 3451 files...
Dataset statistics - Mean: -63.9062, Std: 16.9844


In [4]:
# basic MAE-style training loop with recon dumps
import os, torch, matplotlib.pyplot as plt
from torch.optim import AdamW
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
model = tinybird.to(device)
opt = AdamW(model.parameters(), lr=1e-6, weight_decay=0.0)

viz_every = 100
max_steps = 20000
warmup_steps = 5000
target_lr = 5e-4
patch_size = config["patch_size"]
H, W = test_dataset.n_mels, test_dataset.n_timebins

def depatchify(pred: torch.Tensor) -> torch.Tensor:
    # pred: (B, T, P) → (B,1,H,W)
    fold = nn.Fold(output_size=(H, W), kernel_size=patch_size, stride=patch_size)
    return fold(pred.transpose(1, 2))

def save_recon(x: torch.Tensor, pred: torch.Tensor, step: int, out_dir="recons"):
    os.makedirs(out_dir, exist_ok=True)
    x_img = x[0, 0].detach().cpu().numpy()
    r_img = depatchify(pred)[0, 0].detach().cpu().numpy()
    fig = plt.figure(figsize=(12, 3))  # wide rectangular figure for time dimension
    ax1 = plt.subplot(2, 1, 1); ax1.imshow(x_img, origin="lower", aspect="auto"); ax1.set_title("input"); ax1.axis("off")
    ax2 = plt.subplot(2, 1, 2); ax2.imshow(r_img, origin="lower", aspect="auto"); ax2.set_title("recon"); ax2.axis("off")
    fig.tight_layout(); fig.savefig(f"{out_dir}/step_{step:06d}.png", dpi=150); plt.close(fig)

step = 0
model.train()
ema_loss = None
ema_alpha = 0.99

for epoch in range(2000):  # extend as needed
    for spec, _ in test_dataloader:
        # Learning rate schedule
        if step < warmup_steps:
            lr = 1e-6 + (target_lr - 1e-6) * (step / warmup_steps)
            for param_group in opt.param_groups:
                param_group['lr'] = lr
        
        x = spec.float().to(device, non_blocking=True)  # (B,1,H,W)
        

        h, idx_restore, bool_mask, T = model.forward_encoder(x)
        pred = model.forward_decoder(h, idx_restore, T)
        loss = model.loss_mse(x, pred, bool_mask)

        # Update EMA loss
        if ema_loss is None:
            ema_loss = loss.item()
        else:
            ema_loss = ema_alpha * ema_loss + (1 - ema_alpha) * loss.item()

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if step % viz_every == 0:
            model.eval()
            with torch.no_grad():
                h_v, idx_r_v, m_v, T_v = model.forward_encoder(x)
                pred_v = model.forward_decoder(h_v, idx_r_v, T_v)
                save_recon(x, pred_v, step)
                current_lr = opt.param_groups[0]['lr']
                print(f"Step {step}: Loss = {loss.item():.6f}, EMA Loss = {ema_loss:.6f}, LR = {current_lr:.2e}")
            model.train()

        step += 1
        if step >= max_steps:
            break
    if step >= max_steps:
        break


Step 0: Loss = 0.810557, EMA Loss = 0.810557, LR = 1.00e-06
Step 100: Loss = 0.524330, EMA Loss = 0.650493, LR = 1.10e-05
Step 200: Loss = 0.467206, EMA Loss = 0.498355, LR = 2.10e-05
Step 300: Loss = 0.444053, EMA Loss = 0.430169, LR = 3.09e-05
Step 400: Loss = 0.312186, EMA Loss = 0.376660, LR = 4.09e-05
Step 500: Loss = 0.285156, EMA Loss = 0.338288, LR = 5.09e-05
Step 600: Loss = 0.299312, EMA Loss = 0.315272, LR = 6.09e-05
Step 700: Loss = 0.298715, EMA Loss = 0.303633, LR = 7.09e-05
Step 800: Loss = 0.288405, EMA Loss = 0.296296, LR = 8.08e-05
Step 900: Loss = 0.241069, EMA Loss = 0.292467, LR = 9.08e-05
Step 1000: Loss = 0.249323, EMA Loss = 0.287762, LR = 1.01e-04
Step 1100: Loss = 0.296130, EMA Loss = 0.287744, LR = 1.11e-04
Step 1200: Loss = 0.251789, EMA Loss = 0.284401, LR = 1.21e-04
Step 1300: Loss = 0.234924, EMA Loss = 0.279878, LR = 1.31e-04
Step 1400: Loss = 0.277207, EMA Loss = 0.278854, LR = 1.41e-04
Step 1500: Loss = 0.308589, EMA Loss = 0.280710, LR = 1.51e-04
Step

KeyboardInterrupt: 