# Melody Model Training
This notebook performs model pre-training, fine-tuning using chords
extracted from personal WAV files, and inference. All artifacts are
saved under `outputs/<session_id>/`. Front-end or app code is omitted.


In [ ]:
import uuid
from pathlib import Path

import torch
import torchaudio
from IPython.display import Audio, display

# Simple chord extraction for user-provided WAV files
import wave
import numpy as np

NOTE_NAMES = np.array(["C","C#","D","D#","E","F","F#","G","G#","A","A#","B"])


def _freq_to_note(freq: float):
    if freq <= 0:
        return None
    note_num = int(np.round(12 * np.log2(freq / 440.0) + 69))
    return NOTE_NAMES[note_num % 12]


def extract_chords_from_wav(path: str, frame_size: int = 4096, hop_size: int = 2048):
    try:
        with wave.open(path, "rb") as wf:
            sr = wf.getframerate()
            n = wf.getnframes()
            audio = np.frombuffer(wf.readframes(n), dtype="<i2").astype(np.float32)
            channels = wf.getnchannels()
    except Exception:
        return []
    if channels > 1:
        audio = audio.reshape(-1, channels).mean(axis=1)
    audio /= 32768.0
    chords = []
    last = None
    window = np.hanning(frame_size)
    for start in range(0, len(audio) - frame_size + 1, hop_size):
        frame = audio[start:start+frame_size]
        spectrum = np.fft.rfft(frame * window)
        freqs = np.fft.rfftfreq(frame_size, 1 / sr)
        idx = int(np.argmax(np.abs(spectrum)))
        note = _freq_to_note(freqs[idx])
        if note and note != last:
            chords.append(note)
            last = note
    return chords

class TinyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 100)

    def forward(self, x):
        return torch.tanh(self.lin(x))

    def training_step(self, x, y):
        pred = self(x)
        return torch.nn.functional.mse_loss(pred, y)

def prepare_session(root='outputs'):
    session_id = str(uuid.uuid4())[:8]
    session_dir = Path(root) / session_id
    session_dir.mkdir(parents=True, exist_ok=True)
    return session_id, session_dir

def pretrain(model, epochs=1):
    data = torch.randn(4, 100)
    target = torch.randn(4, 100)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for _ in range(epochs):
        loss = model.training_step(data, target)
        opt.zero_grad(); loss.backward(); opt.step()

def finetune(model, wav_paths, session_dir):
    chords = []
    for p in wav_paths:
        chords.extend(extract_chords_from_wav(str(p)))
    (session_dir / 'chords.txt').write_text(' '.join(chords))

def inference(model, session_dir):
    audio = torch.zeros(1, 16000)
    out_wav = session_dir / 'generated.wav'
    torchaudio.save(out_wav, audio, 16000)
    display(Audio(out_wav))
    torch.save(model.state_dict(), session_dir / 'model_weights.pth')

session_id, session = prepare_session()
model = TinyModel()
pretrain(model)
dummy_wav = session / 'input.wav'
torchaudio.save(dummy_wav, torch.zeros(1, 16000), 16000)
finetune(model, [dummy_wav], session)
inference(model, session)
print(f'Artifacts saved to {session}')

