In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pretty_midi
import math
from torch.utils.data import Dataset, DataLoader

# Constants
MAX_SEQ_LEN = 512  # Maximum sequence length
VOCAB_SIZE = 512   # Size of vocabulary (pitch + duration + instrument + special tokens)
D_MODEL = 256      # Embedding dimension
N_HEADS = 8        # Number of attention heads
N_LAYERS = 6       # Number of transformer layers
D_FF = 1024        # Feedforward dimension
DROPOUT = 0.1      # Dropout rate

# Special tokens
PAD_TOKEN = 0
SOS_TOKEN = 1
EOS_TOKEN = 2

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len, dropout):
        super(MusicTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)
        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads,
                                                  dim_feedforward=d_ff, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=n_layers)
        decoder_layers = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_heads,
                                                  dim_feedforward=d_ff, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers=n_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

        self.d_model = d_model
        self.max_seq_len = max_seq_len

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        if src_mask is None:
            src_mask = self.generate_square_subsequent_mask(src.size(0)).to(src.device)
        if tgt_mask is None:
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)

        memory = self.transformer_encoder(src, src_mask)
        output = self.transformer_decoder(tgt, memory, tgt_mask)
        output = self.fc_out(output)

        return output

class MIDIProcessor:
    def __init__(self, vocab_size=VOCAB_SIZE):
        # Reserved tokens
        self.pad_token = PAD_TOKEN
        self.sos_token = SOS_TOKEN
        self.eos_token = EOS_TOKEN

        # Token ranges
        self.start_pitch = 10
        self.num_pitches = 128
        self.start_duration = self.start_pitch + self.num_pitches
        self.num_durations = 100  # Quantized durations
        self.start_instrument = self.start_duration + self.num_durations
        self.num_instruments = 16  # General MIDI has 16 instrument families

        assert self.start_instrument + self.num_instruments < vocab_size, "Vocabulary size too small"

    def encode_note(self, pitch, duration_bin, instrument):
        pitch_token = self.start_pitch + pitch
        duration_token = self.start_duration + duration_bin
        instrument_token = self.start_instrument + instrument
        return [instrument_token, pitch_token, duration_token]

    def decode_token(self, token):
        if token < self.start_pitch:
            return {"type": "special", "value": token}
        elif token < self.start_duration:
            return {"type": "pitch", "value": token - self.start_pitch}
        elif token < self.start_instrument:
            return {"type": "duration", "value": token - self.start_duration}
        else:
            return {"type": "instrument", "value": token - self.start_instrument}

    def quantize_duration(self, duration):
        # Quantize duration to one of num_durations bins
        # Using log scale to better represent shorter durations
        max_duration = 4.0  # Maximum duration in seconds
        if duration > max_duration:
            duration = max_duration

        # Log scale quantization
        bin_idx = int(self.num_durations * math.log(1 + duration * 10) / math.log(1 + max_duration * 10))
        return min(bin_idx, self.num_durations - 1)

    def dequantize_duration(self, bin_idx):
        # Convert bin back to duration
        max_duration = 4.0
        return (math.exp(bin_idx * math.log(1 + max_duration * 10) / self.num_durations) - 1) / 10

    def midi_to_sequence(self, midi_file):
        """Convert MIDI file to token sequence"""
        if isinstance(midi_file, str):
            midi_data = pretty_midi.PrettyMIDI(midi_file)
        else:
            midi_data = midi_file

        # Sort all notes by their start time
        all_notes = []
        for i, instrument in enumerate(midi_data.instruments):
            instrument_id = min(i, self.num_instruments - 1)  # Limit to available instrument tokens
            for note in instrument.notes:
                all_notes.append({
                    'start': note.start,
                    'end': note.end,
                    'pitch': note.pitch,
                    'instrument': instrument_id
                })

        all_notes.sort(key=lambda x: x['start'])

        # Convert to token sequence
        tokens = [self.sos_token]
        for note in all_notes:
            duration = note['end'] - note['start']
            duration_bin = self.quantize_duration(duration)
            note_tokens = self.encode_note(note['pitch'], duration_bin, note['instrument'])
            tokens.extend(note_tokens)

        tokens.append(self.eos_token)

        return tokens

    def sequence_to_midi(self, tokens, tempo=120):
        """Convert token sequence back to MIDI"""
        midi_data = pretty_midi.PrettyMIDI(initial_tempo=tempo)
        instruments = [pretty_midi.Instrument(program=i) for i in range(self.num_instruments)]

        current_time = 0.0
        current_instrument = 0
        current_pitch = 60

        i = 0
        while i < len(tokens):
            token = tokens[i]
            if token == self.eos_token:
                break

            token_info = self.decode_token(token)

            if token_info['type'] == 'instrument':
                current_instrument = token_info['value']
                i += 1
            elif token_info['type'] == 'pitch':
                current_pitch = token_info['value']

                # Look ahead for duration
                if i + 1 < len(tokens):
                    next_token = tokens[i + 1]
                    next_info = self.decode_token(next_token)
                    if next_info['type'] == 'duration':
                        duration = self.dequantize_duration(next_info['value'])

                        # Create a note
                        note = pretty_midi.Note(
                            velocity=100,
                            pitch=current_pitch,
                            start=current_time,
                            end=current_time + duration
                        )

                        instruments[current_instrument].notes.append(note)
                        current_time += duration
                        i += 2
                    else:
                        i += 1
                else:
                    i += 1
            else:
                i += 1

        # Add instruments to MIDI data
        for instrument in instruments:
            if len(instrument.notes) > 0:
                midi_data.instruments.append(instrument)

        return midi_data

class MusicDataset(Dataset):
    def __init__(self, midi_files, processor, max_seq_len):
        self.midi_files = midi_files
        self.processor = processor
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.midi_files)

    def __getitem__(self, idx):
        midi_file = self.midi_files[idx]
        try:
            # Convert MIDI to token sequence
            tokens = self.processor.midi_to_sequence(midi_file)

            # Truncate or pad sequence
            if len(tokens) > self.max_seq_len:
                tokens = tokens[:self.max_seq_len]
            else:
                tokens = tokens + [PAD_TOKEN] * (self.max_seq_len - len(tokens))

            # Convert to tensors
            src = torch.tensor(tokens[:-1], dtype=torch.long)
            tgt = torch.tensor(tokens[1:], dtype=torch.long)

            return src, tgt
        except Exception as e:
            print(f"Error processing {midi_file}: {e}")
            # Return a simple sequence in case of error
            src = torch.tensor([SOS_TOKEN] + [PAD_TOKEN] * (self.max_seq_len - 2), dtype=torch.long)
            tgt = torch.tensor([PAD_TOKEN] * (self.max_seq_len - 2) + [EOS_TOKEN], dtype=torch.long)
            return src, tgt

def train_model(model, train_dataloader, val_dataloader=None, epochs=10, lr=0.0001, device='cuda'):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch_idx, (src, tgt) in enumerate(train_dataloader):
            src, tgt = src.to(device), tgt.to(device)

            # Forward pass
            src = src.transpose(0, 1)  # Change to (seq_len, batch_size)
            tgt_input = tgt.transpose(0, 1)[:-1, :]  # Exclude the last token
            tgt_output = tgt.transpose(0, 1)[1:, :]  # Exclude the first token

            # Create masks
            src_padding_mask = (src == PAD_TOKEN).transpose(0, 1)
            tgt_padding_mask = (tgt_input == PAD_TOKEN).transpose(0, 1)

            output = model(src, tgt_input)

            # Reshape for loss calculation
            output = output.view(-1, output.shape[-1])
            tgt_output = tgt_output.contiguous().view(-1)

            # Calculate loss
            loss = criterion(output, tgt_output)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            total_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss {loss.item():.4f}')

        avg_loss = total_loss / len(train_dataloader)
        print(f'Epoch {epoch+1}, Average loss: {avg_loss:.4f}')

        # Validation
        if val_dataloader is not None:
            val_loss = evaluate(model, val_dataloader, criterion, device)
            print(f'Validation loss: {val_loss:.4f}')

    return model

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)

            # Forward pass
            src = src.transpose(0, 1)
            tgt_input = tgt.transpose(0, 1)[:-1, :]
            tgt_output = tgt.transpose(0, 1)[1:, :]

            output = model(src, tgt_input)

            # Reshape for loss calculation
            output = output.view(-1, output.shape[-1])
            tgt_output = tgt_output.contiguous().view(-1)

            # Calculate loss
            loss = criterion(output, tgt_output)
            total_loss += loss.item()

    return total_loss / len(dataloader)

def generate_music(model, seed_midi, processor, max_length=1024, temperature=1.0, device='cuda'):
    model.eval()

    # Process seed
    seed_tokens = processor.midi_to_sequence(seed_midi)
    if len(seed_tokens) > max_length // 2:
        seed_tokens = seed_tokens[:max_length // 2]

    # Convert to tensor
    seed_tensor = torch.tensor(seed_tokens).unsqueeze(1).to(device)  # (seq_len, 1)

    # Initialize target with SOS token
    tgt = torch.tensor([[SOS_TOKEN]]).to(device)  # (1, 1)

    # Generate sequence
    generated_tokens = [SOS_TOKEN]

    for _ in range(max_length):
        # Create masks
        tgt_mask = model.generate_square_subsequent_mask(tgt.size(0)).to(device)

        # Forward pass
        output = model(seed_tensor, tgt, tgt_mask=tgt_mask)
        next_token_logits = output[-1, 0] / temperature
        next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), 1).item()

        # Add to sequence
        generated_tokens.append(next_token)
        next_token_tensor = torch.tensor([[next_token]]).to(device)
        tgt = torch.cat([tgt, next_token_tensor], dim=0)

        # Stop if EOS token is generated
        if next_token == EOS_TOKEN:
            break

    # Convert tokens back to MIDI
    midi_data = processor.sequence_to_midi(generated_tokens)
    return midi_data

def main():
    # Example usage
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize processor
    processor = MIDIProcessor()

    # Initialize model
    model = MusicTransformer(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        d_ff=D_FF,
        max_seq_len=MAX_SEQ_LEN,
        dropout=DROPOUT
    )

    # Example: Load MIDI files for training
    # This is a placeholder - you would need to provide your own MIDI files
    import glob
    midi_files = glob.glob('D:\\Dev\\repo\\bach\\data_cache\\data\\*.mid')

    if midi_files:
        # Create dataset and dataloader
        dataset = MusicDataset(midi_files, processor, MAX_SEQ_LEN)
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=128)

        # Train model
        model = train_model(
            model=model,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            epochs=10,
            device=device
        )

        # Save model
        torch.save(model.state_dict(), 'music_transformer.pth')

        # Generate music from seed
        if midi_files:
            seed_midi = pretty_midi.PrettyMIDI(midi_files[0])
            generated_midi = generate_music(model, seed_midi, processor, device=device)
            generated_midi.write('generated_music.mid')
    else:
        print("No MIDI files found for training")

# Example of how to use the model for inference only
def inference_example(model_path, seed_midi_path, output_path):
    # Load the trained model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MusicTransformer(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        d_ff=D_FF,
        max_seq_len=MAX_SEQ_LEN,
        dropout=DROPOUT
    )
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)

    # Initialize processor
    processor = MIDIProcessor()

    # Load seed MIDI
    seed_midi = pretty_midi.PrettyMIDI(seed_midi_path)

    # Generate music
    generated_midi = generate_music(
        model=model,
        seed_midi=seed_midi,
        processor=processor,
        max_length=1024,
        temperature=1.0,
        device=device
    )

    # Save generated music
    generated_midi.write(output_path)
    print(f"Generated music saved to {output_path}")

In [None]:
main()

Using device: cuda


