In [1]:
import os
import glob
import random
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import pretty_midi
import numpy as np

In [2]:
SAMPLE_RATE = 16000
CHUNK_SIZE = 16000  # 1 second of audio
QUANTIZATION_CHANNELS = 4
MIDI_DIR = "./maestro-v3.0.0"     # Maestro dataset root directory
DATA_DIR = "./midi_wav"           # Directory to store WAV files
MAX_MIDI_FILES = 10               # Only convert this many MIDI files


In [None]:
def midi_to_wav(midi_path, wav_path):
    midi = pretty_midi.PrettyMIDI(midi_path)
    audio = midi.fluidsynth(sf2_path="soundfonts/FluidR3_GM/FluidR3_GM.sf2", fs=SAMPLE_RATE)
    audio_tensor = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)  # [1, T]
    torchaudio.save(wav_path, audio_tensor, SAMPLE_RATE)


In [4]:
def preprocess_wav(path):
    waveform, sr = torchaudio.load(path)
    if sr != SAMPLE_RATE:
        resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
        waveform = resampler(waveform)
    waveform = waveform.mean(dim=0, keepdim=True)  # mono
    mu_law = torchaudio.transforms.MuLawEncoding(quantization_channels=QUANTIZATION_CHANNELS)
    encoded = mu_law(waveform)  # [1, T]
    return encoded.squeeze(0)  # [T]

In [5]:
class AudioDataset(Dataset):
    def __init__(self, file_paths):
        self.samples = []
        for path in file_paths:
            encoded = preprocess_wav(path)
            for i in range(0, len(encoded) - CHUNK_SIZE):
                chunk = encoded[i:i+CHUNK_SIZE+1]  # +1 for target
                self.samples.append(chunk)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        chunk = self.samples[idx]
        return chunk[:-1], chunk[1:]  # input, target

In [6]:
class SimpleWaveNet(nn.Module):
    def __init__(self, n_channels=QUANTIZATION_CHANNELS, residual_channels=32, layers=5):
        super().__init__()
        self.embedding = nn.Embedding(n_channels, residual_channels)
        self.dilated_layers = nn.ModuleList([
            nn.Conv1d(residual_channels, residual_channels, kernel_size=2, dilation=2**i)
            for i in range(layers)
        ])
        self.output = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(residual_channels, residual_channels, 1),
            nn.ReLU(),
            nn.Conv1d(residual_channels, n_channels, 1)
        )

    def forward(self, x):
        x = self.embedding(x).permute(0, 2, 1)  # [B, T] -> [B, C, T]
        for layer in self.dilated_layers:
            out = layer(x)
            trim = x.size(-1) - out.size(-1)
            if trim > 0:
                x = x[:, :, :-trim]  # Trim input to match output
            x = x + out  # Residual connection
        return self.output(x).permute(0, 2, 1)  # [B, T, V]

In [7]:
def train(model, dataloader, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        model.train()
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out.view(-1, QUANTIZATION_CHANNELS), y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")


In [8]:
def generate(model, seed_token=128, length=16000):
    device = next(model.parameters()).device
    model.eval()
    generated = [seed_token]
    with torch.no_grad():
        for _ in range(length):
            inp = torch.tensor(generated[-CHUNK_SIZE:], dtype=torch.long).unsqueeze(0).to(device)
            out = model(inp)[:, -1, :]  # last timestep
            probs = torch.softmax(out, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
    return torch.tensor(generated[1:])  # remove seed

In [None]:
# Clear existing WAV files if the directory exists
if os.path.exists(DATA_DIR):
    for file in os.listdir(DATA_DIR):
        file_path = os.path.join(DATA_DIR, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
else:
    os.makedirs(DATA_DIR)
midi_files = glob.glob(os.path.join(MIDI_DIR, "**/*.mid*"), recursive=True)
random.shuffle(midi_files)
midi_files = midi_files[:MAX_MIDI_FILES]  # randomly pick subset

for midi_file in midi_files:
    wav_name = os.path.splitext(os.path.basename(midi_file))[0] + ".wav"
    wav_path = os.path.join(DATA_DIR, wav_name)
    midi_to_wav(midi_file, wav_path)

wav_files = glob.glob(os.path.join(DATA_DIR, "*.wav"))
dataset = AudioDataset(wav_files)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

model = SimpleWaveNet()
train(model, dataloader, epochs=5)

samples = generate(model, seed_token=128, length=16000)
decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=QUANTIZATION_CHANNELS)
wav = decoder(samples.unsqueeze(0).unsqueeze(0).float())
torchaudio.save("generated.wav", wav, SAMPLE_RATE)
