In [5]:
import sys
import os
sys.path.append(os.path.abspath('..'))

In [8]:
import torch
from torch.utils.data import DataLoader, random_split
from datasets.fcnn_dataset import FCNNDataset
from models.fcnn import FCNN
from utils.train import train_model, evaluate
from preprocessing.main_preprocess import preprocess_abc_dataset
import random
import numpy as np

# Set seed for reproducibility
SEED = 42
torch.manual_seed(SEED)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
# === 2. Load Data and Create Test Loader ===
vocab, inv_vocab, indexed_melodies, token_freq, normalized_melodies = preprocess_abc_dataset("../data/")
WINDOW_SIZE = 16
dataset = FCNNDataset(indexed_melodies, vocab, inv_vocab, WINDOW_SIZE)

# Split using the same logic and seed
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
generator = torch.Generator().manual_seed(SEED)
_, _, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=generator)

test_loader = DataLoader(test_ds, batch_size=1024)

Total raw tunes extracted: 1049
Total tunes after cleaning: 1034
Example melody: M:3/4 L:1/4 K:G e|:"G"d2B|"D"A3/2B/2c|"G"B2G|
Example tokens: ['M:3/4', 'L:1/4', 'K:G', 'e', '|:', '"G"', 'd2', 'B', '|', '"D"', 'A3/2', 'B/2', 'c', '|', '"G"', 'B2', 'G', '|']
Number of unique tokens: 440


In [14]:
# === 3. Load the Trained Model ===
model = FCNN(WINDOW_SIZE, vocab_size=len(vocab), embed_dim=128, hidden_dim=512, dropout=0.5).to(device)
model.load_state_dict(torch.load('../saved_models/fcnn_model.pt', map_location=device))
model.to(device)

  model.load_state_dict(torch.load('../saved_models/fcnn_model.pt', map_location=device))


FCNN(
  (embedding): Embedding(440, 128)
  (fc1): Linear(in_features=2432, out_features=512, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=512, out_features=440, bias=True)
)

In [15]:
# === 4. Evaluate on Test Set ===

loss_fn = torch.nn.CrossEntropyLoss()
test_loss = evaluate(model, test_loader, loss_fn, device)
print(f"Test Loss: {test_loss:.4f}")

Test Loss: 1.8832


In [None]:
# === 5. Generate Sample Music ===
def generate_music_sample(model, loader, vocab, inv_vocab, num_tokens=100):
    model.eval()
    with torch.no_grad():
        for batch_x, _ in loader:
            batch_x = batch_x.to(device)
            break

        input_seq = batch_x[0].unsqueeze(0)  # shape: [1, seq_len]
        generated = input_seq.clone()

        context_tokens = input_seq[:, :3]  # M:, L:, K:
        rolling_window = input_seq[:, 3:]  # actual melody

        for _ in range(num_tokens):
            rolling_window = generated[:, -WINDOW_SIZE:]  # keep last notes only
            input_window = torch.cat((context_tokens, rolling_window), dim=1)

            output = model(input_window)
            next_token = torch.argmax(output, dim=-1).unsqueeze(1)
            generated = torch.cat((generated, next_token), dim=1)

        generated_tokens = [inv_vocab[idx.item()] for idx in generated[0]]
        abc_sequence = ''.join(generated_tokens)

    return abc_sequence

generated_abc = generate_music_sample(model, test_loader, vocab, inv_vocab)
print(generated_abc)

M:3/4L:1/4K:Gceg|"C"g2f/2e/2|"G"d3/2c/2B|"D7"AGA|"G"B2B|"D7"AGF|"G"GBG|"G"GBd|"C"e2e|"G"dBG|"C"edc|"G"B2B|"D7"ABA|"G"GBd|"C"e2e|"G"dBG|"C"edc|"G"B2B|"D7"ABA|"G"GBd|"C"e2e|"G"dBG|"C"edc|"G"B2B|"D7"ABA
