<a href="https://colab.research.google.com/github/mohammadreza-mohammadi94/Deep-Learning-Projects/blob/main/Rythym_Detection_MIDI_LSTM/Rythym_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Frameworks and Setup Enviorment

In [None]:
# Install required libraries
!pip install -q music21 mido requests
!pip install -q tensorflow

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import numpy as np
import tensorflow as tf
from music21 import converter, note, chord
import mido
import requests
import zipfile
import io
import tarfile

from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout

In [None]:
# Setup warning
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Setup logger
import logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO,
    handlers=[
        logging.FileHandler('log.log'),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(__name__)

# Functions

In [None]:
# !wget http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz

**Download Dataset**

In [None]:
def download_lakh_midi_dataset():
    dataset_path = "lmd_full"
    if not os.path.exists(dataset_path):
        logger.info("Downloading Lakh MIDI Dataset...")
        try:
            response = requests.get("http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz", stream=True)
            response.raise_for_status()  # Check for HTTP errors
            with tarfile.open(fileobj=io.BytesIO(response.content), mode="r:gz") as tar:
                tar.extractall(dataset_path)
            logger.info("Dataset Extracted Successfully.")
        except Exception as e:
            logger.error(f"Failed to Download Or Extract Dataset: {str(e)}")
            raise  # Re-raise to stop execution if download fails
    else:
        logger.info("Dataset Already Exists...")
    return dataset_path

**Extract Rhythm From MIDI**

In [None]:
def extract_rythm_from_midi(file_path):
    logger.info(f"Processing MIDI File: {file_path}")
    try:
        midi = converter.parse(file_path)
        rhythm_sequence = []
        for element in midi.flat.notes:
            if isinstance(element, (note.Note, chord.Chord)):
                duration = element.quarterLength
                rhythm_sequence.append(float(duration))
        logger.debug(f"Extracted rhythm sequence of length: {len(rhythm_sequence)} \
        From {file_path}")
        return rhythm_sequence[:100]
    except Exception as e:
        logger.error(f"Error processing {file_path}: {str(e)}")
        return None

**Preparing Data**

In [None]:
def prepare_dataset(midi_files, sequence_length=50):
    logger.info("Preparing Dataset...")
    X, y = [], []
    for file in midi_files[:100]:
        rhythm = extract_rythm_from_midi(file)
        if rhythm and len(rhythm) >= sequence_length:
            for i in range(0, len(rhythm) - sequence_length):
                X.append(rhythm[i:i + sequence_length])
                y.append(rhythm[i + sequence_length])
    if not X or not y:
        logger.error("No valid data extracted from MIDI files. Check dataset or MIDI parsing.")
        raise ValueError("No valid data to prepare dataset.")
    X = np.array(X)
    y = np.array(y)
    X = X.reshape((X.shape[0], X.shape[1], 1))
    logger.info(f"Dataset prepared: X Shape: {X.shape}, y shape: {y.shape}")
    return X, y

**Build Model**

In [None]:
def build_rnn_model(sequence_length):
    logger.info("Building RNN Model...")
    model = Sequential([
        LSTM(128, input_shape=(sequence_length, 1), return_sequences=True),
        Dropout(0.3),
        LSTM(64),
        Dropout(0.3),
        Dense(32, activation='relu'),
        Dense(1, activation='linear')
    ])
    model.compile(optimizer='adam', loss='mse')
    logger.info("Model built successfully.")
    return model

**Collect MIDI Files**

In [None]:
def get_midi_files(dataset_path):
    logger.info("Collecting MIDI files...")
    midi_files = []
    for root, _, files in os.walk(dataset_path):
        for file in files:
            if file.endswith(('.mid', '.midi')):
                midi_files.append(os.path.join(root, file))
    logger.info(f"Found {len(midi_files)} MIDI files.")
    return midi_files

**Main**

In [None]:
logger.info("Starting rhythm detection....")
# Download dataset
dataset_path = download_lakh_midi_dataset()
midi_files = get_midi_files(dataset_path)

# Preparing data
sequence_length = 50
X,y = prepare_dataset(midi_files, sequence_length)

# Split Train/Test
logger.info("Splitting data into train/test sets...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
                                                    random_state=42)
logger.info(f"Train Set: X: {X_train.shape}, y: {y_train.shape}")
logger.info(f"Test Set: X: {X_test.shape}, y: {y_test.shape}")

# Creating Model
model = build_rnn_model(sequence_length)
model.summary()
logger.info("Training Model...")
history = model.fit(X_train, y_train,
                    epochs=20,
                    batch_size=64,
                    validation_data=(X_test, y_test))
logger.info("Training Completed..")

# Evaluating the model
logger.info("Evaluating model...")
loss = model.evaluate(X_test, y_test)
logger.info(f"Test loss: {loss}")

# Saving the model for future use
model.save("rhythm_detection_model.h5")
logger.info("Model saved as rhythm_detection_model.h5")

ERROR:__main__:Error processing lmd_full/lmd_full/1/18052170c11b1b02209ccc4237b7a8c7.mid: badly formed midi string: missing leading MTrk
