Imports

In [None]:
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

ConfiguraciÃ³ global

In [None]:
# ---------- CONFIG ----------
JSON_PATH = "data/simplified_chords.json"
EPOCHS = 10
BATCH_SIZE = 64
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# ---------- LOAD DATA ----------
with open(JSON_PATH, "r") as f:
    nested_data = json.load(f)

raw_data = {}
for artist, songs in nested_data.items():
    for song_name, chord_list in songs.items():
        raw_data[song_name] = chord_list
        
all_chords = [ch for song in raw_data.values() for ch in song]
unique_chords = sorted(set(all_chords))
chord2idx = {ch: i for i, ch in enumerate(unique_chords)}
idx2chord = {i: ch for ch, i in chord2idx.items()}


def split_structure_fixed(seq):
    unique = []
    for ch in seq:
        if ch not in unique:
            unique.append(ch)
        if len(unique) >= 10:
            break
    if len(unique) < 10:
        return None
    verse = unique[0:4]
    bridge = unique[4:6]
    chorus = unique[6:10]
    return {
        "verse1": verse,
        "verse2": verse,
        "bridge": bridge,
        "chorus": chorus
    }


structured_data = []
for song in raw_data.values():
    tokens = [chord2idx[ch] for ch in song if ch in chord2idx]
    parts = split_structure_fixed(tokens)
    if parts:
        structured_data.append(parts)

# ---------- DATASET ----------
class ChordDataset(Dataset):
    def __init__(self, structured_data):
        self.samples = []
        for song in structured_data:
            full = song["verse1"] + song["verse2"] + song["bridge"] + song["chorus"]
            for i in range(len(full) - 1):
                self.samples.append((full[i], full[i+1]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x), torch.tensor(y)


dataset = ChordDataset(structured_data)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Model

In [None]:
# ---------- MODEL ----------
class ChordLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden


model = ChordLSTM(len(chord2idx)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

Train loop

In [None]:
# ---------- TRAIN ----------
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = x.unsqueeze(1)  # shape: (batch, seq_len=1)
        out, _ = model(x)
        loss = criterion(out.squeeze(), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")

torch.save(model.state_dict(), "model.pth")

Test loop

In [None]:
# ---------- GENERATE ----------
def generate_structured_song(seed_chord):
    model.load_state_dict(torch.load("model.pth"))
    model.eval()
    chord_seq = [chord2idx[seed_chord]]
    hidden = None
    with torch.no_grad():
        while len(set(chord_seq)) < 10:
            x = torch.tensor([[chord_seq[-1]]], device=DEVICE)
            out, hidden = model(x, hidden)
            probs = torch.softmax(out[0, -1], dim=0)
            next_chord = torch.multinomial(probs, 1).item()
            if next_chord not in chord_seq:
                chord_seq.append(next_chord)

    verse = chord_seq[0:4]
    bridge = chord_seq[4:6]
    chorus = chord_seq[6:10]
    full_song = verse + verse + bridge + chorus
    return [idx2chord[i] for i in full_song]

# ---------- EXAMPLE ----------
random_seed_chord = random.choice(list(chord2idx.keys()))
generated_song = generate_structured_song(random_seed_chord)

print("\nðŸŽµ Generated Song Structure ðŸŽµ")
print("Verse:", generated_song[0:4])
print("Verse:", generated_song[4:8])
print("Bridge:", generated_song[8:10])
print("Chorus:", generated_song[10:14])