In [1]:
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

df = pd.read_csv('data/chordomicon_clean.csv')

# Hiperparàmetres
embedding_dim = 64
hidden_dim = 128
num_layers = 2
learning_rate = 1e-5
num_epochs = 10
batch_size = 16
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# --- Definició del dataset ---
class ChordDataset(Dataset):
    def __init__(self, data):
        self.data = self._preprocess(data)
        self.chord_vocab = self._build_vocab(self.data['chords_tokenized'])
        self.section_vocab = sorted(list(set(self.data['sections'])))
        self.genre_vocab = sorted(list(set(self.data['genres']))) 
        self.chord_to_index = {chord: idx for idx, chord in enumerate(self.chord_vocab)}
        self.section_to_index = {section: idx for idx, section in enumerate(self.section_vocab)}
        self.genre_to_index = {genre: idx for idx, genre in enumerate(self.genre_vocab)}
        self.index_to_chord = {idx: chord for chord, idx in self.chord_to_index.items()}

    def _preprocess(self, df):
        sections = []
        chords_tokenized = []
        genres = []
        for index, row in df.iterrows():
            chords_str = row['chords']
            main_genre = row['main_genre']
            # print(f"\nProcesando chords_str: '{chords_str}'")
            split_chords = chords_str.split('<')
            # print(f"Resultado de split('<'): {split_chords}")
            for item in split_chords:
                if '>' in item:
                    section_label, chord_sequence = item.split('>', 1)
                    cleaned_label = section_label.strip()
                    sections.append(cleaned_label)
                    chords_tokenized.append([chord.strip() for chord in chord_sequence.strip().split()])
                    genres.append(main_genre)
                    # print(f"  - Item: '{item}', Etiqueta extraída: '{cleaned_label}'")
        return {'sections': sections, 'chords_tokenized': chords_tokenized, 'genres': genres}

    def _build_vocab(self, token_lists):
        tokens = []
        for token_list in token_lists:
            tokens.extend(token_list)
        return sorted(list(set(tokens)))

    def __len__(self):
        return len(self.data['sections'])

    def __getitem__(self, idx):
        section = self.data['sections'][idx]
        chords = self.data['chords_tokenized'][idx]
        genre = self.data['genres'][idx]

        # print(f"Intentando acceder a la sección: '{section}'")
        if section not in self.section_to_index:
            print(f"¡¡¡ERROR!!! La sección '{section}' no está en self.section_to_index: {self.section_to_index.keys()}")

        section_index = self.section_to_index[section]
        chord_indices = [self.chord_to_index[chord] for chord in chords]
        genre_index = self.genre_to_index[genre]

        return {
            'section': torch.tensor(section_index, dtype=torch.long),
            'chords': torch.tensor(chord_indices[:-1], dtype=torch.long),
            'next_chord': torch.tensor(chord_indices[1:], dtype=torch.long),
            'genre': torch.tensor(genre_index, dtype=torch.long)
        }

In [3]:
def pad_sequences(batch):
    sections = [item['section'] for item in batch]
    genres = [item['genre'] for item in batch]
    chords = [item['chords'] for item in batch]
    next_chords = [item['next_chord'] for item in batch]

    chords_padded = torch.nn.utils.rnn.pad_sequence(chords, batch_first=True)
    next_chords_padded = torch.nn.utils.rnn.pad_sequence(next_chords, batch_first=True)

    return {
        'section': torch.stack(sections),
        'chords': chords_padded,
        'next_chord': next_chords_padded,
        'genre': torch.stack(genres)
    }

In [4]:
def create_dataloader(df, batch_size=32, shuffle=True):
    dataset = ChordDataset(df)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=pad_sequences)
    return dataloader, dataset.chord_vocab, dataset.section_vocab, dataset.genre_vocab, dataset.index_to_chord

dataloader, chord_vocab, section_vocab, genre_vocab, index_to_chord = create_dataloader(df, batch_size=batch_size, shuffle=True)

In [5]:
# --- Definició del model ---
class ChordGenerator(nn.Module):
    def __init__(self, chord_vocab_size, section_vocab_size, genre_vocab_size, embedding_dim, hidden_dim, num_layers):
        super(ChordGenerator, self).__init__()
        self.chord_embedding = nn.Embedding(chord_vocab_size, embedding_dim)
        self.section_embedding = nn.Embedding(section_vocab_size, embedding_dim)
        self.genre_embedding = nn.Embedding(genre_vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim * 3, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, chord_vocab_size)

    def forward(self, chords, section, genre):
        chord_embedded = self.chord_embedding(chords)
        section_embedded = self.section_embedding(section).unsqueeze(1).expand(-1, chords.size(1), -1)
        genre_embedded = self.genre_embedding(genre).unsqueeze(1).expand(-1, chords.size(1), -1)
        embedded = torch.cat((chord_embedded, section_embedded, genre_embedded), dim=2)
        output, _ = self.lstm(embedded)
        prediction = self.linear(output)
        return prediction
model = ChordGenerator(len(chord_vocab), len(section_vocab), len(genre_vocab), embedding_dim, hidden_dim, num_layers)
model.load_state_dict(torch.load("chord_generator2.pth"))


<All keys matched successfully>

In [7]:
# --- Train loop ---
def train(model, dataloader, learning_rate, num_epochs, device):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            chords = batch['chords'].to(device)
            next_chords = batch['next_chord'].to(device)
            sections = batch['section'].to(device)
            genres = batch['genre'].to(device)

            optimizer.zero_grad()
            outputs = model(chords, sections, genres)

            loss = criterion(outputs.view(-1, outputs.size(-1)), next_chords.view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Un valor de max_norm = 1.0 es un buen punto de partida

            optimizer.step()

            # *** DIAGNÓSTICO TEMPORAL DE LOSS ***
            if torch.isnan(loss).any() or torch.isinf(loss).any():
                print(f"ATENCIÓN: Loss se ha vuelto NaN/Inf en el batch {batch_idx}, epoch {epoch}. Interrumpiendo el entrenamiento.")
                # Opcional: podrías guardar el modelo aquí como 'modelo_fallido.pth' para inspeccionarlo
                return # Detiene el entrenamiento si el loss se corrompe
            # *** FIN DIAGNÓSTICO ***

            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
        
    if not (torch.isnan(torch.tensor(avg_loss)) or torch.isinf(torch.tensor(avg_loss))):
            torch.save(model.state_dict(), f"chord_generator_epoch_{epoch+1}.pth") # Guarda un modelo por época
            torch.save(model.state_dict(), "chord_generator.pth") # Sobreescribe el principal

train(model, dataloader, learning_rate, num_epochs, device)

Epoch [1/10], Loss: 0.7670
Epoch [2/10], Loss: 0.6276
Epoch [3/10], Loss: 0.5993
Epoch [4/10], Loss: 0.5813
Epoch [5/10], Loss: 0.5690
Epoch [6/10], Loss: 0.5583
Epoch [7/10], Loss: 0.5517
Epoch [8/10], Loss: 0.5453
Epoch [9/10], Loss: 0.5408
Epoch [10/10], Loss: 0.5371


In [6]:
# --- Generació d'acords ---
def generate_chords(model, start_sequence, section_label, genre_label, chord_to_index, section_to_index, genre_to_index, index_to_chord, max_length=50, device="cpu"):
    model.to(device)
    model.eval()
    with torch.no_grad():
        start_indices = [chord_to_index[chord] for chord in start_sequence]
        input_sequence = torch.tensor(start_indices, dtype=torch.long).unsqueeze(0).to(device)
        section_index = torch.tensor([section_to_index[section_label]], dtype=torch.long).to(device)
        genre_index = torch.tensor([genre_to_index[genre_label]], dtype=torch.long).to(device)

        generated_sequence = start_indices[:]

        for _ in range(max_length):
            outputs = model(input_sequence, section_index, genre_index)
            probabilities = torch.softmax(outputs[:, -1, :], dim=-1)
            next_chord_index = torch.multinomial(probabilities, num_samples=1).item()
            generated_sequence.append(next_chord_index)
            input_sequence = torch.cat((input_sequence, torch.tensor([[next_chord_index]], dtype=torch.long).to(device)), dim=1)

        return [index_to_chord[idx] for idx in generated_sequence]

In [7]:
# Exemple de generació
start_sequence = ["C", "F"]
section_label = "verse"
genre_label = "pop rock"
if start_sequence[0] not in chord_vocab or section_label not in section_vocab or genre_label not in genre_vocab:
    print(genre_vocab)
else:
    generated_chords = generate_chords(
        model, start_sequence, section_label, genre_label,
        dataloader.dataset.chord_to_index, dataloader.dataset.section_to_index,
        dataloader.dataset.genre_to_index, index_to_chord, device=device, max_length=2
    )
    print(f"\nGeneración para inicio: {start_sequence}, sección: {section_label}, género: {genre_label}")
    print(f"Acordes generados: {generated_chords}")

# Exemple de generació per una secció diferent
start_sequence_intro = ["C"]
section_label_intro = "intro"
genre_label_pop = "pop"
if section_label_intro not in section_vocab or genre_label_pop not in genre_vocab:
    print("Error: Uno de los tokens de entrada no está en el vocabulario.")
else:
    generated_chords_intro = generate_chords(
        model, start_sequence_intro, section_label_intro, genre_label_pop,
        dataloader.dataset.chord_to_index, dataloader.dataset.section_to_index,
        dataloader.dataset.genre_to_index, index_to_chord, device=device, max_length=10
    )
    print(f"\nGeneración para inicio: {start_sequence_intro}, sección: {section_label_intro}, género: {genre_label_pop}")
    print(f"Acordes generados para intro: {generated_chords_intro}")


Generación para inicio: ['C', 'F'], sección: verse, género: pop rock
Acordes generados: ['C', 'F', 'G', 'C']

Generación para inicio: ['C'], sección: intro, género: pop
Acordes generados para intro: ['C', 'D', 'Em', 'G', 'C', 'D', 'Em', 'G', 'C', 'D', 'Em']
