In [1]:
//Fonctionnement du Modèle
//L'application repose sur un modèle de diffusion discrète pour générer des séquences MIDI à partir d'instructions textuelles. Le pipeline se compose de trois modules principaux :

//Encodage sémantique (FLAN-T5) :
//Une instruction en langage naturel (ex. : "une mélodie joyeuse en do majeur") est encodée par FLAN-T5, un modèle de langage pré-entraîné, pour capturer son sens musical et contextuel.

//Diffusion discrète :
//Une séquence MIDI tokenisée (via un tokenizer REMI) est progressivement corrompue par un processus de bruitage selon un cosine beta schedule. Cela transforme la musique en une version bruitée difficile à reconnaître.

//Denoising par Transformer :
//Un Transformer sert de denoiseur et tente de reconstruire la séquence originale étape par étape. Il est conditionné sur les embeddings de FLAN-T5, guidant ainsi la génération musicale selon le texte initial.

//Le modèle apprend ainsi à générer des compositions musicales cohérentes et expressives à partir d’instructions en langage naturel.

SyntaxError: invalid character '’' (U+2019) (3642017967.py, line 13)

In [None]:
/// USING THE SAME DATASET USED BY AMAAI-Lab TO AVOID RE CLEANING THE DATASET

In [None]:
!git clone https://github.com/AMAAI-Lab/Text2midi.git

In [None]:
!wget https://huggingface.co/datasets/amaai-lab/MidiCaps/resolve/main/midicaps.tar.gz

In [None]:
import tarfile
import os

def decompress_tar_gz(file_path, extract_path="."):
    if file_path.endswith(".tar.gz"):
        with tarfile.open(file_path, "r:gz") as tar:
            tar.extractall(path=extract_path)
            print(f"Extracted to: {os.path.abspath(extract_path)}")
    else:
        print("The file is not a .tar.gz archive.")

# Example usage
decompress_tar_gz("midicaps.tar.gz", "output_directory")


In [None]:
!pip install torch transformers accelerate miditok wandb spacy jsonlines pyyaml tqdm nltk
!python -m spacy download en_core_web_sm

In [None]:
import shutil


src = "/kaggle/working/Text2midi/captions/captions.json"
dst = "/kaggle/working/captions.json"

# Move the file
shutil.move(src, dst)

In [None]:
import os


directory = '/kaggle/working/'


for filename in os.listdir(directory):

    if filename.startswith("checkpoint_epoch_") and filename.endswith(".bin"):
        try:

            epoch = int(filename.split('_')[2].split('.')[0])
            

            if 5 <= epoch <= 50:
                file_path = os.path.join(directory, filename)
                os.remove(file_path)
                print(f"Deleted {filename}")
        except Exception as e:
            print(f"Error processing file {filename}: {e}")


In [None]:
MODE = "build_vocab_remi"

In [None]:
#Helper Functions
def _get_clones(module, N):
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

def _get_activation_fn(activation: str):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}")

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule as proposed in https://arxiv.org/abs/2102.09672"""
    steps = torch.arange(timesteps, dtype=torch.float32)
    f_t = torch.cos(((steps / timesteps + s) / (1.0 + s) * math.pi / 2)) ** 2
    betas = 1.0 - f_t / torch.roll(f_t, shifts=1, dims=0)
    betas = torch.clamp(betas, 0.0, 0.999)
    betas[0] = 0.0001  # Avoid zero beta
    return betas


In [None]:
# Cell 4: Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class DiscreteDiffusionModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048,
                 num_steps=1000, dropout=0.1, device=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_steps = num_steps

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.time_emb = nn.Embedding(num_steps, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=5000).to(device)

        # FLAN-T5 encoder
        self.encoder = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device)
        for param in self.encoder.parameters():
            param.requires_grad = False

        # Project FLAN-T5 output (768) to d_model (512)
        self.text_projection = nn.Linear(768, d_model).to(device)

        # Transformer-based denoiser
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.denoiser = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output layer
        self.projection = nn.Linear(d_model, vocab_size)

        # D3PM: Uniform transition matrix
        self.Q = torch.ones(vocab_size, vocab_size, device=device) / vocab_size
        self.Q_bar = torch.ones(num_steps, vocab_size, vocab_size, device=device)
        self.log_Q = torch.log(self.Q + 1e-10)

        # Noise schedule (cosine)
        self.betas = cosine_beta_schedule(num_steps).to(device)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)

        # Precompute Q_bar for each timestep
        for t in range(num_steps):
            alpha_bar_t = self.alpha_bar[t]
            self.Q_bar[t] = alpha_bar_t * torch.eye(vocab_size, device=device) + \
                           (1 - alpha_bar_t) * self.Q

        self._reset_parameters()

    def forward(self, x_t, t, src, src_mask):
        # x_t: Noisy tokens (batch, seq_len)
        # t: Timestep (batch,)
        # src: Text input (batch, text_len)
        # src_mask: Text attention mask (batch, text_len)
        x_emb = self.token_emb(x_t) * math.sqrt(self.d_model)
        x_emb = self.pos_encoder(x_emb.transpose(0, 1)).transpose(0, 1)
        t_emb = self.time_emb(t).unsqueeze(1)  # (batch, 1, d_model)
        x_emb = x_emb + t_emb
        memory = self.encoder(src, attention_mask=src_mask).last_hidden_state  # (batch, text_len, 768)
        memory = self.text_projection(memory)  # (batch, text_len, 512)
        output = self.denoiser(x_emb, memory, memory_mask=None)
        logits = self.projection(output)
        return logits

    def sample(self, src, src_mask, seq_len, num_steps=None, ddim=False, eta=0.0):
      device = src.device
      batch_size = src.size(0)
      x_t = torch.randint(0, self.vocab_size, (batch_size, seq_len), device=device)
      num_steps = num_steps or self.num_steps

      if ddim:
          step_indices = torch.linspace(0, self.num_steps - 1, steps=self.num_steps // num_steps + 1, device=device).long()
      else:
          step_indices = torch.arange(num_steps, device=device)

      for i in reversed(range(len(step_indices))):
          t = step_indices[i]
          t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
          with torch.no_grad():
              logits = self(x_t, t_tensor, src, src_mask)

          # Debugging to see the shape of logits
          # print(f"Logits shape: {logits.shape}")

          # Ensure logits have shape [batch_size, seq_len, vocab_size]
          if len(logits.shape) == 3:  # [batch_size, seq_len, vocab_size]
              # We need to reshape for multinomial sampling
              # Flatten batch_size and seq_len dimensions
              batch_seq_size = logits.size(0) * logits.size(1)
              vocab_size = logits.size(2)

              # Reshape to [batch_size*seq_len, vocab_size]
              flat_logits = logits.reshape(batch_seq_size, vocab_size)
              probs = F.softmax(flat_logits, dim=-1)

              # Sample from the flattened distribution
              flat_samples = torch.multinomial(probs, num_samples=1).squeeze(-1)

              # Reshape back to [batch_size, seq_len]
              x_t_new = flat_samples.reshape(logits.size(0), logits.size(1))
          else:
              # If logits are already 2D [batch_size, vocab_size]
              probs = F.softmax(logits, dim=-1)
              x_t_new = torch.multinomial(probs, num_samples=1).squeeze(-1)

          if ddim and i > 0:
              t_prev = step_indices[i - 1]
              alpha_bar_t = self.alpha_bar[t]
              alpha_bar_t_prev = self.alpha_bar[t_prev]
              pred_x0 = x_t_new  # Use the new sampled tokens
              sigma = eta * torch.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_t_prev))
              x_t = torch.where(
                  torch.rand_like(x_t.float()) < sigma,
                  torch.randint(0, self.vocab_size, x_t.shape, device=device),
                  pred_x0
              )
          else:
              x_t = x_t_new

      return x_t

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

In [None]:
# Cell 5: Dataset
class Text2MusicDataset(Dataset):
    def __init__(self, configs, captions, remi_tokenizer, mode="train", shuffle=False):
        self.mode = mode
        self.captions = captions
        if shuffle:
            random.shuffle(self.captions)
        self.dataset_path = configs['raw_data']['dataset_folder']
        self.remi_tokenizer = remi_tokenizer
        self.nlp = English()
        self.nlp.add_pipe('sentencizer')
        self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
        self.num_steps = configs['model']['text2midi_model'].get('num_diffusion_steps', 1000)
        self.vocab_size = len(remi_tokenizer)

        # Noise schedule
        self.betas = cosine_beta_schedule(self.num_steps)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)
        self.Q = torch.ones(self.vocab_size, self.vocab_size) / self.vocab_size

        print("Length of dataset:", len(self.captions))

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]['caption']
        midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])

        if not os.path.exists(midi_filepath):
            raise FileNotFoundError(f"MIDI file not found: {midi_filepath}")

        try:
            tokens = self.remi_tokenizer(midi_filepath)
            tokenized_midi = ([self.remi_tokenizer["BOS_None"]] + tokens.ids +
                             [self.remi_tokenizer["EOS_None"]]) if tokens.ids else [
                             self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]]
            tokenized_midi = torch.tensor(tokenized_midi)
        except Exception as e:
            raise ValueError(f"Error tokenizing MIDI file {midi_filepath}: {str(e)}")

        if len(tokenized_midi) < self.decoder_max_sequence_length:
            x_0 = F.pad(tokenized_midi, (0, self.decoder_max_sequence_length - len(tokenized_midi))).long()
        else:
            x_0 = tokenized_midi[:self.decoder_max_sequence_length].long()

        t = torch.randint(0, self.num_steps, (1,)).item()
        alpha_bar_t = self.alpha_bar[t]
        Q_t = alpha_bar_t * torch.eye(self.vocab_size) + (1 - alpha_bar_t) * self.Q
        probs = F.one_hot(x_0, num_classes=self.vocab_size).float() @ Q_t
        x_t = torch.multinomial(probs, num_samples=1).squeeze(-1)

        if random.random() > 0.5 and self.mode == "train":
            sentences = list(self.nlp(caption).sents)
            if sentences:
                sent_length = len(sentences)
                drop_pct = (20 + random.random() * 30) / 100
                how_many_to_drop = int(np.floor(drop_pct * sent_length) if sent_length < 4 else np.ceil(drop_pct * sent_length))
                which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
                new_sentences = [s for i, s in enumerate(sentences) if i not in which_to_drop]
                new_sentences = " ".join([s.text for s in new_sentences])
            else:
                new_sentences = caption
        else:
            new_sentences = caption

        inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
        return (inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0),
                x_t.long(), t, x_0.long())

def collate_fn(batch):
    input_ids = nn.utils.rnn.pad_sequence([item[0] for item in batch], batch_first=True, padding_value=0)
    attention_mask = nn.utils.rnn.pad_sequence([item[1] for item in batch], batch_first=True, padding_value=0)
    x_t = nn.utils.rnn.pad_sequence([item[2] for item in batch], batch_first=True, padding_value=0)
    t = torch.tensor([item[3] for item in batch])
    x_0 = nn.utils.rnn.pad_sequence([item[4] for item in batch], batch_first=True, padding_value=0)
    return input_ids, attention_mask, x_t, t, x_0


In [None]:
def generate_midi(model, tokenizer, t5_tokenizer, caption, seq_len, output_dir, device, ddim=True):
    model.eval()
    inputs = t5_tokenizer(caption, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    with torch.no_grad():
        tokens = model.sample(input_ids, attention_mask, seq_len, ddim=ddim)
    
    token_ids = tokens[0].cpu().numpy().tolist()
    if token_ids[0] == tokenizer["BOS_None"]:
        token_ids = token_ids[1:]
    if token_ids[-1] == tokenizer["EOS_None"]:
        token_ids = token_ids[:-1]
    
    try:
        # Use decode instead of tokens_to_midi
        midi_score = tokenizer.decode(token_ids)
        output_path = os.path.join(output_dir, f"generated_{int(time.time())}.mid")
        
        # Use dump_midi method which is available in the ScoreTick object
        midi_score.dump_midi(output_path)
                
        print(f"Generated MIDI saved to {output_path}")
        return output_path
    except Exception as e:
        print(f"Error converting tokens to MIDI: {str(e)}")
        return None

In [None]:
# Cell 8: Vocabulary Building

import datetime

def build_vocab(configs):
    vocab = {}
    for i in INSTRUMENTS:
        vocab[('prefix', 'instrument', i)] = len(vocab) + 1
    velocity = [0, 15, 30, 45, 60, 75, 90, 105, 120, 127]
    midi_pitch = list(range(0, 128))
    onset = list(range(0, 5001, 10))
    duration = list(range(0, 5001, 10))
    for v in velocity:
        for i in INSTRUMENTS:
            for p in midi_pitch:
                if i != "drum":
                    vocab[(i, p, v)] = len(vocab) + 1
    for p in midi_pitch:
        vocab[("drum", p)] = len(vocab) + 1
    for o in onset:
        vocab[("onset", o)] = len(vocab) + 1
    for d in duration:
        vocab[("dur", d)] = len(vocab) + 1
    special_tokens = ["<T>", "<D>", "<U>", "<SS>", "<S>", "<E>", "SEP"]
    for token in special_tokens:
        vocab[token] = len(vocab) + 1
    print(f"Vocabulary length: {len(vocab)}")
    vocab_path = os.path.join(configs["artifact_folder"], "vocab.pkl")
    os.makedirs(configs["artifact_folder"], exist_ok=True)
    with open(vocab_path, 'wb') as f:
        pickle.dump(vocab, f)
    print(f"Vocabulary saved to {vocab_path}")

def build_vocab_remi(configs):
    BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": BEAT_RES,
        "num_velocities": 32,
        "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
        "use_chords": False,
        "use_rests": False,
        "use_tempos": True,
        "use_time_signatures": True,
        "use_programs": True,
        "num_tempos": 32,
        "tempo_range": (40, 250),
    }
    config = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(config)
    caption_path = "/kaggle/working/captions.json"
    if not os.path.exists(caption_path):
        raise FileNotFoundError(f"Caption file not found: {caption_path}")
    with jsonlines.open(caption_path) as reader:
        captions = list(reader)
    midi_paths = [os.path.join(configs['raw_data']['dataset_folder'],
                  captions[i]['location']) for i in range(len(captions))][:30000]
    for path in midi_paths:
        if not os.path.exists(path):
            raise FileNotFoundError(f"MIDI file not found: {path}")
    print(f"Vocabulary length: {tokenizer.vocab_size}")
    vocab_path = os.path.join(configs["artifact_folder"], "vocab_remi.pkl")
    os.makedirs(configs["artifact_folder"], exist_ok=True)
    with open(vocab_path, 'wb') as f:
        pickle.dump(tokenizer, f)
    print(f"Vocabulary saved to {vocab_path}")


In [None]:
MODE = "build_vocab_remi"

In [None]:

# Cell 9: Execute Mode
if MODE == "train":
    train_model_accelerate(CONFIG)
elif MODE == "build_vocab":
    build_vocab(CONFIG)
elif MODE == "build_vocab_remi":
    build_vocab_remi(CONFIG)
else:
    raise ValueError(f"Invalid mode: {MODE}. Choose 'train', 'build_vocab', or 'build_vocab_remi'")


In [None]:
MODE = "train"

In [None]:
pip install mido

In [None]:
from mido import MidiFile, MidiTrack, Message

dummy = MidiFile()
track = MidiTrack()
dummy.tracks.append(track)

# Add a single dummy note
track.append(Message('note_on', note=60, velocity=64, time=0))
track.append(Message('note_off', note=60, velocity=64, time=480))

dummy.save('/kaggle/working/output_directory/lmd_full/9/9d762ce1f025b6df8e87335092024626.mid')
print("ok")

In [None]:
import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
    anony = None
except:
    anony = "must"
    print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

In [None]:
pip install miditoolkit

In [None]:
////////////////////////////////

In [None]:
# Inference from checkpoint_epoch_50.bin

import torch
import os
import pickle
from transformers import T5Tokenizer

# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "/kaggle/working/saved/checkpoint_epoch_50.bin"

# --- Load REMI Tokenizer ---
with open("/kaggle/working/artifacts/vocab_remi.pkl", "rb") as f:
    tokenizer = pickle.load(f)

# --- Load Model ---
model = DiscreteDiffusionModel(
    vocab_size=len(tokenizer),
    d_model=CONFIG['model']['text2midi_model']['decoder_d_model'],
    nhead=CONFIG['model']['text2midi_model']['decoder_num_heads'],
    num_layers=CONFIG['model']['text2midi_model']['decoder_num_layers'],
    dim_feedforward=CONFIG['model']['text2midi_model']['decoder_intermediate_size'],
    num_steps=CONFIG['model']['text2midi_model']['num_diffusion_steps'],
    device=device
).to(device)

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# --- Load T5 Tokenizer ---
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

# --- Generate ---
caption = "A chaotic song "

midi_path = generate_midi(
    model=model,
    tokenizer=tokenizer,
    t5_tokenizer=t5_tokenizer,
    caption=caption,
    seq_len=CONFIG['model']['text2midi_model']['decoder_max_sequence_length'],
    output_dir=CONFIG['training']['text2midi_model']['output_dir'],
    device=device,
    ddim=True
)


In [None]:
////////////// edition 80-10-10

In [None]:
//////////used config and train

In [None]:
MODE="train"

In [None]:
# Cell 12: Execute Mode
if MODE == "train":
    train_model_accelerate(CONFIG)
elif MODE == "build_vocab":
    build_vocab(CONFIG)
elif MODE == "build_vocab_remi":
    build_vocab_remi(CONFIG)
else:
    raise ValueError(f"Invalid mode: {MODE}. Choose 'train', 'build_vocab', or 'build_vocab_remi'")

In [None]:
# Cell 1: Imports and Configuration
import os
import pickle
import jsonlines
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm.auto import tqdm
import math
import time
import random
from transformers import T5Tokenizer, T5EncoderModel, get_scheduler
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from spacy.lang.en import English
import wandb
import logging
from miditok import REMI, TokenizerConfig
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sklearn.model_selection import train_test_split
import json
from copy import deepcopy
import datetime

# Initialize logger
logger = get_logger(__name__)

# Initialize W&B with API key
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
except:
    print('W&B login failed. Logging locally. Provide W&B API key in Kaggle Secrets as "wandb_api".')

# Configuration
CONFIG = {
    "model": {
        "text2midi_model": {
            "decoder_d_model": 768,
            "decoder_num_heads": 12,
            "decoder_num_layers": 8,
            "decoder_intermediate_size": 2048,
            "decoder_max_sequence_length": 1024,
            "num_diffusion_steps": 1000,
            "ddim_steps": 100
        }
    },
    "training": {
        "text2midi_model": {
            "learning_rate": 0.0002,  # Reduced from 0.0005
            "epochs": 50,
            "max_train_steps": None,
            "num_warmup_steps": 2000,
            "gradient_accumulation_steps": 2,
            "per_device_train_batch_size": 8,  # Reduced from 16
            "output_dir": "/kaggle/working/saved/",
            "with_tracking": True,
            "report_to": "wandb",
            "checkpointing_steps": "epoch",
            "save_every": 5,
            "lr_scheduler_type": "cosine",
            "split_type": "80_10_10",
            "max_grad_norm": 1.0,  # Added explicit gradient clipping
            "weight_decay": 0.01,  # Added weight decay for regularization
            "debug_nan": True,  # Add flag to enable NaN debugging
            "fp16_precision": True  # Explicitly define precision
        }
    },
    "raw_data": {
        "caption_dataset_path": "/kaggle/working/captions.json",
        "dataset_folder": "/kaggle/working/output_directory"
    },
    "artifact_folder": "/kaggle/working/artifacts"
}

# Constants
INSTRUMENTS = ['piano', 'chromatic', 'organ', 'guitar', 'bass', 'strings', 'ensemble',
               'brass', 'reed', 'pipe', 'synth_lead', 'synth_pad', 'synth_effect',
               'ethnic', 'percussive', 'sfx', 'drum']

# Verify dataset paths
if not os.path.exists(CONFIG['raw_data']['caption_dataset_path']):
    logger.error(f"Caption file not found: {CONFIG['raw_data']['caption_dataset_path']}")
if not os.path.exists(CONFIG['raw_data']['dataset_folder']):
    logger.error(f"Dataset folder not found: {CONFIG['raw_data']['dataset_folder']}")

In [None]:
/////

In [None]:
import os
import time
import jsonlines
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from accelerate import Accelerator, DistributedDataParallelKwargs
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import random
import datetime
import math
from transformers import T5Tokenizer

def train_model_accelerate(configs):
    start_time = time.time()
    print(f"{datetime.datetime.now()}: Starting train_model_accelerate")

    # Initialize Accelerator
    print(f"{datetime.datetime.now()}: Initializing Accelerator")
    accelerator = Accelerator(
        gradient_accumulation_steps=configs['training']['text2midi_model']['gradient_accumulation_steps'],
        mixed_precision='fp16' if configs['training']['text2midi_model']['fp16_precision'] else 'no',
        kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)],
    )

    # Set up directories
    if accelerator.is_main_process:
        output_dir = configs['training']['text2midi_model']['output_dir']
        outputs_dir = os.path.join(output_dir, "outputs")
        os.makedirs(outputs_dir, exist_ok=True)
        if configs['training']['text2midi_model']['with_tracking']:
            try:
                wandb.init(project="Text-2-Midi", settings=wandb.Settings(init_timeout=120))
            except Exception as e:
                print(f"{datetime.datetime.now()}: W&B initialization failed: {str(e)}")
                configs['training']['text2midi_model']['with_tracking'] = False

    accelerator.wait_for_everyone()

    # Load vocabulary
    print(f"{datetime.datetime.now()}: Loading vocabulary")
    vocab_path = os.path.join(configs['artifact_folder'], "vocab_remi.pkl")
    if not os.path.exists(vocab_path):
        print(f"{datetime.datetime.now()}: ERROR: Vocabulary file not found: {vocab_path}")
        raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
    with open(vocab_path, "rb") as f:
        tokenizer = pickle.load(f)

    # Load captions
    print(f"{datetime.datetime.now()}: Loading captions")
    caption_path = configs['raw_data']['caption_dataset_path']
    if not os.path.exists(caption_path):
        print(f"{datetime.datetime.now()}: ERROR: Caption file not found: {caption_path}")
        raise FileNotFoundError(f"Caption file not found: {caption_path}")
    caption_start_time = time.time()
    with jsonlines.open(caption_path) as reader:
        captions = list(reader)
    # Limit dataset size for debugging
    captions = captions[:10000]  # Use only 10,000 captions
    print(f"{datetime.datetime.now()}: Loaded {len(captions)} captions in {time.time() - caption_start_time:.2f} seconds")
    print(f"{datetime.datetime.now()}: WARNING: Skipped MIDI file validation. Missing or corrupted files may cause errors.")

    # Split dataset
    print(f"{datetime.datetime.now()}: Splitting dataset")
    train_captions, temp_captions = train_test_split(captions, test_size=0.2, random_state=42)
    val_captions, test_captions = train_test_split(temp_captions, test_size=0.5, random_state=42)
    print(f"{datetime.datetime.now()}: Train: {len(train_captions)}, Validation: {len(val_captions)}, Test: {len(test_captions)}")

    # Initialize datasets and dataloaders
    print(f"{datetime.datetime.now()}: Initializing datasets")
    train_dataset = Text2MusicDataset(configs, train_captions, remi_tokenizer=tokenizer, mode="train", shuffle=True)
    val_dataset = Text2MusicDataset(configs, val_captions, remi_tokenizer=tokenizer, mode="val")
    test_dataset = Text2MusicDataset(configs, test_captions, remi_tokenizer=tokenizer, mode="test")

    print(f"{datetime.datetime.now()}: Initializing dataloaders")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=configs['training']['text2midi_model']['per_device_train_batch_size'],
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn,
        drop_last=True
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=configs['training']['text2midi_model']['per_device_train_batch_size'],
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn,
        drop_last=False
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=configs['training']['text2midi_model']['per_device_train_batch_size'],
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn,
        drop_last=False
    )

    # Initialize model
    print(f"{datetime.datetime.now()}: Initializing model")
    model = DiscreteDiffusionModel(
        vocab_size=len(tokenizer),
        d_model=configs['model']['text2midi_model']['decoder_d_model'],
        nhead=configs['model']['text2midi_model']['decoder_num_heads'],
        num_layers=configs['model']['text2midi_model']['decoder_num_layers'],
        dim_feedforward=configs['model']['text2midi_model']['decoder_intermediate_size'],
        num_steps=configs['model']['text2midi_model']['num_diffusion_steps'],
        device=accelerator.device
    )

    # Load previous best model if it exists
    best_model_path = os.path.join(configs['training']['text2midi_model']['output_dir'], 'best_model.bin')
    if os.path.exists(best_model_path):
        try:
            state_dict = torch.load(best_model_path, map_location=accelerator.device)
            model.load_state_dict(state_dict)
            print(f"{datetime.datetime.now()}: Loaded checkpoint from {best_model_path}")
        except Exception as e:
            print(f"{datetime.datetime.now()}: Error loading checkpoint {best_model_path}: {str(e)}")

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{datetime.datetime.now()}: Total number of trainable parameters: {total_params}")

    # Setup optimizer and scheduler
    print(f"{datetime.datetime.now()}: Setting up optimizer and scheduler")
    optimizer = optim.AdamW(
        model.parameters(),
        lr=configs['training']['text2midi_model']['learning_rate'] * 0.1,
        weight_decay=configs['training']['text2midi_model']['weight_decay']
    )
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / configs['training']['text2midi_model']['gradient_accumulation_steps'])
    max_train_steps = configs['training']['text2midi_model']['epochs'] * num_update_steps_per_epoch
    lr_scheduler = get_scheduler(
        name=configs['training']['text2midi_model']['lr_scheduler_type'],
        optimizer=optimizer,
        num_warmup_steps=configs['training']['text2midi_model']['num_warmup_steps'],
        num_training_steps=max_train_steps,
    )
    print(f"{datetime.datetime.now()}: Training for {configs['training']['text2midi_model']['epochs']} epochs, {max_train_steps} steps")

    # Prepare for distributed training
    print(f"{datetime.datetime.now()}: Preparing for distributed training")
    model, optimizer, lr_scheduler, train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader, val_dataloader, test_dataloader
    )

    # Training loop
    criterion = nn.CrossEntropyLoss()
    progress_bar = tqdm(range(max_train_steps), desc="Training", disable=not accelerator.is_local_main_process)
    completed_steps = 0
    best_val_loss = float('inf')
    patience = 5
    no_improve = 0
    t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
    smoother = SmoothingFunction().method1

    for epoch in range(configs['training']['text2midi_model']['epochs']):
        print(f"{datetime.datetime.now()}: Starting epoch {epoch + 1}")
        model.train()
        total_train_loss = 0
        total_train_bleu = 0
        total_train_pitch_sim = 0
        total_train_rhythmic_sim = 0
        total_train_perplexity = 0
        total_train_accuracy = 0
        num_batches = 0

        for step, batch in enumerate(train_dataloader):
            try:
                encoder_input, attention_mask, x_t, t, x_0 = batch
                logits = model(x_t, t, encoder_input, attention_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), x_0.view(-1))
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"{datetime.datetime.now()}: WARNING: Invalid loss (nan/inf) at step {step}, skipping")
                    continue
                total_train_loss += loss.detach().float()

                # Compute metrics
                preds = torch.argmax(logits, dim=-1)  # Predicted tokens
                # BLEU score
                bleu_scores = []
                for pred, ref in zip(preds.cpu().tolist(), x_0.cpu().tolist()):
                    bleu_scores.append(sentence_bleu([ref], pred, smoothing_function=smoother))
                batch_bleu = np.mean(bleu_scores)
                total_train_bleu += batch_bleu

                # Pitch similarity (REMI Pitch tokens, MIDI 21–108)
                pitch_mask = (x_0 >= 21) & (x_0 <= 108)
                if pitch_mask.sum() > 0:
                    pitch_sim = (preds[pitch_mask] == x_0[pitch_mask]).float().mean().item()
                    total_train_pitch_sim += pitch_sim
                else:
                    total_train_pitch_sim += 0

                # Rhythmic similarity (REMI Duration tokens, placeholder range)
                rhythm_mask = (x_0 >= 109) & (x_0 <= 200)  # Adjust based on REMI tokenizer
                if rhythm_mask.sum() > 0:
                    rhythmic_sim = (preds[rhythm_mask] == x_0[rhythm_mask]).float().mean().item()
                    total_train_rhythmic_sim += rhythmic_sim
                else:
                    total_train_rhythmic_sim += 0

                # Perplexity
                perplexity = torch.exp(loss.detach()).item()
                total_train_perplexity += perplexity if not math.isinf(perplexity) else 0

                # Token accuracy
                accuracy = (preds == x_0).float().mean().item()
                total_train_accuracy += accuracy

                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), configs['training']['text2midi_model']['max_grad_norm'])
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                num_batches += 1
            except Exception as e:
                print(f"{datetime.datetime.now()}: ERROR in training step {step}: {str(e)}")
                raise

            if accelerator.sync_gradients:
                progress_bar.set_postfix({"Loss": loss.item(), "BLEU": batch_bleu})
                progress_bar.update(1)
                completed_steps += 1

            if completed_steps >= max_train_steps:
                break

        # Average metrics over batches
        avg_loss = total_train_loss / num_batches if num_batches > 0 else float('inf')
        avg_bleu = total_train_bleu / num_batches if num_batches > 0 else 0
        avg_pitch_sim = total_train_pitch_sim / num_batches if num_batches > 0 else 0
        avg_rhythmic_sim = total_train_rhythmic_sim / num_batches if num_batches > 0 else 0
        avg_perplexity = total_train_perplexity / num_batches if num_batches > 0 else float('inf')
        avg_accuracy = total_train_accuracy / num_batches if num_batches > 0 else 0

        # Log metrics to W&B
        if configs['training']['text2midi_model']['with_tracking'] and accelerator.is_main_process:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_loss,
                "train_bleu": avg_bleu,
                "train_pitch_sim": avg_pitch_sim,
                "train_rhythmic_sim": avg_rhythmic_sim,
                "train_perplexity": avg_perplexity,
                "train_accuracy": avg_accuracy
            })

        print(f"{datetime.datetime.now()}: Epoch {epoch + 1} completed - Loss: {avg_loss:.4f}, BLEU: {avg_bleu:.4f}, "
              f"Pitch Sim: {avg_pitch_sim:.4f}, Rhythmic Sim: {avg_rhythmic_sim:.4f}, Perplexity: {avg_perplexity:.4f}, "
              f"Accuracy: {avg_accuracy:.4f}")
                # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0 and accelerator.is_main_process:
            checkpoint_path = os.path.join(
                configs['training']['text2midi_model']['output_dir'],
                f"checkpoint_epoch_{epoch + 1}.bin"
            )
            torch.save(accelerator.unwrap_model(model).state_dict(), checkpoint_path)
            print(f"{datetime.datetime.now()}: Saved checkpoint at {checkpoint_path}")


    print(f"{datetime.datetime.now()}: Training completed in {time.time() - start_time:.2f} seconds")

In [None]:
//////

In [None]:
/// USED CODE IS FOR TRAINING ? EVERYTHING BEFORE THIS REALTED TO TRAINING IS UNUSED //VOCAB CREATION AND INFERANCE IS HERE

In [None]:
import os
import time
import jsonlines
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from transformers import T5Tokenizer, T5EncoderModel, get_scheduler
from miditok import REMI, TokenizerConfig
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb
import math
import random
import logging

# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize W&B login once at the start
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
    logger.info("W&B login successful")
except Exception as e:
    logger.warning(f"W&B login failed: {e}. Logging locally.")
    WANDB_ENABLED = False
else:
    WANDB_ENABLED = True

# Configuration
CONFIG = {
    "model": {
        "text2midi_model": {
            "decoder_d_model": 768,
            "decoder_num_heads": 12,
            "decoder_num_layers": 8,
            "decoder_intermediate_size": 2048,
            "decoder_max_sequence_length": 1024,
            "num_diffusion_steps": 1000,
            "ddim_steps": 100
        }
    },
    "training": {
        "text2midi_model": {
            "learning_rate": 0.0002,
            "epochs": 50,
            "num_warmup_steps": 2000,
            "gradient_accumulation_steps": 2,
            "per_device_train_batch_size": 8,
            "output_dir": "/kaggle/working/saved/",
            "with_tracking": WANDB_ENABLED,
            "checkpointing_steps": "epoch",
            "save_every": 5,
            "lr_scheduler_type": "cosine",
            "max_grad_norm": 1.0,
            "weight_decay": 0.01,
            "fp16_precision": True
        }
    },
    "raw_data": {
        "caption_dataset_path": "/kaggle/working/captions.json",
        "dataset_folder": "/kaggle/working/output_directory"
    },
    "artifact_folder": "/kaggle/working/artifacts"
}

# Helper Functions
def cosine_beta_schedule(timesteps, s=0.008):
    steps = torch.arange(timesteps, dtype=torch.float32)
    f_t = torch.cos(((steps / timesteps + s) / (1.0 + s) * math.pi / 2)) ** 2
    betas = 1.0 - f_t / torch.roll(f_t, shifts=1, dims=0)
    betas = torch.clamp(betas, 0.0, 0.999)
    betas[0] = 0.0001
    return betas

# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class DiscreteDiffusionModel(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048,
                 num_steps=1000, dropout=0.1, device=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_steps = num_steps
        self.device = device

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.time_emb = nn.Embedding(num_steps, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=5000).to(device)

        # FLAN-T5 encoder
        self.encoder = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device)
        for param in self.encoder.parameters():
            param.requires_grad = False

        # Project FLAN-T5 output (768) to d_model
        self.text_projection = nn.Linear(768, d_model).to(device)

        # Transformer-based denoiser
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True, activation="gelu"
        )
        self.denoiser = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output layer
        self.projection = nn.Linear(d_model, vocab_size)

        # D3PM: Uniform transition matrix
        self.Q = torch.ones(vocab_size, vocab_size, device=device) / vocab_size
        self.Q_bar = torch.ones(num_steps, vocab_size, vocab_size, device=device)
        self.log_Q = torch.log(self.Q + 1e-10)

        # Noise schedule (cosine)
        self.betas = cosine_beta_schedule(num_steps).to(device)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)

        # Precompute Q_bar for each timestep
        for t in range(num_steps):
            alpha_bar_t = self.alpha_bar[t]
            self.Q_bar[t] = alpha_bar_t * torch.eye(vocab_size, device=device) + \
                           (1 - alpha_bar_t) * self.Q

        self._reset_parameters()

    def forward(self, x_t, t, src, src_mask):
        x_emb = self.token_emb(x_t) * math.sqrt(self.d_model)
        x_emb = self.pos_encoder(x_emb.transpose(0, 1)).transpose(0, 1)
        t_emb = self.time_emb(t).unsqueeze(1)
        x_emb = x_emb + t_emb
        memory = self.encoder(src, attention_mask=src_mask).last_hidden_state
        memory = self.text_projection(memory)
        output = self.denoiser(x_emb, memory, memory_mask=None)
        logits = self.projection(output)
        return logits

    def sample(self, src, src_mask, seq_len, num_steps=None, ddim=False, eta=0.0):
        device = src.device
        batch_size = src.size(0)
        x_t = torch.randint(0, self.vocab_size, (batch_size, seq_len), device=device)
        num_steps = num_steps or self.num_steps

        if ddim:
            step_indices = torch.linspace(0, self.num_steps - 1, steps=self.num_steps // num_steps + 1, device=device).long()
        else:
            step_indices = torch.arange(num_steps, device=device)

        for i in reversed(range(len(step_indices))):
            t = step_indices[i]
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            with torch.no_grad():
                logits = self(x_t, t_tensor, src, src_mask)
                probs = F.softmax(logits, dim=-1)
                x_t_new = torch.multinomial(probs.view(-1, self.vocab_size), num_samples=1).view(batch_size, seq_len)

            if ddim and i > 0:
                t_prev = step_indices[i - 1]
                alpha_bar_t = self.alpha_bar[t]
                alpha_bar_t_prev = self.alpha_bar[t_prev]
                sigma = eta * torch.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_t_prev))
                x_t = torch.where(
                    torch.rand_like(x_t.float()) < sigma,
                    torch.randint(0, self.vocab_size, x_t.shape, device=device),
                    x_t_new
                )
            else:
                x_t = x_t_new

        return x_t

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

# Dataset
class Text2MusicDataset(Dataset):
    def __init__(self, configs, captions, remi_tokenizer, mode="train"):
        self.mode = mode
        self.captions = captions
        self.dataset_path = configs["raw_data"]["dataset_folder"]
        self.remi_tokenizer = remi_tokenizer
        self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        self.decoder_max_sequence_length = configs["model"]["text2midi_model"]["decoder_max_sequence_length"]
        self.num_steps = configs["model"]["text2midi_model"]["num_diffusion_steps"]
        self.vocab_size = len(remi_tokenizer)

        # Noise schedule
        self.betas = cosine_beta_schedule(self.num_steps)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)
        self.Q = torch.ones(self.vocab_size, self.vocab_size) / self.vocab_size

        # Validate MIDI files
        valid_captions = []
        for cap in captions:
            if not isinstance(cap, dict) or "caption" not in cap or "location" not in cap:
                logger.warning(f"Invalid caption format: {cap}")
                continue
            midi_path = os.path.join(self.dataset_path, cap["location"])
            if os.path.exists(midi_path):
                valid_captions.append(cap)
            else:
                logger.warning(f"MIDI file not found: {midi_path}")
        self.captions = valid_captions
        if not self.captions:
            raise ValueError("No valid captions found after validation")
        logger.info(f"Dataset size after validation: {len(self.captions)}")

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]["caption"]
        midi_filepath = os.path.join(self.dataset_path, self.captions[idx]["location"])

        try:
            tokens = self.remi_tokenizer(midi_filepath)
            tokenized_midi = ([self.remi_tokenizer["BOS_None"]] + tokens.ids +
                             [self.remi_tokenizer["EOS_None"]]) if tokens.ids else \
                            [self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]]
            tokenized_midi = torch.tensor(tokenized_midi)
        except Exception as e:
            logger.error(f"Error tokenizing MIDI file {midi_filepath}: {str(e)}")
            tokenized_midi = torch.tensor([self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]])

        if len(tokenized_midi) < self.decoder_max_sequence_length:
            x_0 = F.pad(tokenized_midi, (0, self.decoder_max_sequence_length - len(tokenized_midi))).long()
        else:
            x_0 = tokenized_midi[:self.decoder_max_sequence_length].long()

        t = torch.randint(0, self.num_steps, (1,)).item()
        alpha_bar_t = self.alpha_bar[t]
        Q_t = alpha_bar_t * torch.eye(self.vocab_size) + (1 - alpha_bar_t) * self.Q
        probs = F.one_hot(x_0, num_classes=self.vocab_size).float() @ Q_t
        x_t = torch.multinomial(probs, num_samples=1).squeeze(-1)

        inputs = self.t5_tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
        return (inputs["input_ids"].squeeze(0), inputs["attention_mask"].squeeze(0),
                x_t.long(), t, x_0.long())

def collate_fn(batch):
    input_ids = nn.utils.rnn.pad_sequence([item[0] for item in batch], batch_first=True, padding_value=0)
    attention_mask = nn.utils.rnn.pad_sequence([item[1] for item in batch], batch_first=True, padding_value=0)
    x_t = nn.utils.rnn.pad_sequence([item[2] for item in batch], batch_first=True, padding_value=0)
    t = torch.tensor([item[3] for item in batch])
    x_0 = nn.utils.rnn.pad_sequence([item[4] for item in batch], batch_first=True, padding_value=0)
    return input_ids, attention_mask, x_t, t, x_0

# Training Function
def train_model(configs):
    accelerator = Accelerator(
        gradient_accumulation_steps=configs["training"]["text2midi_model"]["gradient_accumulation_steps"],
        mixed_precision="fp16" if configs["training"]["text2midi_model"]["fp16_precision"] else "no"
    )

    # Set up directories
    if accelerator.is_main_process:
        output_dir = configs["training"]["text2midi_model"]["output_dir"]
        os.makedirs(output_dir, exist_ok=True)
        if configs["training"]["text2midi_model"]["with_tracking"]:
            wandb.init(project="Text-2-Midi")

    # Load tokenizer
    vocab_path = os.path.join(configs["artifact_folder"], "vocab_remi.pkl")
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
    with open(vocab_path, "rb") as f:
        tokenizer = pickle.load(f)

    # Load captions
    caption_path = configs["raw_data"]["caption_dataset_path"]
    if not os.path.exists(caption_path):
        raise FileNotFoundError(f"Caption file not found: {caption_path}")
    with jsonlines.open(caption_path) as reader:
        captions = list(reader)[:10000]  # Limit to 10,000 captions
    if not captions:
        raise ValueError("No captions loaded from file")
    train_captions, temp_captions = train_test_split(captions, test_size=0.2, random_state=42)
    val_captions, test_captions = train_test_split(temp_captions, test_size=0.5, random_state=42)
    logger.info(f"Train: {len(train_captions)}, Val: {len(val_captions)}, Test: {len(test_captions)}")

    # Initialize datasets and dataloaders
    train_dataset = Text2MusicDataset(configs, train_captions, remi_tokenizer=tokenizer, mode="train")
    val_dataset = Text2MusicDataset(configs, val_captions, remi_tokenizer=tokenizer, mode="val")
    train_dataloader = DataLoader(
        train_dataset, batch_size=configs["training"]["text2midi_model"]["per_device_train_batch_size"],
        shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last=True
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=configs["training"]["text2midi_model"]["per_device_train_batch_size"],
        shuffle=False, num_workers=4, collate_fn=collate_fn, drop_last=False
    )

    # Initialize model
    model = DiscreteDiffusionModel(
        vocab_size=len(tokenizer),
        d_model=configs["model"]["text2midi_model"]["decoder_d_model"],
        nhead=configs["model"]["text2midi_model"]["decoder_num_heads"],
        num_layers=configs["model"]["text2midi_model"]["decoder_num_layers"],
        dim_feedforward=configs["model"]["text2midi_model"]["decoder_intermediate_size"],
        num_steps=configs["model"]["text2midi_model"]["num_diffusion_steps"],
        device=accelerator.device
    )

    # Load checkpoint if available
    best_model_path = os.path.join(output_dir, "best_model.bin")
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path, map_location=accelerator.device))
        logger.info(f"Loaded checkpoint from {best_model_path}")

    # Setup optimizer and scheduler
    optimizer = optim.AdamW(
        model.parameters(),
        lr=configs["training"]["text2midi_model"]["learning_rate"],
        weight_decay=configs["training"]["text2midi_model"]["weight_decay"]
    )
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / configs["training"]["text2midi_model"]["gradient_accumulation_steps"])
    max_train_steps = configs["training"]["text2midi_model"]["epochs"] * num_update_steps_per_epoch
    lr_scheduler = get_scheduler(
        name=configs["training"]["text2midi_model"]["lr_scheduler_type"],
        optimizer=optimizer,
        num_warmup_steps=configs["training"]["text2midi_model"]["num_warmup_steps"],
        num_training_steps=max_train_steps
    )

    # Prepare for distributed training
    model, optimizer, lr_scheduler, train_dataloader, val_dataloader = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader, val_dataloader
    )

    # Training loop
    criterion = nn.CrossEntropyLoss()
    progress_bar = tqdm(range(max_train_steps), desc="Training", disable=not accelerator.is_local_main_process)
    completed_steps = 0
    best_val_loss = float("inf")

    for epoch in range(configs["training"]["text2midi_model"]["epochs"]):
        model.train()
        total_train_loss = 0
        total_train_accuracy = 0
        total_train_perplexity = 0
        num_batches = 0

        for batch in train_dataloader:
            with accelerator.accumulate(model):
                encoder_input, attention_mask, x_t, t, x_0 = batch
                logits = model(x_t, t, encoder_input, attention_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), x_0.view(-1))

                if torch.isnan(loss) or torch.isinf(loss):
                    logger.warning(f"Invalid loss at epoch {epoch+1}, skipping")
                    continue

                total_train_loss += loss.detach().float()
                preds = torch.argmax(logits, dim=-1)
                total_train_accuracy += (preds == x_0).float().mean().item()
                total_train_perplexity += torch.exp(loss.detach()).item()
                num_batches += 1

                accelerator.backward(loss)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), configs["training"]["text2midi_model"]["max_grad_norm"])
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

        if num_batches == 0:
            logger.warning(f"No valid batches processed in epoch {epoch+1}")
            continue

        avg_loss = total_train_loss / num_batches
        avg_accuracy = total_train_accuracy / num_batches
        avg_perplexity = total_train_perplexity / num_batches

        # Validation
        model.eval()
        total_val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for batch in val_dataloader:
                encoder_input, attention_mask, x_t, t, x_0 = batch
                logits = model(x_t, t, encoder_input, attention_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), x_0.view(-1))
                total_val_loss += loss.item()
                val_batches += 1
        avg_val_loss = total_val_loss / val_batches if val_batches > 0 else float("inf")

        # Log metrics
        if configs["training"]["text2midi_model"]["with_tracking"] and accelerator.is_main_process:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_loss,
                "train_accuracy": avg_accuracy,
                "train_perplexity": avg_perplexity,
                "val_loss": avg_val_loss
            })

        logger.info(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Accuracy={avg_accuracy:.4f}, "
                    f"Perplexity={avg_perplexity:.4f}, Val Loss={avg_val_loss:.4f}")

        # Save checkpoint
        if (epoch + 1) % configs["training"]["text2midi_model"]["save_every"] == 0 and accelerator.is_main_process:
            checkpoint_path = os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.bin")
            torch.save(accelerator.unwrap_model(model).state_dict(), checkpoint_path)
            logger.info(f"Saved checkpoint at {checkpoint_path}")

        # Save best model
        if avg_val_loss < best_val_loss and accelerator.is_main_process:
            best_val_loss = avg_val_loss
            torch.save(accelerator.unwrap_model(model).state_dict(), best_model_path)
            logger.info(f"Saved best model at {best_model_path}")

    accelerator.end_training()

# Inference Function

# Vocabulary Building
def build_vocab_remi(configs):
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1},
        "num_velocities": 32,
        "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
        "use_chords": False,
        "use_rests": False,
        "use_tempos": True,
        "use_time_signatures": True,
        "use_programs": True,
        "num_tempos": 32,
        "tempo_range": (40, 250),
    }
    config = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(config)
    logger.info(f"Vocabulary length: {tokenizer.vocab_size}")
    vocab_path = os.path.join(configs["artifact_folder"], "vocab_remi.pkl")
    os.makedirs(configs["artifact_folder"], exist_ok=True)
    with open(vocab_path, "wb") as f:
        pickle.dump(tokenizer, f)
    logger.info(f"Vocabulary saved to {vocab_path}")

# Main Execution

def generate_midi(model, tokenizer, t5_tokenizer, caption, seq_len, output_dir, device, ddim=True, num_steps=100):
    model.eval()
    inputs = t5_tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    with torch.no_grad():
        tokens = model.sample(input_ids, attention_mask, seq_len, num_steps=num_steps, ddim=ddim)

    token_ids = tokens[0].cpu().numpy().tolist()
    if token_ids[0] == tokenizer["BOS_None"]:
        token_ids = token_ids[1:]
    if token_ids[-1] == tokenizer["EOS_None"]:
        token_ids = token_ids[:-1]

    try:
        midi_score = tokenizer.decode(token_ids)
        output_path = os.path.join(output_dir, f"generated_{int(time.time())}.mid")
        midi_score.dump_midi(output_path)
        logger.info(f"Generated MIDI saved to {output_path}")
        return output_path
    except Exception as e:
        logger.error(f"Error converting tokens to MIDI: {str(e)}")
        return None



def main(mode="train"):
    if mode == "train":
        train_model(CONFIG)
    elif mode == "build_vocab_remi":
        build_vocab_remi(CONFIG)
    elif mode == "infer":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        with open(os.path.join(CONFIG["artifact_folder"], "vocab_remi.pkl"), "rb") as f:
            tokenizer = pickle.load(f)
        model = DiscreteDiffusionModel(
            vocab_size=len(tokenizer),
            d_model=CONFIG["model"]["text2midi_model"]["decoder_d_model"],
            nhead=CONFIG["model"]["text2midi_model"]["decoder_num_heads"],
            num_layers=CONFIG["model"]["text2midi_model"]["decoder_num_layers"],
            dim_feedforward=CONFIG["model"]["text2midi_model"]["decoder_intermediate_size"],
            num_steps=CONFIG["model"]["text2midi_model"]["num_diffusion_steps"],
            device=device
        ).to(device)
        checkpoint_path = os.path.join(CONFIG["training"]["text2midi_model"]["output_dir"], "best_model.bin")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        caption = "A melodic electronic composition with classical influences, featuring a string ensemble, trumpet, brass section, synth strings, and drums. Set in F# minor with a 4/4 time signature, it moves at an Allegro tempo. The mood evokes a cinematic, spacious, and epic atmosphere while maintaining a sense of relaxation."
        generate_long_midi(
            model=model,
            tokenizer=tokenizer,
            t5_tokenizer=t5_tokenizer,
            caption="A dramatic orchestral theme with sweeping strings and pounding drums.",
            seq_len=CONFIG["model"]["text2midi_model"]["decoder_max_sequence_length"],
            output_dir=CONFIG["training"]["text2midi_model"]["output_dir"],
            device=device,
            num_chunks=20,
            remove_silence=True,
            num_steps=200,
            ddim=True
        )

    else:
        raise ValueError(f"Invalid mode: {mode}. Choose 'train', 'build_vocab_remi', or 'infer'")

        
if __name__ == "__main__":
    mode = os.getenv("MODE", "train")
    main(mode)

In [None]:
////genrate with post process :

In [None]:
!pip install miditoolkit

In [None]:
// old genrate midi

In [None]:
def generate_midi(model, tokenizer, t5_tokenizer, caption, seq_len, output_dir, device, ddim=True, num_steps=100):
    model.eval()
    inputs = t5_tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    with torch.no_grad():
        tokens = model.sample(input_ids, attention_mask, seq_len, num_steps=num_steps, ddim=ddim)

    token_ids = tokens[0].cpu().numpy().tolist()
    if token_ids[0] == tokenizer["BOS_None"]:
        token_ids = token_ids[1:]
    if token_ids[-1] == tokenizer["EOS_None"]:
        token_ids = token_ids[:-1]

    try:
        midi_score = tokenizer.decode(token_ids)
        output_path = os.path.join(output_dir, f"generated_{int(time.time())}.mid")
        midi_score.dump_midi(output_path)
        logger.info(f"Generated MIDI saved to {output_path}")
        return output_path
    except Exception as e:
        logger.error(f"Error converting tokens to MIDI: {str(e)}")
        return None


In [None]:
////old main 

In [None]:
def main(mode="infer"):
    if mode == "train":
        train_model(CONFIG)
    elif mode == "build_vocab_remi":
        build_vocab_remi(CONFIG)
    elif mode == "infer":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        with open(os.path.join(CONFIG["artifact_folder"], "vocab_remi.pkl"), "rb") as f:
            tokenizer = pickle.load(f)
        model = DiscreteDiffusionModel(
            vocab_size=len(tokenizer),
            d_model=CONFIG["model"]["text2midi_model"]["decoder_d_model"],
            nhead=CONFIG["model"]["text2midi_model"]["decoder_num_heads"],
            num_layers=CONFIG["model"]["text2midi_model"]["decoder_num_layers"],
            dim_feedforward=CONFIG["model"]["text2midi_model"]["decoder_intermediate_size"],
            num_steps=CONFIG["model"]["text2midi_model"]["num_diffusion_steps"],
            device=device
        ).to(device)
        checkpoint_path = os.path.join(CONFIG["training"]["text2midi_model"]["output_dir"], "best_model.bin")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        caption = "A melodic electronic composition with classical influences, featuring a string ensemble, trumpet, brass section, synth strings, and drums. Set in F# minor with a 4/4 time signature, it moves at an Allegro tempo. The mood evokes a cinematic, spacious, and epic atmosphere while maintaining a sense of relaxation."
        generate_long_midi(
            model=model,
            tokenizer=tokenizer,
            t5_tokenizer=t5_tokenizer,
            caption="A dramatic orchestral theme with sweeping strings and pounding drums.",
            seq_len=CONFIG["model"]["text2midi_model"]["decoder_max_sequence_length"],
            output_dir=CONFIG["training"]["text2midi_model"]["output_dir"],
            device=device,
            num_chunks=20,
            remove_silence=True,
            num_steps=200,
            ddim=True
        )

    else:
        raise ValueError(f"Invalid mode: {mode}. Choose 'train', 'build_vocab_remi', or 'infer'")


In [None]:
generate_midi(
            model, tokenizer, t5_tokenizer, caption,
            seq_len=CONFIG["model"]["text2midi_model"]["decoder_max_sequence_length"],
            output_dir=CONFIG["training"]["text2midi_model"]["output_dir"],
            device=device, ddim=True, num_steps=200
        )

In [None]:
!pip install pretty_midi

In [None]:
!pip install midi_utils

In [None]:
!pip install laion_clap

In [None]:
////

In [None]:
import os
import time
import jsonlines
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from transformers import T5Tokenizer, T5EncoderModel, get_scheduler
from miditok import REMI, TokenizerConfig
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb
import math
import random
import logging
import re
import tempfile
import music21
import pretty_midi

# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize W&B login
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
    logger.info("W&B login successful")
except Exception as e:
    logger.warning(f"W&B login failed: {e}. Logging locally.")
    WANDB_ENABLED = False
else:
    WANDB_ENABLED = True

# Configuration
CONFIG = {
    "model": {
        "text2midi_model": {
            "decoder_d_model": 1024,
            "decoder_num_heads": 8,
            "decoder_num_layers": 12,
            "decoder_intermediate_size": 2048,
            "decoder_max_sequence_length": 2048,
            "num_diffusion_steps": 1000,
            "ddim_steps": 100
        }
    },
    "training": {
        "text2midi_model": {
            "learning_rate": 0.0003,
            "epochs": 50,
            "num_warmup_steps": 1000,
            "gradient_accumulation_steps": 2,
            "per_device_train_batch_size": 8,
            "output_dir": "/kaggle/working/saved/",
            "with_tracking": WANDB_ENABLED,
            "checkpointing_steps": "epoch",
            "save_every": 5,
            "lr_scheduler_type": "cosine",
            "max_grad_norm": 1.0,
            "weight_decay": 0.001,
            "fp16_precision": True
        }
    },
    "raw_data": {
        "caption_dataset_path": "/kaggle/working/captions.json",
        "dataset_folder": "/kaggle/working/output_directory"
    },
    "artifact_folder": "/kaggle/working/artifacts"
}

# Helper Functions
def cosine_beta_schedule(timesteps, s=0.008):
    steps = torch.arange(timesteps, dtype=torch.float32)
    f_t = torch.cos(((steps / timesteps + s) / (1.0 + s) * math.pi / 2)) ** 2
    betas = 1.0 - f_t / torch.roll(f_t, shifts=1, dims=0)
    betas = torch.clamp(betas, 0.0, 0.999)
    betas[0] = 0.0001
    return betas

# Evaluation Metric Helpers
TEMPO_TERM_TO_BPM = {
    "largo": 50,
    "adagio": 70,
    "andante": 90,
    "moderato": 110,
    "allegro": 140,
    "presto": 180,
}

def parse_caption(caption):
    tempo_match = re.search(r'at an? (\w+) tempo', caption, re.IGNORECASE)
    tempo_term = tempo_match.group(1) if tempo_match else None
    key_match = re.search(r'in ([A-G][#b]? (major|minor))', caption, re.IGNORECASE)
    key = key_match.group(1) if key_match else None
    return tempo_term, key

def map_tempo_term_to_bin(tempo_term):
    if tempo_term is None:
        return None
    tempo_term = tempo_term.lower()
    if tempo_term in TEMPO_TERM_TO_BPM:
        target_bpm = TEMPO_TERM_TO_BPM[tempo_term]
        min_tempo = 40
        max_tempo = 250
        num_tempos = 32
        bin_width = (max_tempo - min_tempo) / (num_tempos - 1)
        bin_index = round((target_bpm - min_tempo) / bin_width)
        return max(0, min(bin_index, num_tempos - 1))
    return None

def extract_tempo_bin(tokens, tokenizer):
    min_tempo = 40
    max_tempo = 250
    num_tempos = 32
    bin_width = (max_tempo - min_tempo) / (num_tempos - 1)
    
    for token in tokens:
        token_str = tokenizer[token]
        if token_str.startswith("Tempo_"):
            try:
                tempo_value = float(token_str.split('_')[1])
                bin_index = round((tempo_value - min_tempo) / bin_width)
                return max(0, min(bin_index, num_tempos - 1))
            except (ValueError, IndexError) as e:
                logger.warning(f"Failed to parse tempo token {token_str}: {e}")
                return None
    return None

def detect_key(midi_path):
    try:
        score = music21.converter.parse(midi_path)
        key = score.analyze('key')
        return key.tonic.name + ' ' + key.mode
    except Exception as e:
        logger.warning(f"Key detection failed: {e}")
        return None

def compression_ratio(tokens):
    unique_tokens = len(set([t for t in tokens if t != 0]))  # Exclude padding
    total_tokens = len([t for t in tokens if t != 0])
    return unique_tokens / total_tokens if total_tokens > 0 else 0

# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class DiscreteDiffusionModel(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=8, dim_feedforward=2048,
                 num_steps=1000, dropout=0.1, device=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_steps = num_steps
        self.device = device

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.time_emb = nn.Embedding(num_steps, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=5000).to(device)
        self.encoder = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device)
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.text_projection = nn.Linear(768, d_model).to(device)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True, activation="gelu"
        )
        self.denoiser = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.projection = nn.Linear(d_model, vocab_size)

        self.Q = torch.ones(vocab_size, vocab_size, device=device) / vocab_size
        self.Q_bar = torch.ones(num_steps, vocab_size, vocab_size, device=device)
        self.log_Q = torch.log(self.Q + 1e-10)
        self.betas = cosine_beta_schedule(num_steps).to(device)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)

        for t in range(num_steps):
            alpha_bar_t = self.alpha_bar[t]
            self.Q_bar[t] = alpha_bar_t * torch.eye(vocab_size, device=device) + \
                           (1 - alpha_bar_t) * self.Q

        self._reset_parameters()

    def forward(self, x_t, t, src, src_mask):
        x_emb = self.token_emb(x_t) * math.sqrt(self.d_model)
        x_emb = self.pos_encoder(x_emb.transpose(0, 1)).transpose(0, 1)
        t_emb = self.time_emb(t).unsqueeze(1)
        x_emb = x_emb + t_emb
        memory = self.encoder(src, attention_mask=src_mask).last_hidden_state
        memory = self.text_projection(memory)
        output = self.denoiser(x_emb, memory, memory_mask=None)
        logits = self.projection(output)
        return logits

    def sample(self, src, src_mask, seq_len, num_steps=None, ddim=False, eta=0.0):
        device = src.device
        batch_size = src.size(0)
        x_t = torch.randint(0, self.vocab_size, (batch_size, seq_len), device=device)
        num_steps = num_steps or self.num_steps

        if ddim:
            step_indices = torch.linspace(0, self.num_steps - 1, steps=self.num_steps // num_steps + 1, device=device).long()
        else:
            step_indices = torch.arange(num_steps, device=device)

        for i in reversed(range(len(step_indices))):
            t = step_indices[i]
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            with torch.no_grad():
                logits = self(x_t, t_tensor, src, src_mask)
                probs = F.softmax(logits, dim=-1)
                x_t_new = torch.multinomial(probs.view(-1, self.vocab_size), num_samples=1).view(batch_size, seq_len)

            if ddim and i > 0:
                t_prev = step_indices[i - 1]
                alpha_bar_t = self.alpha_bar[t]
                alpha_bar_t_prev = self.alpha_bar[t_prev]
                sigma = eta * torch.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_t_prev))
                x_t = torch.where(
                    torch.rand_like(x_t.float()) < sigma,
                    torch.randint(0, self.vocab_size, x_t.shape, device=device),
                    x_t_new
                )
            else:
                x_t = x_t_new

        return x_t

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

# Dataset
class Text2MusicDataset(Dataset):
    def __init__(self, configs, captions, remi_tokenizer, mode="train"):
        self.mode = mode
        self.captions = captions
        self.dataset_path = configs["raw_data"]["dataset_folder"]
        self.remi_tokenizer = remi_tokenizer
        self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        self.decoder_max_sequence_length = configs["model"]["text2midi_model"]["decoder_max_sequence_length"]
        self.num_steps = configs["model"]["text2midi_model"]["num_diffusion_steps"]
        self.vocab_size = len(remi_tokenizer)

        self.betas = cosine_beta_schedule(self.num_steps)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)
        self.Q = torch.ones(self.vocab_size, self.vocab_size) / self.vocab_size

        valid_captions = []
        for cap in captions:
            if not isinstance(cap, dict) or "caption" not in cap or "location" not in cap:
                logger.warning(f"Invalid caption format: {cap}")
                continue
            midi_path = os.path.join(self.dataset_path, cap["location"])
            if os.path.exists(midi_path):
                valid_captions.append(cap)
            else:
                logger.warning(f"MIDI file not found: {midi_path}")
        self.captions = valid_captions
        if not self.captions:
            raise ValueError("No valid captions found after validation")
        logger.info(f"Dataset size after validation: {len(self.captions)}")

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]["caption"]
        midi_filepath = os.path.join(self.dataset_path, self.captions[idx]["location"])

        try:
            tokens = self.remi_tokenizer(midi_filepath)
            tokenized_midi = ([self.remi_tokenizer["BOS_None"]] + tokens.ids +
                             [self.remi_tokenizer["EOS_None"]]) if tokens.ids else \
                            [self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]]
            tokenized_midi = torch.tensor(tokenized_midi)
        except Exception as e:
            logger.error(f"Error tokenizing MIDI file {midi_filepath}: {str(e)}")
            tokenized_midi = torch.tensor([self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]])

        if len(tokenized_midi) < self.decoder_max_sequence_length:
            x_0 = F.pad(tokenized_midi, (0, self.decoder_max_sequence_length - len(tokenized_midi))).long()
        else:
            x_0 = tokenized_midi[:self.decoder_max_sequence_length].long()

        t = torch.randint(0, self.num_steps, (1,)).item()
        alpha_bar_t = self.alpha_bar[t]
        Q_t = alpha_bar_t * torch.eye(self.vocab_size) + (1 - alpha_bar_t) * self.Q
        probs = F.one_hot(x_0, num_classes=self.vocab_size).float() @ Q_t
        x_t = torch.multinomial(probs, num_samples=1).squeeze(-1)

        inputs = self.t5_tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
        return (caption, inputs["input_ids"].squeeze(0), inputs["attention_mask"].squeeze(0),
                x_t.long(), t, x_0.long())

def collate_fn(batch):
    captions = [item[0] for item in batch]
    input_ids = nn.utils.rnn.pad_sequence([item[1] for item in batch], batch_first=True, padding_value=0)
    attention_mask = nn.utils.rnn.pad_sequence([item[2] for item in batch], batch_first=True, padding_value=0)
    x_t = nn.utils.rnn.pad_sequence([item[3] for item in batch], batch_first=True, padding_value=0)
    t = torch.tensor([item[4] for item in batch])
    x_0 = nn.utils.rnn.pad_sequence([item[5] for item in batch], batch_first=True, padding_value=0)
    return captions, input_ids, attention_mask, x_t, t, x_0

# Training Function
def train_model(configs):
    accelerator = Accelerator(
        gradient_accumulation_steps=configs["training"]["text2midi_model"]["gradient_accumulation_steps"],
        mixed_precision="fp16" if configs["training"]["text2midi_model"]["fp16_precision"] else "no"
    )

    output_dir = configs["training"]["text2midi_model"]["output_dir"]
    if accelerator.is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        if configs["training"]["text2midi_model"]["with_tracking"]:
            wandb.init(project="Text-2-Midi")

    vocab_path = os.path.join(configs["artifact_folder"], "vocab_remi.pkl")
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
    with open(vocab_path, "rb") as f:
        tokenizer = pickle.load(f)

    caption_path = configs["raw_data"]["caption_dataset_path"]
    if not os.path.exists(caption_path):
        raise FileNotFoundError(f"Caption file not found: {caption_path}")
    with jsonlines.open(caption_path) as reader:
        captions = list(reader)[:2000]  # Limit to 2000 captions
    if not captions:
        raise ValueError("No captions loaded from file")
    train_captions, temp_captions = train_test_split(captions, test_size=0.2, random_state=42)
    val_captions, test_captions = train_test_split(temp_captions, test_size=0.5, random_state=42)
    logger.info(f"Train: {len(train_captions)}, Val: {len(val_captions)}, Test: {len(test_captions)}")

    train_dataset = Text2MusicDataset(configs, train_captions, remi_tokenizer=tokenizer, mode="train")
    val_dataset = Text2MusicDataset(configs, val_captions, remi_tokenizer=tokenizer, mode="val")
    train_dataloader = DataLoader(
        train_dataset, batch_size=configs["training"]["text2midi_model"]["per_device_train_batch_size"],
        shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last=True
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=configs["training"]["text2midi_model"]["per_device_train_batch_size"],
        shuffle=False, num_workers=4, collate_fn=collate_fn, drop_last=False
    )

    model = DiscreteDiffusionModel(
        vocab_size=len(tokenizer),
        d_model=configs["model"]["text2midi_model"]["decoder_d_model"],
        nhead=configs["model"]["text2midi_model"]["decoder_num_heads"],
        num_layers=configs["model"]["text2midi_model"]["decoder_num_layers"],
        dim_feedforward=configs["model"]["text2midi_model"]["decoder_intermediate_size"],
        num_steps=configs["model"]["text2midi_model"]["num_diffusion_steps"],
        device=accelerator.device
    )

    best_model_path = os.path.join(output_dir, "best_model.bin")
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path, map_location=accelerator.device))
        logger.info(f"Loaded checkpoint from {best_model_path}")

    optimizer = optim.AdamW(
        model.parameters(),
        lr=configs["training"]["text2midi_model"]["learning_rate"],
        weight_decay=configs["training"]["text2midi_model"]["weight_decay"]
    )
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / configs["training"]["text2midi_model"]["gradient_accumulation_steps"])
    max_train_steps = configs["training"]["text2midi_model"]["epochs"] * num_update_steps_per_epoch
    lr_scheduler = get_scheduler(
        name=configs["training"]["text2midi_model"]["lr_scheduler_type"],
        optimizer=optimizer,
        num_warmup_steps=configs["training"]["text2midi_model"]["num_warmup_steps"],
        num_training_steps=max_train_steps
    )

    model, optimizer, lr_scheduler, train_dataloader, val_dataloader = accelerator.prepare(
        model, optimizer, lr_scheduler, train_dataloader, val_dataloader
    )

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    progress_bar = tqdm(range(max_train_steps), desc="Training", disable=not accelerator.is_local_main_process)
    completed_steps = 0
    best_val_loss = float("inf")

    for epoch in range(configs["training"]["text2midi_model"]["epochs"]):
        model.train()
        total_train_loss = 0
        total_train_accuracy = 0
        total_train_perplexity = 0
        num_batches = 0

        for batch in train_dataloader:
            captions, encoder_input, attention_mask, x_t, t, x_0 = batch
            with accelerator.accumulate(model):
                logits = model(x_t, t, encoder_input, attention_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), x_0.view(-1))

                if torch.isnan(loss) or torch.isinf(loss):
                    logger.warning(f"Invalid loss at epoch {epoch+1}, skipping")
                    continue

                total_train_loss += loss.detach().float()
                preds = torch.argmax(logits, dim=-1)
                total_train_accuracy += (preds == x_0).float().mean().item()
                total_train_perplexity += torch.exp(loss.detach()).item()
                num_batches += 1

                accelerator.backward(loss)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), configs["training"]["text2midi_model"]["max_grad_norm"])
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

        if num_batches == 0:
            logger.warning(f"No valid batches processed in epoch {epoch+1}")
            continue

        avg_loss = total_train_loss / num_batches
        avg_accuracy = total_train_accuracy / num_batches
        avg_perplexity = total_train_perplexity / num_batches

        # Validation
        model.eval()
        total_val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for batch in val_dataloader:
                _, encoder_input, attention_mask, x_t, t, x_0 = batch
                logits = model(x_t, t, encoder_input, attention_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), x_0.view(-1))
                total_val_loss += loss.item()
                val_batches += 1
        avg_val_loss = total_val_loss / val_batches if val_batches > 0 else float("inf")

        # Evaluate metrics on validation subset
        num_eval_samples = min(10, len(val_dataset))
        eval_indices = random.sample(range(len(val_dataset)), num_eval_samples)
        eval_metrics = {"CR": [], "TB": [], "TBT": [], "CK": [], "CKD": [], "CLAP": []}

        for idx in eval_indices:
            caption, input_ids, attention_mask, _, _, _ = val_dataset[idx]
            input_ids = input_ids.unsqueeze(0).to(accelerator.device)
            attention_mask = attention_mask.unsqueeze(0).to(accelerator.device)

            with torch.no_grad():
                generated_tokens = model.sample(
                    input_ids, attention_mask,
                    seq_len=configs["model"]["text2midi_model"]["decoder_max_sequence_length"],
                    num_steps=configs["model"]["text2midi_model"]["ddim_steps"],
                    ddim=True
                )
            generated_tokens = generated_tokens[0].cpu().numpy().tolist()

            if generated_tokens[0] == tokenizer["BOS_None"]:
                generated_tokens = generated_tokens[1:]
            if generated_tokens[-1] == tokenizer["EOS_None"]:
                generated_tokens = generated_tokens[:-1]

            # Compression Ratio
            cr = compression_ratio(generated_tokens)
            eval_metrics["CR"].append(cr)

            # Tempo and Key Metrics
            tempo_term, target_key = parse_caption(caption)
            target_tempo_bin = map_tempo_term_to_bin(tempo_term)
            generated_tempo_bin = extract_tempo_bin(generated_tokens, tokenizer)

            if generated_tempo_bin is not None and target_tempo_bin is not None:
                tb = 1 if generated_tempo_bin == target_tempo_bin else 0
                tbt = 1 if abs(generated_tempo_bin - target_tempo_bin) <= 1 else 0
            else:
                tb = tbt = 0
            eval_metrics["TB"].append(tb)
            eval_metrics["TBT"].append(tbt)

            # Generate MIDI for Key and CLAP
            try:
                with tempfile.NamedTemporaryFile(suffix=".mid", delete=True) as temp_midi:
                    midi_score = tokenizer.decode(generated_tokens)
                    midi_score.dump(temp_midi.name)
                    detected_key = detect_key(temp_midi.name)
                    ck = 1 if detected_key and target_key and detected_key.lower() == target_key.lower() else 0
                    eval_metrics["CK"].append(ck)
                    eval_metrics["CKD"].append(ck)  # Same as CK for simplicity
                    pm = pretty_midi.PrettyMIDI(temp_midi.name)
                    audio = pm.synthesize(fs=44100)
            except Exception as e:
                logger.warning(f"Error in eval sample {idx}: {str(e)}")
                eval_metrics["CK"].append(0)
                eval_metrics["CKD"].append(0)
                eval_metrics["CLAP"].append(0)

        avg_metrics = {k: sum(v) / len(v) for k, v in eval_metrics.items() if v}

        # Log metrics
        if configs["training"]["text2midi_model"]["with_tracking"] and accelerator.is_main_process:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_loss,
                "train_accuracy": avg_accuracy,
                "train_perplexity": avg_perplexity,
                "val_loss": avg_val_loss,
                "val_CR": avg_metrics.get("CR", 0),
                "val_TB": avg_metrics.get("TB", 0),
                "val_TBT": avg_metrics.get("TBT", 0),
                "val_CK": avg_metrics.get("CK", 0),
                "val_CKD": avg_metrics.get("CKD", 0),
                "val_CLAP": avg_metrics.get("CLAP", 0)
            })

        logger.info(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Accuracy={avg_accuracy:.4f}, "
                    f"Perplexity={avg_perplexity:.4f}, Val Loss={avg_val_loss:.4f}, "
                    f"CR={avg_metrics.get('CR', 0):.4f}, TB={avg_metrics.get('TB', 0):.4f}, "
                    f"TBT={avg_metrics.get('TBT', 0):.4f}, CK={avg_metrics.get('CK', 0):.4f}, "
                    f"CLAP={avg_metrics.get('CLAP', 0):.4f}")

        # Save checkpoints
        if (epoch + 1) % configs["training"]["text2midi_model"]["save_every"] == 0 and accelerator.is_main_process:
            checkpoint_path = os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.bin")
            torch.save(accelerator.unwrap_model(model).state_dict(), checkpoint_path)
            logger.info(f"Saved checkpoint at {checkpoint_path}")

        if avg_val_loss < best_val_loss and accelerator.is_main_process:
            best_val_loss = avg_val_loss
            torch.save(accelerator.unwrap_model(model).state_dict(), best_model_path)
            logger.info(f"Saved best model at {best_model_path}")

    accelerator.end_training()

# Vocabulary Building
def build_vocab_remi(configs):
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1},
        "num_velocities": 32,
        "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
        "use_chords": False,
        "use_rests": False,
        "use_tempos": True,
        "use_time_signatures": True,
        "use_programs": True,
        "num_tempos": 32,
        "tempo_range": (40, 250),
    }
    config = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(config)
    logger.info(f"Vocabulary length: {tokenizer.vocab_size}")
    vocab_path = os.path.join(configs["artifact_folder"], "vocab_remi.pkl")
    os.makedirs(configs["artifact_folder"], exist_ok=True)
    with open(vocab_path, "wb") as f:
        pickle.dump(tokenizer, f)
    logger.info(f"Vocabulary saved to {vocab_path}")

# Inference Function
def generate_midi(model, tokenizer, t5_tokenizer, caption, seq_len, output_dir, device, ddim=True, num_steps=100):
    model.eval()
    inputs = t5_tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    with torch.no_grad():
        tokens = model.sample(input_ids, attention_mask, seq_len, num_steps=num_steps, ddim=ddim)

    token_ids = tokens[0].cpu().numpy().tolist()
    if token_ids[0] == tokenizer["BOS_None"]:
        token_ids = token_ids[1:]
    if token_ids[-1] == tokenizer["EOS_None"]:
        token_ids = token_ids[:-1]

    try:
        midi_score = tokenizer.decode(token_ids)
        output_path = os.path.join(output_dir, f"generated_{int(time.time())}.mid")
        midi_score.dump_midi(output_path)
        logger.info(f"Generated MIDI saved to {output_path}")
        return output_path
    except Exception as e:
        logger.error(f"Error converting tokens to MIDI: {str(e)}")
        return None

def main(mode="infer"):
    if mode == "train":
        train_model(CONFIG)
    elif mode == "build_vocab_remi":
        build_vocab_remi(CONFIG)
    elif mode == "infer":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        with open(os.path.join(CONFIG["artifact_folder"], "vocab_remi.pkl"), "rb") as f:
            tokenizer = pickle.load(f)
        model = DiscreteDiffusionModel(
            vocab_size=len(tokenizer),
            d_model=CONFIG["model"]["text2midi_model"]["decoder_d_model"],
            nhead=CONFIG["model"]["text2midi_model"]["decoder_num_heads"],
            num_layers=CONFIG["model"]["text2midi_model"]["decoder_num_layers"],
            dim_feedforward=CONFIG["model"]["text2midi_model"]["decoder_intermediate_size"],
            num_steps=CONFIG["model"]["text2midi_model"]["num_diffusion_steps"],
            device=device
        ).to(device)
        checkpoint_path = os.path.join(CONFIG["training"]["text2midi_model"]["output_dir"], "best_model.bin")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        caption = "A melodic electronic composition with classical influences, featuring a string ensemble, trumpet, brass section, synth strings, and drums. Set in F# minor with a 4/4 time signature, it moves at an Allegro tempo. The mood evokes a cinematic, spacious, and epic atmosphere while maintaining a sense of relaxation."
        generate_midi(
            model=model,
            tokenizer=tokenizer,
            t5_tokenizer=t5_tokenizer,
            caption=caption,
            seq_len=CONFIG["model"]["text2midi_model"]["decoder_max_sequence_length"],
            output_dir=CONFIG["training"]["text2midi_model"]["output_dir"],
            device=device,
            num_steps=200,
            ddim=True
        )
    else:
        raise ValueError(f"Invalid mode: {mode}. Choose 'train', 'build_vocab_remi', or 'infer'")

if __name__ == "__main__":
    mode = os.getenv("MODE", "infer")
    main(mode)

In [None]:
rm -rf /kaggle/working/saved

In [None]:
!pip install flask flask-cors torch transformers mido music21 pretty_midi miditok pyngrok

In [None]:
from kaggle_secrets import UserSecretsClient
ngrok.set_auth_token(UserSecretsClient().get_secret("ngrok_authtoken"))

In [None]:


from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import base64
import torch
import logging
import time
import pickle
from transformers import T5Tokenizer
from pyngrok import ngrok

app = Flask(__name__)
CORS(app)

# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize model, tokenizer, and device
try:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Load REMI tokenizer
    vocab_path = "/kaggle/working/artifacts/vocab_remi.pkl"  # Adjust path
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
    with open(vocab_path, "rb") as f:
        remi_tokenizer = pickle.load(f)
    logger.info("REMI tokenizer loaded")

    # Load T5 tokenizer
    t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
    logger.info("T5 tokenizer loaded")

    # Initialize and load model
    model = DiscreteDiffusionModel(
        vocab_size=len(remi_tokenizer),
        d_model=CONFIG["model"]["text2midi_model"]["decoder_d_model"],
        nhead=CONFIG["model"]["text2midi_model"]["decoder_num_heads"],
        num_layers=CONFIG["model"]["text2midi_model"]["decoder_num_layers"],
        dim_feedforward=CONFIG["model"]["text2midi_model"]["decoder_intermediate_size"],
        num_steps=CONFIG["model"]["text2midi_model"]["num_diffusion_steps"],
        device=device
    ).to(device)
    checkpoint_path = "/kaggle/working/saved/best_model.bin"  # Adjust path
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Model checkpoint not found: {checkpoint_path}")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error(f"Failed to initialize model or tokenizers: {str(e)}")
    raise

@app.route('/generate-midi', methods=['POST'])
def generate_midi_endpoint():
    data = request.get_json()
    prompt = data.get('prompt')

    if not prompt:
        return jsonify({"error": "Prompt is required"}), 400

    try:
        # Generate MIDI
        output_path = generate_midi(
            model=model,
            tokenizer=remi_tokenizer,
            t5_tokenizer=t5_tokenizer,
            caption=prompt,
            seq_len=CONFIG["model"]["text2midi_model"]["decoder_max_sequence_length"],
            output_dir="/kaggle/working",
            device=device,
            num_steps=200,
            ddim=True
        )

        if not output_path or not os.path.exists(output_path):
            return jsonify({"error": "Failed to generate MIDI file"}), 500

        # Read and encode MIDI
        with open(output_path, 'rb') as midi_file:
            midi_base64 = base64.b64encode(midi_file.read()).decode('utf-8')

        # Clean up
        try:
            os.remove(output_path)
            logger.info(f"Cleaned up: {output_path}")
        except Exception as e:
            logger.warning(f"Failed to clean up {output_path}: {str(e)}")

        return jsonify({
            "midi": {
                "data": midi_base64,
                "mimetype": "audio/midi"
            }
        }), 200

    except Exception as e:
        logger.error(f"Error generating MIDI: {str(e)}")
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    # Start ngrok to expose the Flask app
    public_url = ngrok.connect(5000).public_url
    logger.info(f"ngrok tunnel opened at {public_url}")
    
    # Update Flask to run on port 5000
    app.run(host='0.0.0.0', port=5000)