In [None]:
# Install necessary libraries
!pip install pretty_midi music21 transformers torch matplotlib

# Import necessary libraries
import os
import pretty_midi
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, DataCollatorForLanguageModeling, Trainer, default_data_collator
import music21
import matplotlib.pyplot as plt
import warnings

# Suppress specific warnings from pretty_midi
warnings.filterwarnings("ignore", category=UserWarning, module='pretty_midi')

# Suppress all warnings from music21
warnings.filterwarnings("ignore", module='music21')

# Define the MIDIDataset class
class MIDIDataset(Dataset):
    def __init__(self, root_dirs, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.files = self._gather_midi_files(root_dirs)
        self.files = self._filter_valid_files(self.files)
        self.total_duration = self.calculate_total_duration()

    def _gather_midi_files(self, root_dirs):
        files = []
        for root_dir in root_dirs:
            for subdir, _, filenames in os.walk(root_dir):
                for filename in filenames:
                    if filename.endswith('.midi') or filename.endswith('.mid'):
                        files.append(os.path.join(subdir, filename))
        return files

    def _filter_valid_files(self, files):
        valid_files = []
        for file in files:
            try:
                midi_data = pretty_midi.PrettyMIDI(file)
                if any(instrument.notes for instrument in midi_data.instruments):
                    valid_files.append(file)
                else:
                    print(f"No valid notes found in {file}. Skipping...")
            except Exception as e:
                print(f"Error reading {file}: {e}. Skipping...")
        return valid_files

    def calculate_total_duration(self):
        total_duration = 0.0
        for file in self.files:
            try:
                score = music21.converter.parse(file)
                total_duration += score.duration.quarterLength
            except Exception as e:
                print(f"Error processing {file}: {e}. Skipping...")
        return total_duration

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        midi_file = self.files[idx]
        midi_data = pretty_midi.PrettyMIDI(midi_file)

        notes = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append(note.pitch)

        tokens = self.tokenizer(
            " ".join(map(str, notes)),
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': tokens.input_ids.squeeze(0),
            'attention_mask': tokens.attention_mask.squeeze(0)
        }

# Path to your Indian dataset
indian_dataset_path = '/content/drive/MyDrive/Dataset Music/Msc/Indian'

# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Load the dataset specifically from the Indian folder
indian_dataset = MIDIDataset(root_dirs=[indian_dataset_path], tokenizer=tokenizer)
print(f"Total duration of the Indian dataset: {indian_dataset.total_duration:.2f} quarter lengths")

# Check if the Indian dataset is empty
if len(indian_dataset) == 0:
    raise ValueError("The Indian dataset is empty. Please check if the directory contains valid MIDI files.")

# Split into training and validation sets
train_size = int(0.8 * len(indian_dataset))
val_size = len(indian_dataset) - train_size
train_dataset, val_dataset = random_split(indian_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=default_data_collator)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=default_data_collator)

# Load pre-trained GPT-2 model
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.config.pad_token_id = model.config.eos_token_id

# Training arguments
train_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    report_to="none"
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

# Train the model on the Indian dataset
trainer.train()

# Generate music using the fine-tuned model
input_prompt = "60 62 64 65"  # Example input prompt
input_ids = tokenizer(input_prompt, return_tensors='pt').input_ids
attention_mask = tokenizer(input_prompt, return_tensors='pt').attention_mask

output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_length=100,
    num_return_sequences=1,
    do_sample=True,
    temperature=1.2,
    top_k=50,
    top_p=0.95,
    repetition_penalty=2.0
)

output_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated MIDI pitch sequence:", output_text)

# Convert the generated sequence to a music21 score
def convert_sequence_to_score(sequence, note_duration=0.5):
    notes = sequence.split()
    score = music21.stream.Score()
    part = music21.stream.Part()

    for pitch in notes:
        note_pitch = int(pitch)
        note = music21.note.Note(note_pitch)
        note.quarterLength = note_duration
        part.append(note)

    score.append(part)
    return score

score = convert_sequence_to_score(output_text)

# Optionally, add chords to the score
def add_chords_to_score(score):
    chords = [
        music21.chord.Chord(['C4', 'E4', 'G4']),
        music21.chord.Chord(['F4', 'A4', 'C5']),
        music21.chord.Chord(['G4', 'B4', 'D5']),
        music21.chord.Chord(['C4', 'E4', 'G4'])
    ]

    for chord in chords:
        chord.quarterLength = 2  # Each chord lasts for two beats
        score.insert(chord.offset, chord)

    return score

# Optionally, add lyrics to the score
def add_lyrics_to_score(score, lyrics):
    part = score.parts[0]
    for i, note in enumerate(part.notes):
        if i < len(lyrics):
            note.lyric = lyrics[i]
    return score

# Add chords and lyrics if needed
score = add_chords_to_score(score)
lyrics = "This is a test of generated music".split()
score = add_lyrics_to_score(score, lyrics)

# Function to save the score as a MIDI file without using MuseScore
def save_score_as_midi(score, output_midi_file='generated_music21.mid'):
    mf = music21.midi.translate.music21ObjectToMidiFile(score)
    mf.open(output_midi_file, 'wb')
    mf.write()
    mf.close()
    print(f"Saved MIDI file to {output_midi_file}")

# Save the score as a MIDI file
save_score_as_midi(score, 'generated_music21.mid')

# Play the MIDI file using music21
score = music21.converter.parse("generated_music21.mid")
score.show('midi')


Total duration of the Indian dataset: 69218.08 quarter lengths


Epoch,Training Loss,Validation Loss
1,3.8938,3.979025
2,3.9761,3.899576
3,3.8455,3.862927


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated MIDI pitch sequence: 60 62 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 86 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 56 57 58 59 60 63 87 49 48
Saved MIDI file to generated_music21.mid



**Justification:** The code fine-tunes a GPT-2 model using a dataset of Indian MIDI files, generates new music sequences, converts them into a music score, and saves the result as a MIDI file. It also plays back the generated music using music21. This process enables automated music composition and playback, tailored to Indian musical data.