In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
from transformers import BertConfig, TrainingArguments, Trainer
from datasets import load_dataset

from model.model import RhythmicControlBert 
dataset = load_dataset("efraimdahl/syncopation-dataset",split="train")

In [30]:
print(dataset)
print(dataset[0])

Dataset({
    features: ['syncopation', 'distances', 'index', 'metric_vector', 'spectral_vector'],
    num_rows: 4
})
{'syncopation': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.396282718049293, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.36034531901159267, 0.8571428571428572, 0.8571428571428572, 0.8571428571428572, 0.8571428571428572, 0.8571428571428572, 0.0, 0.0, 0.396282718049293, 0.5302370585389454, 0.0, 0.8571428571428572, 0.8571428571428572, 0.5302370585389454, 0.36034531901159267, 0.5302370585389454, 0.396282718049293, 0.8571428571428572, 0.0, 0.0, 0.0, 0.8571428571428572, 0.8571428571428572, 0.5302370585389454, 0.36034531901159267, 0.0, 0.0, 0.8571428571428572, 0.0], 'distances': [48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 18.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 12.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 18.0, 24.0, 48.0, 48.0, 48.0, 24.0

In [32]:

config = BertConfig(
    hidden_size=768,
    num_hidden_layers=2,
    num_attention_heads=4,
    intermediate_size=3072,
    max_position_embeddings=2048 + 2,  # +2 for CLS/SEP if needed
    vocab_size=1,  # irrelevant since we're using inputs_embeds
)

GRID_SIZE = 48

In [40]:
def preprocess(example):
    return {
        "control_seq": torch.tensor(example["metric_vector"], dtype=torch.float),
        "labels": torch.tensor(example["syncopation"], dtype=torch.float),
    }

# Optional: preprocess and save as a new dataset
dataset.set_transform(preprocess)

def collate_fn(batch):
    control_seq = torch.stack([item["control_seq"] for item in batch])  # [B, 2048, 48]
    labels = torch.stack([item["labels"] for item in batch])            # [B, 2048]
    return {"control_seq": control_seq, "labels": labels}


model = RhythmicControlBert(config=config, control_dim=GRID_SIZE, num_targets=1)



In [41]:
print(dataset)
print(dataset.column_names)
print(dataset[0])

Dataset({
    features: ['syncopation', 'distances', 'index', 'metric_vector', 'spectral_vector'],
    num_rows: 4
})
['syncopation', 'distances', 'index', 'metric_vector', 'spectral_vector']
{'control_seq': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'labels': tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.3963, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.3603, 0.8571, 0.8571, 0.8571, 0.8571, 0.8571, 0.0000, 0.0000,
        0.3963, 0.5302, 0.0000, 0.8571, 0.8571, 0.5302, 0.3603, 0.5302, 0.3963,
        0.8571, 0.0000, 0.0000, 0.0000, 0.8571, 0.8571, 0.5302, 0.3603, 0.0000,
        0.0000, 0.8571, 0.0000])}


In [42]:
control_seq =  dataset[0]["control_seq"].unsqueeze(0)  # [1, 2048, 48]
labels = dataset[0]["labels"].unsqueeze(0)            # [1, 2048]
model.eval()
with torch.no_grad():
    out = model(control_seq, labels=labels)
    print("Test forward pass:", out["logits"].shape, "Loss:", out["loss"].item())

Test forward pass: torch.Size([1, 57, 1]) Loss: 0.7588778138160706


In [43]:
wandb.init(project="syncopation-transformer", name="pretraining-run")


training_args = TrainingArguments(
    output_dir="./checkpoints",
    evaluation_strategy="no",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # effective batch size = 16
    num_train_epochs=5,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    report_to="wandb",  # logs to W&B
    push_to_hub=True,
    hub_model_id="your-username/syncopation-transformer",
    hub_strategy="every_save",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
)


trainer.train()
trainer.push_to_hub()

NameError: name 'wandb' is not defined