In [15]:
import os
import glob
import random
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import pretty_midi
import numpy as np
from tqdm import tqdm

In [16]:
SAMPLE_RATE = 16000
CHUNK_SIZE = 16000  # 1 second of audio
QUANTIZATION_CHANNELS = 256
MIDI_DIR = "./maestro-v3.0.0"     # Maestro dataset root directory
DATA_DIR = "./midi_wav"           # Directory to store WAV files
MAX_MIDI_FILES = 10               # Only convert this many MIDI files
MAX_DURATION = 300


In [17]:
def midi_to_wav(midi_path, wav_path):
    try:
        midi = pretty_midi.PrettyMIDI(midi_path)
        audio = midi.fluidsynth(fs=SAMPLE_RATE)
        if len(audio) == 0:
            print(f"Warning: Empty audio generated from {midi_path}")
            return False
        audio_tensor = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)  # [1, T]
        torchaudio.save(wav_path, audio_tensor, SAMPLE_RATE)
        return True
    except Exception as e:
        print(f"Error converting {midi_path}: {e}")
        return False

In [18]:
def preprocess_wav(path):
    waveform, sr = torchaudio.load(path)
    if sr != SAMPLE_RATE:
        resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
        waveform = resampler(waveform)
    waveform = waveform.mean(dim=0, keepdim=True)  # mono
    mu_law = torchaudio.transforms.MuLawEncoding(quantization_channels=QUANTIZATION_CHANNELS)
    encoded = mu_law(waveform)  # [1, T]
    return encoded.squeeze(0)  # [T]

In [19]:
class AudioDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths
        self.offsets = []  # list of (file_index, offset) pairs
        self.preprocessed = []

        for file_idx, path in enumerate(file_paths):
            try:
                encoded = preprocess_wav(path)
                self.preprocessed.append(encoded)
                # Create overlapping chunks
                for i in range(0, len(encoded) - CHUNK_SIZE, CHUNK_SIZE // 2):
                    self.offsets.append((file_idx, i))
            except Exception as e:
                print(f"Error processing {path}: {e}")

    def __len__(self):
        return len(self.offsets)

    def __getitem__(self, idx):
        file_idx, i = self.offsets[idx]
        encoded = self.preprocessed[file_idx]
        chunk = encoded[i:i+CHUNK_SIZE+1]
        
        # Ensure we have exactly CHUNK_SIZE + 1 samples
        if len(chunk) < CHUNK_SIZE + 1:
            # Pad with zeros if necessary
            chunk = torch.cat([chunk, torch.zeros(CHUNK_SIZE + 1 - len(chunk), dtype=chunk.dtype)])
        
        return chunk[:-1], chunk[1:]  # input and target


In [20]:
class SimpleWaveNet(nn.Module):
    def __init__(self, n_channels=QUANTIZATION_CHANNELS, residual_channels=64, layers=8):
        super().__init__()
        self.n_channels = n_channels
        self.residual_channels = residual_channels
        
        self.embedding = nn.Embedding(n_channels, residual_channels)
        
        # Causal dilated convolutions
        self.dilated_layers = nn.ModuleList()
        self.residual_layers = nn.ModuleList() 
        self.skip_layers = nn.ModuleList()
        
        for i in range(layers):
            dilation = 2 ** i
            # Use causal padding to maintain sequence length
            padding = dilation
            
            conv = nn.Conv1d(
                residual_channels,
                2 * residual_channels,  # For gated activation
                kernel_size=2,
                dilation=dilation,
                padding=padding
            )
            self.dilated_layers.append(conv)
            
            # Residual and skip connections
            self.residual_layers.append(nn.Conv1d(residual_channels, residual_channels, 1))
            self.skip_layers.append(nn.Conv1d(residual_channels, residual_channels, 1))
        
        # Output layers
        self.output = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(residual_channels, residual_channels, 1),
            nn.ReLU(),
            nn.Conv1d(residual_channels, n_channels, 1)
        )

    def forward(self, x):
        # x shape: [batch, time]
        batch_size, time_steps = x.shape
        
        # Embed input
        x = self.embedding(x).transpose(1, 2)  # [batch, channels, time]
        
        skip_connections = []
        
        for i, (dilated_conv, res_conv, skip_conv) in enumerate(
            zip(self.dilated_layers, self.residual_layers, self.skip_layers)
        ):
            # Apply dilated convolution
            conv_out = dilated_conv(x)
            
            # Remove extra padding to maintain causal property
            if conv_out.size(-1) > time_steps:
                conv_out = conv_out[:, :, :time_steps]
            
            # Gated activation: tanh(W * x) * sigmoid(V * x)
            tanh_out, sigmoid_out = conv_out.chunk(2, dim=1)
            gated = torch.tanh(tanh_out) * torch.sigmoid(sigmoid_out)
            
            # Residual connection
            residual = res_conv(gated)
            if residual.size(-1) != x.size(-1):
                residual = residual[:, :, :x.size(-1)]
            x = x + residual
            
            # Skip connection
            skip = skip_conv(gated)
            if skip.size(-1) != time_steps:
                skip = skip[:, :, :time_steps]
            skip_connections.append(skip)
        
        # Sum skip connections
        skip_sum = sum(skip_connections)
        
        # Final output
        out = self.output(skip_sum)  # [batch, n_channels, time]
        return out.transpose(1, 2)  # [batch, time, n_channels]

In [21]:
def train(model, dataloader, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on device: {device}")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=True)
        
        for batch_idx, (x, y) in enumerate(loop):
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            out = model(x)
            
            # Ensure shapes match exactly
            assert out.shape[:2] == y.shape, f"Shape mismatch: out {out.shape} vs y {y.shape}"
            
            loss = criterion(out.reshape(-1, QUANTIZATION_CHANNELS), y.reshape(-1))
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            loop.set_postfix(loss=loss.item())
            
            # Break early for testing
            if batch_idx >= 100:  # Only train on first 100 batches per epoch
                break
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

In [22]:
def generate(model, seed_token=128, length=16000, temperature=1.0):
    device = next(model.parameters()).device
    model.eval()
    
    # Start with a longer seed for better context
    generated = [seed_token] * min(512, length // 4)  
    
    with torch.no_grad():
        for i in tqdm(range(length), desc="Generating"):
            # Use recent context for prediction
            context_length = min(len(generated), CHUNK_SIZE)
            inp = torch.tensor(generated[-context_length:], dtype=torch.long).unsqueeze(0).to(device)
            
            out = model(inp)[:, -1, :]  # Get last timestep prediction
            
            # Apply temperature for sampling diversity
            logits = out / temperature
            probs = torch.softmax(logits, dim=-1)
            
            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
    
    return torch.tensor(generated[len(generated)-length:])  # Return only the generated part


In [23]:
# Clear existing WAV files if the directory exists
if os.path.exists(DATA_DIR):
    for file in os.listdir(DATA_DIR):
        file_path = os.path.join(DATA_DIR, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
else:
    os.makedirs(DATA_DIR)

# Find and filter MIDI files
all_midi_files = glob.glob(os.path.join(MIDI_DIR, "**/*.mid*"), recursive=True)
random.shuffle(all_midi_files)

filtered_midi_files = []
for path in all_midi_files:
    try:
        midi = pretty_midi.PrettyMIDI(path)
        if midi.get_end_time() <= MAX_DURATION:
            filtered_midi_files.append(path)
        if len(filtered_midi_files) >= MAX_MIDI_FILES:
            break
    except Exception as e:
        print(f"Skipped {path}: {e}")

midi_files = filtered_midi_files
print(f"Selected {len(midi_files)} MIDI files for conversion")

# Convert MIDI to WAV
successful_conversions = 0
for midi_file in midi_files:
    wav_name = os.path.splitext(os.path.basename(midi_file))[0] + ".wav"
    wav_path = os.path.join(DATA_DIR, wav_name)
    if midi_to_wav(midi_file, wav_path):
        successful_conversions += 1

print(f"Successfully converted {successful_conversions} MIDI files to WAV")

# Load dataset and train
wav_files = glob.glob(os.path.join(DATA_DIR, "*.wav"))
if not wav_files:
    print("No WAV files found! Check MIDI conversion.")
else:
    dataset = AudioDataset(wav_files)
    print(f"Dataset created with {len(dataset)} chunks")
    
    if len(dataset) == 0:
        print("No valid audio chunks found!")
    else:
        dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

        model = SimpleWaveNet()
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Model created with {total_params:,} parameters")

        train(model, dataloader, epochs=3)

        # Generate audio
        print("Generating audio...")
        samples = generate(model, seed_token=128, length=SAMPLE_RATE*2, temperature=0.8)  # 2 seconds
        
        # Decode and save
        decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=QUANTIZATION_CHANNELS)
        wav = decoder(samples.unsqueeze(0).float())
        torchaudio.save("generated.wav", wav, SAMPLE_RATE)
        print("Generated audio saved as 'generated.wav'")

Selected 10 MIDI files for conversion
Successfully converted 10 MIDI files to WAV
Dataset created with 3430 chunks
Model created with 235,840 parameters
Training on device: cpu


Epoch 1:  12%|█▏        | 100/858 [04:01<30:33,  2.42s/it, loss=3.66]


Epoch 1, Average Loss: 4.6776


Epoch 2:  12%|█▏        | 100/858 [04:08<31:22,  2.48s/it, loss=3.27]


Epoch 2, Average Loss: 3.4886


Epoch 3:  12%|█▏        | 100/858 [04:01<30:28,  2.41s/it, loss=3]   


Epoch 3, Average Loss: 3.1358
Generating audio...


Generating: 100%|██████████| 32000/32000 [1:03:10<00:00,  8.44it/s]


Generated audio saved as 'generated.wav'
