In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pretty_midi
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pathlib import Path
import os
from tqdm.auto import tqdm
np.set_printoptions(threshold = np.inf)
import matplotlib.pyplot as plt
plt.style.use('dark_background')

In [2]:
sequence_length = 100
time_step = 0.1
min_pitch = 21  # A0
max_pitch = 108  # C8
n_pitches = max_pitch - min_pitch # 88 possible pitch options

In [3]:
# Create binary piano roll sequence : (time, pitch) numpy from a midi file
def midi_to_sequence(midi_path):
  pm = pretty_midi.PrettyMIDI(midi_path)
  # piano_roll : (pitch, time steps) numpy
  piano_roll = pm.get_piano_roll(fs=1/time_step)
  # Transpose for (time steps, pitch)
  piano_roll = piano_roll.T
  # Clip to valid pitch range
  piano_roll = piano_roll[:, min_pitch: max_pitch+1]
  # Convert Velocity information to binary for faster learning and simpler data
  binary_roll = (torch.tensor(piano_roll) != 0).float()
  return binary_roll

In [4]:
# Create sequences : (no of sequences, sequence_length, 88 (possible pitches) ) tensor
# and targets : (no of sequences, 88 (possible pitches) ) tensor
# from given binary piano roll sequence : (time, 88 (possible pitches) ) numpy
def create_sequences(piano_roll):
  sequences = []
  targets = []

  for i in range(0, len(piano_roll) - sequence_length - 1, 1):
    seq = piano_roll[i:i + sequence_length]
    target = piano_roll[i + sequence_length]
    sequences.append(seq)
    targets.append(target)
      
  return torch.stack(sequences), torch.stack(targets)

In [None]:
sequences = []
targets = []
midi_files = list(Path(r"C:\Users\aniru\Desktop\projects\AI Music Generator\custom\maestro-v3.0.0").rglob("*.midi"))
count = 0
max_count = 100
# Combine sequences and targets from max_count number of midi files given in path
for midi_file in tqdm(midi_files):
  if (count == max_count):
    break
  try:
    piano_roll = midi_to_sequence(str(midi_file))
    if piano_roll is not None:
      seqs, tgts = create_sequences(piano_roll)
      sequences.append(seqs)
      targets.append(tgts)
      count += 1
  except Exception as e:
    print(f"Error processing {midi_file}: {e}")

print("Converting training data to tensors")
sequences = torch.cat(sequences, dim=0)
targets = torch.cat(targets, dim=0)
print("Shape of Sequences")
print(sequences.shape)
print("Shape of Targets")
print(targets.shape)
print("Converting training data to dataloader")
dataset = TensorDataset(sequences, targets)
dataloader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)

In [7]:
# Save preprocessd data for future use to save time
torch.save(sequences, 'anirudh_sequences.pth')
torch.save(targets, 'anirudh_targets.pth')

In [None]:
# Display a random midi file after converting its velocity to binary format for reference
print(midi_file)
piano_roll_tmp = midi_to_sequence(str(midi_file))
prev_numpy = piano_roll_tmp.T.numpy()
plt.figure(figsize=(10, 5))
plt.title("Original Piano Roll")
plt.xlabel("Time Steps")
plt.ylabel("Pitch")
plt.imshow(prev_numpy, aspect='auto', origin='lower', cmap='Greens')
plt.show()

In [9]:
class MusicLSTM(nn.Module):
  def __init__(self):
    super(MusicLSTM, self).__init__()
    # 256 hidden layers because the more hidden layers, the more complex patterns it can recognize 
    self.lstm = nn.LSTM(input_size=88, hidden_size=256, num_layers=1, batch_first=True)
    self.fc = nn.Linear(256, 88)
    # No sigmoid because we use BCE with logits loss
      
  def forward(self, x, hidden=None):
    # x : (batch_size, sequence_length, 88 (possible pitches) ) tensor
    out, hidden = self.lstm(x, hidden)
    # out : (batch_size, sequence_length, 256) tensor
    # Give last layer of out for the final fully connected layer
    out = self.fc(out[:, -1, :])
    return out, hidden
  
  def generate(self, seed_sequence, steps=480): #6
    self.eval()
    current_sequence = seed_sequence.clone().to("cuda")
    generated_sequence = torch.zeros((steps, 88), device="cuda")
    hidden = None
    with torch.no_grad():
        for step in tqdm(range(steps)):
            # Get model prediction
            output, hidden = self(current_sequence, hidden)
            # If sigmoid was applied we would keep threshold as >= 0.5 
            # but because we use BCE with logits loss to reduce numerical errors
            # we keep threshold as >= 0 as sigmoid(0) = 0.5 
            output = output >= 0
            generated_sequence[step] = output.squeeze()
            
            # Update the sequence by removing oldest step and adding new one
            current_sequence = torch.cat([
                current_sequence[:, 1:, :],
                output.unsqueeze(1)
            ], dim = 1)
    return generated_sequence

In [10]:
# this function is generated by chatgpt idk if it works properly
# currently validating output by plotting it with matplotlib instead
def sequence_to_midi(sequence, output_path):
    """
    Convert generated sequence to MIDI file.
    Uses tensors for processing until final MIDI creation.
    
    Args:
        sequence: Tensor of shape (time_steps, 88) with velocities
        output_path: Path to save MIDI file
        time_step: Time between notes in seconds
    """
    # Create a new MIDI object
    pm = pretty_midi.PrettyMIDI(initial_tempo=120)
    piano = pretty_midi.Instrument(program=0)  # program 0 is piano
    
    # Finding note onsets and offsets
    # A note starts when velocity goes from 0 to >0
    # A note ends when velocity goes from >0 to 0
    sequence_padded = torch.cat([
        torch.zeros((1, 88)),  # Add zero padding at start
        sequence,
        torch.zeros((1, 88))   # Add zero padding at end
    ])
    
    # Calculate changes in velocity
    velocity_changes = sequence_padded[1:] - sequence_padded[:-1]
    
    # Find note starts (positive velocity change) and ends (negative velocity change)
    note_starts = velocity_changes[:-1] > 0
    note_ends = velocity_changes[1:] < 0
    
    # Convert to numpy for final processing
    sequence_np = sequence.numpy()
    note_starts_np = note_starts.numpy()
    note_ends_np = note_ends.numpy()
    
    # Create MIDI notes
    for pitch in range(88):
        # Find all start times for this pitch
        start_times = np.where(note_starts_np[:, pitch])[0]
        
        for start_idx in start_times:
            # Find the next end time for this note
            end_indices = np.where(note_ends_np[start_idx:, pitch])[0]
            if len(end_indices) == 0:
                # If no end found, end at the last time step
                end_idx = len(sequence) - 1
            else:
                end_idx = start_idx + end_indices[0]
            
            # Get the velocity (use max velocity during the note duration)
            velocity = int(np.max(sequence_np[start_idx:end_idx + 1, pitch]))
            
            # Create note (add 21 to pitch to map to MIDI pitch numbers)
            note = pretty_midi.Note(
                velocity=velocity,
                pitch=pitch + 21,  # MIDI starts at A0 (21)
                start=start_idx * time_step,
                end=(end_idx + 1) * time_step
            )
            piano.notes.append(note)

    pm.instruments.append(piano)
    pm.write(output_path)

In [None]:
model = MusicLSTM()
criterion = torch.nn.BCEWithLogitsLoss() # did not add sigmoid function in model because we using this
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (sequences, targets) in tqdm(enumerate(dataloader)):
        sequences = sequences.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        # Predicting without hidden for default hidden = None
        logits, hidden = model(sequences)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

In [None]:
# Generate music
# seed_sequence : (sequence_length, 88 (possible pitches) ) tensor in cpu
seed_sequence = dataset[10000][0] # A random dataset element's sequence to predict from
# seed_sequence : (1, sequence_length, 88 (possible pitches) ) tensor in gpu
seed_sequence = seed_sequence.to(device).unsqueeze(0)
# generated_sequence : (steps, 88 (possible pitches)) tensor in gpu
generated_sequence = model.generate(seed_sequence, steps=600)
generated_piano_roll = generated_sequence * 110.0 # Converting binary velocity to a reasonable integer velocity of 110
sequence_to_midi(generated_piano_roll.cpu(), "generated_music.mid")

In [None]:
# Visualize the generated piano roll for debugging
gen_numpy = generated_sequence.T.cpu().numpy()
plt.figure(figsize=(10, 5))
plt.title("Generated Piano Roll")
plt.xlabel("Time Steps")
plt.ylabel("Pitch")
plt.imshow(gen_numpy, aspect='auto', origin='lower', cmap='Greens')
plt.show()

In [None]:
# For verification purposes
print(generated_sequence[:5])

In [15]:
# Save model for future testing and optimizer for further training
torch.save(model.state_dict(), 'anirudh_model_weights.pth')
torch.save(optimizer.state_dict(), 'anirudh_optimizer.pth')