<a href="https://colab.research.google.com/github/gocgodman/M2M/blob/main/MT3_YT_Piano_Render.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================================
# 셀 1: 패키지 설치 및 초기 설정
# ============================================================================

# YourMT3+ 설치
!pip install awscli
!mkdir -p amt
!aws s3 sync s3://amt-deploy-public/amt/ /content/amt --no-sign-request
!aws s3 sync s3://amt-deploy-public/examples/ /content/examples --no-sign-request

# M2M 및 공통 패키지 설치
!pip install piano_transcription_inference pytorch_lightning mir_eval pytube torchcodec
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q yt-dlp gdown gradio pretty_midi librosa numpy soundfile pyfluidsynth scipy
!pip install transformers==4.45.1
!apt-get update -qq
!apt-get install -y -qq ffmpeg fluidsynth sox p7zip-full
!pip install librosa==0.9.2 --upgrade

print("✓ 모든 패키지 설치 완료")

Collecting awscli
  Downloading awscli-1.44.20-py3-none-any.whl.metadata (11 kB)
Collecting botocore==1.42.30 (from awscli)
  Downloading botocore-1.42.30-py3-none-any.whl.metadata (5.9 kB)
Collecting docutils<=0.19,>=0.18.1 (from awscli)
  Downloading docutils-0.19-py3-none-any.whl.metadata (2.7 kB)
Collecting s3transfer<0.17.0,>=0.16.0 (from awscli)
  Downloading s3transfer-0.16.0-py3-none-any.whl.metadata (1.7 kB)
Collecting colorama<0.4.7,>=0.2.5 (from awscli)
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Collecting rsa<4.8,>=3.1.2 (from awscli)
  Downloading rsa-4.7.2-py3-none-any.whl.metadata (3.6 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from botocore==1.42.30->awscli)
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Downloading awscli-1.44.20-py3-none-any.whl (4.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading botocore-1.42.30-py3-none-any.whl (14.6 MB)


In [2]:
# ============================================================================
# 셀 2: Google Drive 마운트 및 경로 설정
# ============================================================================

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, glob, shutil, zipfile, tarfile, subprocess
import gdown

# 기본 경로 설정
WORK_ROOT = "/content/ytd_pipeline_work"
DRIVE_SF2_DIR = "/content/drive/MyDrive/sf2_library"
DRIVE_RESULTS_DIR = "/content/drive/MyDrive/ytd_pipeline_results"
STATE_FILE = os.path.join(DRIVE_RESULTS_DIR, "pipeline_state.json")

os.makedirs(WORK_ROOT, exist_ok=True)
os.makedirs(DRIVE_SF2_DIR, exist_ok=True)
os.makedirs(DRIVE_RESULTS_DIR, exist_ok=True)

print("WORK_ROOT:", WORK_ROOT)
print("DRIVE_SF2_DIR:", DRIVE_SF2_DIR)
print("DRIVE_RESULTS_DIR:", DRIVE_RESULTS_DIR)

# 공유 폴더에서 SF2 다운로드 (선택사항)
SHARED_FOLDER_ID = "1JkTMvPwM_XURqG2114n4Qj0rR83WEucL"
OUT_DIR = "/content/sf2_from_shared"
os.makedirs(OUT_DIR, exist_ok=True)

if SHARED_FOLDER_ID:
    try:
        gdown.download_folder(id=SHARED_FOLDER_ID, output=OUT_DIR, quiet=False)
        print("✓ 공유 폴더 다운로드 완료")
    except Exception as e:
        print("공유 폴더 다운로드 실패:", e)

# 압축 해제 함수
def try_extract_archive(path, dest):
    """압축 파일 자동 해제"""
    path_lower = path.lower()
    os.makedirs(dest, exist_ok=True)
    try:
        if path_lower.endswith(".zip"):
            with zipfile.ZipFile(path, 'r') as zf:
                zf.extractall(dest)
                return True
        if path_lower.endswith((".tar.gz", ".tgz", ".tar")):
            with tarfile.open(path, 'r:*') as tf:
                tf.extractall(dest)
                return True
        if path_lower.endswith(".7z"):
            cmd = ['7z', 'x', '-y', '-o' + dest, path]
            subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            return True
    except Exception as e:
        print("압축 해제 실패:", e)
    return False

# SF2 파일 압축 해제
print("\n[SF2 압축 파일 해제 중...]")
for root, dirs, files in os.walk(DRIVE_SF2_DIR):
    for fn in files:
        if fn.lower().endswith((".zip", ".tar.gz", ".tgz", ".tar", ".7z")):
            extract_dest = os.path.join(root, fn + "_extracted")
            if try_extract_archive(os.path.join(root, fn), extract_dest):
                print(f"  ✓ {fn}")

for root, dirs, files in os.walk(OUT_DIR):
    for fn in files:
        if fn.lower().endswith((".zip", ".tar.gz", ".tgz", ".tar", ".7z")):
            extract_dest = os.path.join(root, fn + "_extracted")
            if try_extract_archive(os.path.join(root, fn), extract_dest):
                print(f"  ✓ {fn}")

# SF2 파일 수집 및 복사
local_sf2_files = glob.glob(os.path.join(DRIVE_SF2_DIR, "**/*.sf2"), recursive=True)
local_names = {os.path.basename(f) for f in local_sf2_files}
shared_sf2_files = glob.glob(os.path.join(OUT_DIR, "**/*.sf2"), recursive=True)

print(f"\n내 드라이브 sf2 수: {len(local_sf2_files)}")
print(f"공유 드라이브 sf2 수: {len(shared_sf2_files)}")

# 공유 폴더에서 새 SF2 복사
copied = []
for p in shared_sf2_files:
    fname = os.path.basename(p)
    if fname in local_names:
        continue
    dest = os.path.join(DRIVE_SF2_DIR, fname)
    try:
        if not os.path.exists(dest):
            shutil.copy(p, dest)
            copied.append(dest)
            print(f"  복사: {fname}")
    except Exception as e:
        print(f"  복사 실패: {fname}, {e}")

print(f"새로 복사된 sf2 수: {len(copied)}")

# 잘못 배치된 SF2 정리
misplaced = glob.glob(os.path.join(DRIVE_RESULTS_DIR, "**/*.sf2"), recursive=True)
if misplaced:
    print(f"\n결과 폴더에 잘못 들어간 SF2 수: {len(misplaced)}")
    for p in misplaced:
        fname = os.path.basename(p)
        dst = os.path.join(DRIVE_SF2_DIR, fname)
        try:
            if not os.path.exists(dst):
                shutil.move(p, dst)
                print(f"  이동: {fname}")
            else:
                os.remove(p)
                print(f"  중복 제거: {fname}")
        except Exception as e:
            print(f"  이동/삭제 실패: {fname}, {e}")
else:
    print("\n결과 폴더에 잘못된 SF2 없음")

# 최종 SF2 목록
final_sf2_files = glob.glob(os.path.join(DRIVE_SF2_DIR, "**/*.sf2"), recursive=True)
print(f"\n✓ 최종 SF2 라이브러리 파일 수: {len(final_sf2_files)}")
for f in final_sf2_files[:20]:
    print(f"  * {os.path.basename(f)}")

if len(final_sf2_files) > 20:
    print(f"  ... 외 {len(final_sf2_files) - 20}개")

Mounted at /content/drive
WORK_ROOT: /content/ytd_pipeline_work
DRIVE_SF2_DIR: /content/drive/MyDrive/sf2_library
DRIVE_RESULTS_DIR: /content/drive/MyDrive/ytd_pipeline_results


Retrieving folder contents


Retrieving folder 11BHbbZws4leslR81OSluH0d7VhSYC4N2 Black Midi Soundfonts
Processing file 1kJoVKGS3ovON0OzNP3bodhlsTSo_J42m Brilliant CFX Concert Grand V2.3.7z
Processing file 17Iei7jUjFgJPrwdgWm22NhE_Ap6mRYN1 Brilliant CFX II Concert Grand V.1.4.7z
Processing file 1RaZkbNjG0UwWofmAHIVSsYUAl-iqCgmM LSP Mixable Concert Grand 1.2.9.7z
Processing file 1I_VeR7G0dHXASFFOLTfz00bhjBi7Ii8m Mustafa Concert Grand V1.0.7z
Processing file 1jL4FDCM7YHtCWh4WdlY_45eYm0Bhjy1O Mustafa F308XP Concert Grand V1.0.7z
Processing file 1gucI6eeNxEPBMdsMpJnXHOxKR-gXKSr- Ordinary D274 Concert Grand V1.0.7z
Processing file 1Q0Ljo55-Nezfupc-UBCPygI7kWk7Ep6H Regal S275 Concert Grand V1.2.7z
Processing file 1zCaOTXDqE_AAEuxcF0FaIKHiNJ43WktR Retroid 8SQ Digital Grand V2.sf2
Processing file 16-o3ntN3CF3K8He3LFUgMMQkGsa2wAcl Retroid D274SQ Digital Grand V1.1.7z
Processing file 1K4-14qdkrA3A0sIFQ50M5-DYk_qRM9X1 Supernova Concert Grand V1.4.rar
Processing file 1M26dnhABD45JHrs_zrY73_x7LQu-gxS0 Zayyan Concert Grand V1.2.

Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From: https://drive.google.com/uc?id=1kJoVKGS3ovON0OzNP3bodhlsTSo_J42m
To: /content/sf2_from_shared/Black Midi Soundfonts/Brilliant CFX Concert Grand V2.3.7z
100%|██████████| 16.5M/16.5M [00:00<00:00, 81.4MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=17Iei7jUjFgJPrwdgWm22NhE_Ap6mRYN1
From (redirected): https://drive.google.com/uc?id=17Iei7jUjFgJPrwdgWm22NhE_Ap6mRYN1&confirm=t&uuid=615fae85-198f-4ba9-a569-73c156ded08b
To: /content/sf2_from_shared/Black Midi Soundfonts/Brilliant CFX II Concert Grand V.1.4.7z
100%|██████████| 152M/152M [00:03<00:00, 46.0MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1RaZkbNjG0UwWofmAHIVSsYUAl-iqCgmM
From (redirected): https://drive.google.com/uc?id=1RaZkbNjG0UwWofmAHIVSsYUAl-iqCgmM&confirm=t&uuid=32ad2a8f-42e0-4e08-b0a9-807d86437abb
To: /content/sf2_from_shared/Black Midi Soundfonts/LSP Mixabl

공유 폴더 다운로드 실패: Failed to retrieve file url:

	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses.
	Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=1G9CUt53d5CrmM8MmUv8ohVGYgWbfd9CO

but Gdown can't. Please check connections and permissions.

[SF2 압축 파일 해제 중...]
  ✓ Kawai MP11SE.7z
  ✓ LSPModel 290 Bosendorfer [Pro] (1.3).7z
  ✓ LSPModel CFX Yamaha [Pro] v1.8.sf2.7z
  ✓ LSPModel M450i Klavins [Pro] v2.1.sf2.7z
  ✓ Zayyan II Concert Grand V1.0.7z
  ✓ Brilliant CFX II Concert Grand V.1.4.7z
  ✓ Regal S275 Concert Grand V1.2.7z
  ✓ Retroid D274SQ Digital Grand V1.1.7z
  ✓ الْبِطِّيْخْ CSII Concert Grand (BETA 0.8).7z
  ✓ Mustafa Concert Grand V1.0.7z
  ✓ Brilliant CFX Concert Grand V2.3.7z
  ✓ LSP Mixable Concert Grand 1.2.9.7z
  ✓ Mustafa F308XP Concert Grand V1.0.7z
  ✓ Ordinary D274 Conce

In [19]:
# @title Model helper
%cd /content/amt/src
from collections import Counter
import argparse
import torch
import numpy as np

from model.init_train import initialize_trainer, update_config
from utils.task_manager import TaskManager
from config.vocabulary import drum_vocab_presets
from utils.utils import str2bool
from utils.utils import Timer
from utils.audio import slice_padded_array
from utils.note2event import mix_notes
from utils.event2note import merge_zipped_note_events_and_ties_to_notes
from utils.utils import write_model_output_as_midi, write_err_cnt_as_json
from model.ymt3 import YourMT3

def load_model_checkpoint(args=None):
    parser = argparse.ArgumentParser(description="YourMT3")
    # General
    parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
    parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
    parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
    parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
    # Model configurations
    parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
    parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
    parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
    parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
    parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
    parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
    parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
    parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
    parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
    parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
    parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
    parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
    parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
    # Perceiver-TF configurations
    parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
    parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
    parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
    parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
    # Decoder configurations
    parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
    # Task and Evaluation configurations
    parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
    parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
    parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
    parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
    parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
    parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
    parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False')
    # Trainer configurations
    parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
    parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
    parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
    parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
    parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
    # Debug
    parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
    parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
    args = parser.parse_args(args)
    # yapf: enable
    if torch.__version__ >= "1.13":
        torch.set_float32_matmul_precision("high")
    args.epochs = None

    # Initialize and update config
    _, _, dir_info, shared_cfg = initialize_trainer(args, stage='test')
    shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')

    if args.eval_drum_vocab != None:  # override eval_drum_vocab
        eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]

    # Initialize task manager
    tm = TaskManager(task_name=args.task,
                     max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
                     debug_mode=args.debug_mode)
    print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")

    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model
    model = YourMT3(
        audio_cfg=audio_cfg,
        model_cfg=model_cfg,
        shared_cfg=shared_cfg,
        optimizer=None,
        task_manager=tm,  # tokenizer is a member of task_manager
        eval_subtask_key=args.eval_subtask_key,
        write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
        ).to(device)
    checkpoint = torch.load(dir_info["last_ckpt_path"], weights_only=False) # fix model loading error in torch 2.6
    state_dict = checkpoint['state_dict']
    new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
    model.load_state_dict(new_state_dict, strict=False)
    return model.eval()


def transcribe(model, audio_info):
    t = Timer()

    # Converting Audio
    t.start()
    audio, sr = torchaudio.load(uri=audio_info['filepath'])
    audio = torch.mean(audio, dim=0).unsqueeze(0)
    audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate'])
    audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz)
    t.stop(); t.print_elapsed_time("converting audio");

    # Inference
    t.start()
    pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments)
    t.stop(); t.print_elapsed_time("model inference");

    # Post-processing
    t.start()
    num_channels = model.task_manager.num_decoding_channels
    n_items = audio_segments.shape[0]
    start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)]
    pred_notes_in_file = []
    n_err_cnt = Counter()
    for ch in range(num_channels):
        pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr]  # (B, L)
        zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches(
            pred_token_arr_ch, start_secs_file, return_events=True)
        pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
        pred_notes_in_file.append(pred_notes_ch)
        n_err_cnt += n_err_cnt_ch
    pred_notes = mix_notes(pred_notes_in_file)  # This is the mixed notes from all channels

    # Write MIDI
    write_model_output_as_midi(pred_notes, '/content/',
                              audio_info['track_name'], model.midi_output_inverse_vocab)
    t.stop(); t.print_elapsed_time("post processing");
    midifile =  os.path.join('/content/model_output/', audio_info['track_name']  + '.mid')
    assert os.path.exists(midifile)
    return midifile

/content/amt/src


In [4]:

# ============================================================================
# 셀 4: GradIO Helper (torchaudio 버전 호환)
# ============================================================================

import subprocess
from typing import Literal, Dict

def prepare_media(source_path_or_url: str,
                  source_type: Literal['audio_filepath', 'youtube_url'],
                  delete_video: bool = True) -> Dict:
    """미디어 준비 (오디오 파일 또는 YouTube) - torchaudio 버전 호환"""

    if source_type == 'audio_filepath':
        audio_file = source_path_or_url
    elif source_type == 'youtube_url':
        # 이 부분은 download_youtube_audio_single로 이미 처리됨
        raise ValueError("youtube_url은 download_youtube_audio_single 함수를 사용하세요")
    else:
        raise ValueError(source_type)

    # torchaudio 버전별 처리
    try:
        # 최신 버전 (2.0+)
        import torchaudio
        metadata = torchaudio.info(audio_file)

        return {
            "filepath": audio_file,
            "track_name": os.path.basename(audio_file).split('.')[0],
            "sample_rate": int(metadata.sample_rate),
            "bits_per_sample": int(metadata.bits_per_sample) if hasattr(metadata, 'bits_per_sample') else 16,
            "num_channels": int(metadata.num_channels),
            "num_frames": int(metadata.num_frames),
            "duration": int(metadata.num_frames / metadata.sample_rate),
            "encoding": str(metadata.encoding).lower() if hasattr(metadata, 'encoding') else "unknown",
        }
    except AttributeError:
        # 구버전 또는 info 없는 경우 - librosa로 대체
        import librosa
        import soundfile as sf

        try:
            # soundfile로 메타데이터 읽기
            info = sf.info(audio_file)

            return {
                "filepath": audio_file,
                "track_name": os.path.basename(audio_file).split('.')[0],
                "sample_rate": int(info.samplerate),
                "bits_per_sample": 16,  # 기본값
                "num_channels": int(info.channels),
                "num_frames": int(info.frames),
                "duration": int(info.duration),
                "encoding": str(info.subtype).lower() if hasattr(info, 'subtype') else "unknown",
            }
        except Exception as e:
            # 최후의 수단 - librosa로 로드
            print(f"soundfile 실패, librosa 사용: {e}")
            y, sr = librosa.load(audio_file, sr=None, mono=False)

            if y.ndim == 1:
                num_channels = 1
                num_frames = len(y)
            else:
                num_channels = y.shape[0]
                num_frames = y.shape[1]

            return {
                "filepath": audio_file,
                "track_name": os.path.basename(audio_file).split('.')[0],
                "sample_rate": int(sr),
                "bits_per_sample": 16,
                "num_channels": num_channels,
                "num_frames": num_frames,
                "duration": int(num_frames / sr),
                "encoding": "unknown",
            }

print("✓ GradIO Helper 로드 완료 (torchaudio 버전 호환)")

✓ GradIO Helper 로드 완료 (torchaudio 버전 호환)


In [5]:
# ============================================================================
# 셀 5: M2M 페달 검출 (Document 2 기반)
# ============================================================================

import librosa
import pretty_midi
from scipy.ndimage import uniform_filter1d, binary_closing

def detect_pedal_rule(audio_path,
                      sr=22050,
                      hop_length=512,
                      low_freq_cut=500,
                      energy_smooth=0.5,
                      on_z=1.0,
                      off_z=0.7,
                      min_event_len=0.06,
                      merge_gap=0.06,
                      closing_size=3):
    """규칙 기반 페달 검출"""

    y, _ = librosa.load(audio_path, sr=sr, mono=True)

    # 에너지 계산
    frame_energy = librosa.feature.rms(y=y, frame_length=2048, hop_length=hop_length)[0]

    # 저주파 에너지
    S = np.abs(librosa.stft(y, n_fft=2048, hop_length=hop_length))
    freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
    low_idx = np.where(freqs <= low_freq_cut)[0]
    low_energy = S[low_idx, :].sum(axis=0) if len(low_idx) > 0 else np.zeros_like(frame_energy)

    # 정규화
    e = frame_energy / (frame_energy.max() + 1e-8)
    le = low_energy / (low_energy.max() + 1e-8) if low_energy.max() > 0 else low_energy

    # 결합 (저주파 에너지 60% + 전체 에너지 40%)
    combined = 0.6 * le + 0.4 * e

    # 스무딩
    window = int(max(1, energy_smooth * (sr / hop_length)))
    combined_smooth = uniform_filter1d(combined, size=window)

    # 임계값 계산 (평균 + Z * 표준편차)
    mu = combined_smooth.mean()
    sigma = combined_smooth.std() + 1e-8
    on_thr = mu + on_z * sigma
    off_thr = mu + off_z * sigma

    # 마스크 생성 (히스테리시스 임계값)
    mask = np.zeros_like(combined_smooth, dtype=bool)
    state = False
    for i, v in enumerate(combined_smooth):
        if not state and v >= on_thr:
            state = True
            mask[i] = True
        elif state:
            mask[i] = True
            if v < off_thr:
                state = False

    # 모폴로지 연산 (작은 구멍 메우기)
    if closing_size > 1:
        mask = binary_closing(mask, structure=np.ones(closing_size))

    # 프레임을 시간으로 변환
    times = librosa.frames_to_time(np.arange(len(mask)), sr=sr, hop_length=hop_length)

    # 이벤트 추출 (연속된 True 구간)
    events = []
    prev = False
    start = None
    for t, m in zip(times, mask):
        if m and not prev:
            start = t
        if (not m) and prev and start is not None:
            events.append((start, t))
            start = None
        prev = m

    if prev and start is not None:
        events.append((start, times[-1]))

    # 필터링 및 병합
    filtered = []
    for s, e in events:
        if (e - s) >= min_event_len:
            if filtered and s - filtered[-1][1] <= merge_gap:
                # 이전 이벤트와 병합
                filtered[-1] = (filtered[-1][0], e)
            else:
                filtered.append((s, e))

    return filtered

def insert_pedal_cc_into_midi(midi_in_path, midi_out_path, pedal_events, piano_program=0):
    """MIDI에 페달 CC64 삽입"""

    pm = pretty_midi.PrettyMIDI(midi_in_path)

    # 피아노 프로그램 설정
    for inst in pm.instruments:
        if not inst.is_drum:
            inst.program = piano_program

    # 대상 악기 선택 (첫 번째 악기 또는 새로 생성)
    target_inst = pm.instruments[0] if pm.instruments else pretty_midi.Instrument(program=piano_program)
    if not pm.instruments:
        pm.instruments.append(target_inst)

    # 페달 CC 추가 (CC#64 = Sustain Pedal)
    for (s, e) in pedal_events:
        on_time = max(0.0, s - 0.02)  # 약간 일찍 페달 밟기
        off_time = e + 0.02            # 약간 늦게 페달 떼기
        target_inst.control_changes.append(
            pretty_midi.ControlChange(number=64, value=127, time=on_time))
        target_inst.control_changes.append(
            pretty_midi.ControlChange(number=64, value=0, time=off_time))

    # CC 정렬 (시간 순서대로)
    for inst in pm.instruments:
        inst.control_changes.sort(key=lambda cc: cc.time)

    pm.write(midi_out_path)

print("✓ M2M 페달 검출 로드 완료")

✓ M2M 페달 검출 로드 완료


In [6]:
# ============================================================================
# 셀 6 수정: YouTube 다운로드 개선 (Bot 우회)
# ============================================================================

import json
import time
import uuid
import traceback
from urllib.parse import urlparse, parse_qs

def render_midi_to_wav(midi_path, wav_out_path, sf2_path, sample_rate=44100, timeout=300):
    """FluidSynth로 MIDI → WAV 렌더링"""

    if sf2_path is None or not os.path.exists(sf2_path):
        return None, "sf2_not_found"

    try:
        cmd = ['fluidsynth', '-ni', sf2_path, midi_path, '-F', wav_out_path, '-r', str(sample_rate)]
        proc = subprocess.run(cmd, check=False, timeout=timeout,
                            stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        log = proc.stdout + "\n" + proc.stderr

        if proc.returncode != 0:
            return None, log

        if not os.path.exists(wav_out_path):
            return None, "fluidsynth finished but wav not created\n" + log

        return wav_out_path, log

    except subprocess.TimeoutExpired as e:
        return None, f"fluidsynth timeout: {e}"
    except Exception as e:
        return None, f"fluidsynth exception: {e}\n{traceback.format_exc()}"

def run_cmd(cmd, check=False, timeout=None):
    """명령어 실행 헬퍼"""
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                         text=True, timeout=timeout)
    if check and proc.returncode != 0:
        raise RuntimeError(f"Command failed: {' '.join(cmd)}\nSTDOUT:{proc.stdout}\nSTDERR:{proc.stderr}")
    return proc.returncode, proc.stdout, proc.stderr

def remove_extension(filepath):
    """확장자 제거"""
    return ".".join(os.path.basename(filepath).split('.')[:-1])

def check_tool(name):
    """도구 설치 확인"""
    code, out, err = run_cmd(["which", name])
    return code == 0

def load_state():
    """상태 파일 로드"""
    if os.path.exists(STATE_FILE):
        with open(STATE_FILE, "r", encoding="utf-8") as f:
            return json.load(f)
    return {"processed": [], "failed": [], "items": {}, "last_update": None}

def save_state(state):
    """상태 파일 저장"""
    state["last_update"] = time.time()
    with open(STATE_FILE, "w", encoding="utf-8") as f:
        json.dump(state, f, ensure_ascii=False, indent=2)

def resolve_sf2_path(sf2_choice, uploaded_sf2=None, sf2_library_dir=DRIVE_SF2_DIR):
    """SF2 경로 해석 (우선순위: 절대경로 > 업로드 > 라이브러리)"""

    # 1) 절대 경로 지정
    if sf2_choice and sf2_choice != "None" and os.path.exists(sf2_choice):
        return sf2_choice

    # 2) 업로드된 파일
    if uploaded_sf2 and isinstance(uploaded_sf2, list) and len(uploaded_sf2) > 0:
        up = uploaded_sf2[0]
        src = up['name'] if isinstance(up, dict) and 'name' in up else up
        if os.path.exists(src):
            os.makedirs(sf2_library_dir, exist_ok=True)
            dst = os.path.join(sf2_library_dir, os.path.basename(src))
            if not os.path.exists(dst):
                shutil.copy(src, dst)
            return dst

    # 3) 라이브러리 내 첫 번째 파일
    if os.path.exists(sf2_library_dir):
        candidates = sorted([os.path.join(sf2_library_dir, f)
                           for f in os.listdir(sf2_library_dir)
                           if f.endswith('.sf2')])
        if candidates:
            return candidates[0]

    return None

def expand_playlist_to_video_urls(playlist_url):
    """재생목록을 비디오 URL 목록으로 확장"""
    cmd = ["yt-dlp", "--flat-playlist", "-J", playlist_url]
    code, out, err = run_cmd(cmd)
    if code != 0:
        raise RuntimeError(f"yt-dlp playlist expand failed: {err}")
    j = json.loads(out)
    entries = j.get("entries", [])
    urls = []
    for e in entries:
        vid = e.get("id")
        if vid:
            urls.append(f"https://www.youtube.com/watch?v={vid}")
    return urls

def download_youtube_audio_single(url, outdir, fmt="mp3", audio_bitrate="192k", retries=3):
    """
    YouTube 단일 영상 오디오 다운로드 (Bot 우회 개선)

    Args:
        url: YouTube URL
        outdir: 출력 디렉터리
        fmt: 포맷 (기본: mp3)
        audio_bitrate: 비트레이트
        retries: 재시도 횟수
    """
    os.makedirs(outdir, exist_ok=True)
    out_template = os.path.join(outdir, "%(id)s.%(ext)s")  # ID 기반 파일명

    # Bot 우회를 위한 yt-dlp 옵션
    cmd = [
        "yt-dlp",
        "-x",  # 오디오만 추출
        "--audio-format", fmt,
        "--audio-quality", audio_bitrate,
        "-o", out_template,
        "--no-warnings",
        "--quiet",  # 경고 줄이기
        "--no-playlist",  # 재생목록 비활성화
        "--extractor-retries", str(retries),
        "--socket-timeout", "30",
        # User-Agent 설정
        "--user-agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
        # 속도 제한 (너무 빠르면 차단됨)
        "--sleep-interval", "2",
        "--max-sleep-interval", "5",
        url
    ]

    for attempt in range(retries):
        try:
            print(f"  다운로드 시도 {attempt + 1}/{retries}...")
            code, out, err = run_cmd(cmd, timeout=120)

            if code == 0:
                # 다운로드된 파일 찾기
                files = sorted(glob.glob(os.path.join(outdir, f"*.{fmt}")))
                if files:
                    return files[-1]

            # 실패 시 대기
            if attempt < retries - 1:
                wait_time = (attempt + 1) * 5  # 5초, 10초, 15초...
                print(f"  실패. {wait_time}초 대기 중...")
                time.sleep(wait_time)

        except Exception as e:
            print(f"  다운로드 예외 (시도 {attempt + 1}): {e}")
            if attempt < retries - 1:
                time.sleep((attempt + 1) * 5)

    # 모든 재시도 실패
    raise RuntimeError(f"YouTube 다운로드 실패 (모든 재시도 소진): {url}\n"
                      f"YouTube가 Colab IP를 차단했을 가능성이 있습니다.\n"
                      f"해결 방법: 오디오 파일을 직접 업로드하거나 나중에 다시 시도하세요.")

def extract_video_id_from_url(url):
    """URL에서 YouTube 비디오 ID 추출"""
    try:
        p = urlparse(url)
        if p.hostname and "youtube" in p.hostname:
            qs = parse_qs(p.query)
            if "v" in qs and qs["v"]:
                return qs["v"][0]
        if p.hostname and "youtu.be" in p.hostname:
            return p.path.lstrip("/")
    except Exception:
        return None
    return None

print("✓ FluidSynth 렌더링 및 유틸리티 로드 완료 (YouTube Bot 우회 개선)")

✓ FluidSynth 렌더링 및 유틸리티 로드 완료 (YouTube Bot 우회 개선)


In [15]:
# @title Load Checkpoint
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
precision = 'bf16-mixed' # @param ["32", "bf16-mixed", "16"]
project = '2024'

if model_name == "YMT3+":
    checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
    args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
    checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
    args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
            '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
    checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
    args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
            '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
            '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
    checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
    args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
    checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
    args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
    raise ValueError(model_name)

model = load_model_checkpoint(args=args)

INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..


Resuming from ../logs/2024/mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops/checkpoints/last.ckpt
Task: mc13_full_plus_256, Max Shift Steps: 206
"add_melody_metric_to_singing": True
"add_pitch_class_metric":       None
"audio_cfg":                    {'codec': 'spec', 'hop_length': 300, 'audio_backend': 'torchaudio', 'sample_rate': 16000, 'input_frames': 32767, 'n_fft': 2048, 'n_mels': 512, 'f_min': 50.0, 'f_max': 8000.0}
"base_lr":                      None
"eval_drum_vocab":              None
"eval_subtask_key":             default
"eval_vocab":                   None
"init_factor":                  None
"max_steps":                    None
"model_cfg":                    {'encoder_type': 'perceiver-tf', 'decoder_type': 'multi-t5', 'pre_encoder_type': 'conv', 'pre_encoder_type_default': {'t5': None, 'perceiver-tf': 'conv', 'conformer': None}, 'pre_decoder_type': 'linear', 'pre_decoder_type_default': {'t5': {'t5': None}, 'perceiver-tf': {'t5': 'linear', 'multi-t5': '

In [16]:

# ============================================================================
# 셀 8: 통합 파이프라인 함수 (오류 추적 강화)
# ============================================================================

def process_item_from_url_or_path(
    item,
    sf2_choice=None,
    uploaded_sf2=None,
    use_pedal_model=False,
    pedal_model_path=None,
    tmp_dir=None,
    sustain_tolerance=0.2,
    pedal_energy_smooth=0.5,
    pedal_low_freq_cut=500,
    pedal_on_threshold=1.0,
    pedal_off_threshold=0.7,
    pedal_min_event_len=0.08,
    pedal_merge_gap=0.08
):
    """
    단일 항목 처리: YouTube URL 또는 로컬 오디오 파일
    YourMT3+ 전사 → M2M 페달 검출 → FluidSynth 렌더링
    """

    if tmp_dir is None:
        tmp_dir = os.path.join(WORK_ROOT, "tmp_" + uuid.uuid4().hex)
    os.makedirs(tmp_dir, exist_ok=True)

    result = {
        "status": "error",
        "mp3": None,
        "midi": None,
        "wav": None,
        "pedals": [],
        "transcribe_log": None,
        "render_log": None,
        "error": None,
        "trace": None
    }

    try:
        # 1) 오디오 확보
        if isinstance(item, str) and item.startswith("http"):
            print(f"[다운로드] YouTube: {item[:50]}...")
            mp3 = download_youtube_audio_single(item, tmp_dir, fmt="mp3")
        else:
            mp3 = item
        result["mp3"] = mp3

        # 2) YourMT3+ 전사
        print(f"[전사] YourMT3+: {os.path.basename(mp3)}")

        try:
            # prepare_media 호출
            audio_info = prepare_media(mp3, source_type='audio_filepath')
            print(f"  오디오 정보: {audio_info['sample_rate']}Hz, {audio_info['duration']}초")
        except Exception as e:
            print(f"✗ prepare_media 실패: {e}")
            raise

        try:
            # 전사 실행
            print(f"  전사 시작...")
            mid = transcribe(model, audio_info)
            print(f"  전사 완료: {mid}")
        except Exception as e:
            print(f"✗ transcribe 실패: {e}")
            import traceback
            traceback.print_exc()
            raise

        result["midi"] = mid
        result["transcribe_log"] = "YourMT3+ transcription completed"

        # 3) 페달 검출
        print(f"[페달] M2M 검출 중...")
        try:
            pedals = detect_pedal_rule(
                mp3,
                energy_smooth=pedal_energy_smooth,
                low_freq_cut=pedal_low_freq_cut,
                on_z=pedal_on_threshold,
                off_z=pedal_off_threshold,
                min_event_len=pedal_min_event_len,
                merge_gap=pedal_merge_gap
            )
            result["pedals"] = pedals
            print(f"  → {len(pedals)}개 페달 이벤트 검출")
        except Exception as e:
            print(f"✗ 페달 검출 실패: {e}")
            raise

        # 4) MIDI에 CC 삽입
        print(f"[페달] CC64 삽입 중...")
        try:
            mid_pedal = os.path.join(tmp_dir, os.path.splitext(os.path.basename(mid))[0] + "_pedal.mid")
            insert_pedal_cc_into_midi(mid, mid_pedal, pedals, piano_program=0)
            result["midi"] = mid_pedal
        except Exception as e:
            print(f"✗ CC64 삽입 실패: {e}")
            raise

        # 5) Drive에 저장 (파일명 충돌 방지)
        print(f"[저장] Drive에 저장 중...")
        try:
            unique_id = uuid.uuid4().hex[:6]

            saved_mp3 = os.path.join(DRIVE_RESULTS_DIR,
                                     f"{remove_extension(os.path.basename(mp3))}_{unique_id}.mp3")
            shutil.copy(mp3, saved_mp3)
            result["mp3"] = saved_mp3

            saved_midi = os.path.join(DRIVE_RESULTS_DIR,
                                      f"{remove_extension(os.path.basename(mid_pedal))}_{unique_id}.mid")
            shutil.copy(mid_pedal, saved_midi)
            result["midi"] = saved_midi
            print(f"  → MIDI 저장: {os.path.basename(saved_midi)}")
        except Exception as e:
            print(f"✗ 저장 실패: {e}")
            raise

        # 6) FluidSynth 렌더링
        chosen_sf2 = resolve_sf2_path(sf2_choice, uploaded_sf2)
        saved_wav = None
        render_log = None

        if chosen_sf2:
            print(f"[렌더] FluidSynth: {os.path.basename(chosen_sf2)}")
            try:
                wav_out = os.path.join(tmp_dir, f"{remove_extension(os.path.basename(mid_pedal))}.wav")
                rendered, render_log = render_midi_to_wav(mid_pedal, wav_out, chosen_sf2)
                result["render_log"] = render_log

                if rendered:
                    saved_wav = os.path.join(DRIVE_RESULTS_DIR,
                                            f"{remove_extension(os.path.basename(rendered))}_{unique_id}.wav")
                    shutil.copy(rendered, saved_wav)
                    result["wav"] = saved_wav
                    print(f"  → 렌더링 완료: {os.path.basename(saved_wav)}")
                else:
                    print(f"  → 렌더링 실패")
            except Exception as e:
                print(f"✗ 렌더링 예외: {e}")
                # 렌더링 실패는 치명적이지 않음
        else:
            print(f"[렌더] SF2 없음, 렌더링 건너뜀")

        result["status"] = "ok"
        print(f"✓ 처리 완료: {os.path.basename(mp3)}")
        return result

    except Exception as e:
        result["error"] = str(e)
        result["trace"] = traceback.format_exc()
        print(f"✗ 처리 실패: {e}")
        print(result["trace"])
        return result

def process_files(
    gr_files,
    sf2_choice,
    uploaded_sf2,
    sustain_tolerance=0.2,
    pedal_energy_smooth=0.5,
    pedal_low_freq_cut=500,
    pedal_on_threshold=1.0,
    pedal_off_threshold=0.7,
    pedal_min_event_len=0.08,
    pedal_merge_gap=0.08,
    **kwargs
):
    """Gradio 파일 업로드 처리"""

    chosen_sf2 = resolve_sf2_path(sf2_choice, uploaded_sf2)

    if not gr_files:
        return "파일을 선택하세요", None, None

    work_dir = "/tmp/transcribe_" + uuid.uuid4().hex
    os.makedirs(work_dir, exist_ok=True)

    local_paths = []
    for f in gr_files:
        if isinstance(f, dict) and 'name' in f:
            local_paths.append(f['name'])
        else:
            local_paths.append(f)

    midi_outs = []
    wav_outs = []

    for audio_path in local_paths:
        res = process_item_from_url_or_path(
            audio_path,
            sf2_choice=chosen_sf2,
            uploaded_sf2=None,
            pedal_energy_smooth=pedal_energy_smooth,
            pedal_low_freq_cut=pedal_low_freq_cut,
            pedal_on_threshold=pedal_on_threshold,
            pedal_off_threshold=pedal_off_threshold,
            pedal_min_event_len=pedal_min_event_len,
            pedal_merge_gap=pedal_merge_gap
        )

        if res.get("status") != "ok":
            return f"처리 실패: {res.get('error')}\nTrace:\n{res.get('trace')}", None, None

        midi_outs.append(res.get("midi"))
        if res.get("wav"):
            wav_outs.append(res.get("wav"))

    # 결과 반환
    if len(midi_outs) == 1:
        final_midi = midi_outs[0]
    else:
        # 여러 파일 → ZIP으로 압축
        zip_path = os.path.join(work_dir, "results_with_pedal.zip")
        with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
            for p in midi_outs:
                zf.write(p, arcname=os.path.basename(p))
        final_midi = zip_path

    wav_path = wav_outs[-1] if wav_outs else None
    status_msg = f"✓ {len(midi_outs)}개 파일 처리 완료"

    return status_msg, final_midi, wav_path

def playlist_pipeline_generator(playlist_text, sf2_choice_path, **kwargs):
    """재생목록 배치 처리 제너레이터"""

    lines = [ln.strip() for ln in str(playlist_text).splitlines() if ln.strip()]
    items = []

    # 재생목록 확장
    for ln in lines:
        if ln.startswith("http") and ("playlist" in ln or "list=" in ln):
            yield f"재생목록 확장 중: {ln[:50]}...", None
            try:
                urls = expand_playlist_to_video_urls(ln)
                if not urls:
                    yield f"재생목록에서 URL을 찾지 못했습니다: {ln}", None
                else:
                    items.extend(urls)
                    yield f"재생목록에서 {len(urls)}개 항목 발견", None
            except Exception as e:
                yield f"재생목록 확장 실패: {e}", None
        else:
            items.append(ln)

    total = len(items)
    if total == 0:
        yield "처리할 항목이 없습니다.", None
        return

    # 상태 로드
    try:
        state = load_state()
    except Exception as e:
        yield f"상태 파일 로드 실패: {e}", None
        return

    yield f"총 항목: {total}. 이미 처리된 항목: {len(state.get('processed', []))}", None

    # 각 항목 처리
    for idx, item in enumerate(items, start=1):
        # 이미 처리됨 확인
        if item in state.get("processed", []):
            yield f"[{idx}/{total}] 건너뜀(이미 처리됨): {item[:50]}...", None
            continue

        yield f"[{idx}/{total}] 다운로드/전사 시작: {item[:50]}...", None

        try:
            res = process_item_from_url_or_path(
                item,
                sf2_choice=sf2_choice_path,
                uploaded_sf2=None,
                **kwargs
            )
        except Exception as e:
            res = {"status": "error", "error": str(e)}

        if res.get("status") == "ok":
            state.setdefault("processed", []).append(item)
            state.setdefault("items", {})[item] = {
                "mp3": res.get("mp3"),
                "midi": res.get("midi"),
                "wav": res.get("wav"),
                "pedals": len(res.get("pedals", [])),
                "time": time.time()
            }
            try:
                save_state(state)
            except Exception as e:
                yield f"[{idx}/{total}] 완료했으나 상태 저장 실패: {e}", res
                continue

            yield f"[{idx}/{total}] ✓ 완료: {os.path.basename(res.get('midi', 'N/A'))}", res
        else:
            state.setdefault("failed", []).append({"item": item, "error": res.get("error")})
            try:
                save_state(state)
            except Exception as e:
                yield f"[{idx}/{total}] 실패: 에러 저장 실패: {e}", None
                continue

            yield f"[{idx}/{total}] ✗ 실패: {res.get('error')}", None

    yield f"모든 항목 처리 완료. 결과 위치: {DRIVE_RESULTS_DIR}", None

print("✓ 통합 파이프라인 함수 로드 완료 (오류 추적 강화)")

✓ 통합 파이프라인 함수 로드 완료 (오류 추적 강화)


In [18]:
# ============================================================================
# 셀 9: Gradio UI
# ============================================================================

import gradio as gr

# SF2 목록 생성
def list_sf2_choices(sf2_dir=DRIVE_SF2_DIR):
    """SF2 파일 목록을 Gradio Dropdown 형식으로 반환"""
    if not os.path.exists(sf2_dir):
        return [("None", "None")]
    files = sorted(glob.glob(os.path.join(sf2_dir, "**/*.sf2"), recursive=True))
    if not files:
        return [("None", "None")]
    choices = [("None", "None")] + [(os.path.basename(p), p) for p in files]
    return choices

sf2_choices = list_sf2_choices()
print(f"사용 가능한 SF2: {len(sf2_choices) - 1}개")

# UI 콜백 함수
def on_run_upload(gr_files, sf2_choice_val, uploaded_sf2_val,
                  pedal_energy_smooth, pedal_low_freq_cut,
                  pedal_on_threshold, pedal_off_threshold,
                  pedal_min_event_len, pedal_merge_gap,
                  sf2_direct_input):
    """업로드 파일 처리 콜백"""
    try:
        # SF2 경로 결정
        chosen_sf2 = sf2_direct_input if sf2_direct_input else sf2_choice_val
        chosen_sf2 = resolve_sf2_path(chosen_sf2, uploaded_sf2_val)
        chosen_name = os.path.basename(chosen_sf2) if chosen_sf2 else "없음"

        print(f"\n{'='*60}")
        print(f"업로드 처리 시작")
        print(f"SF2: {chosen_name}")
        print(f"{'='*60}\n")

        final_midi, wav_path = process_files(
            gr_files,
            sf2_choice=chosen_sf2,
            uploaded_sf2=None,
            pedal_energy_smooth=pedal_energy_smooth,
            pedal_low_freq_cut=pedal_low_freq_cut,
            pedal_on_threshold=pedal_on_threshold,
            pedal_off_threshold=pedal_off_threshold,
            pedal_min_event_len=pedal_min_event_len,
            pedal_merge_gap=pedal_merge_gap
        )

        return final_midi, wav_path

    except Exception as e:
        tb = traceback.format_exc()
        print(f"업로드 처리 중 오류:\n{tb}")
        return None, None

def playlist_wrapper(playlist_text, sf2_choice_val, uploaded_sf2_val,
                    pedal_energy_smooth, pedal_low_freq_cut,
                    pedal_on_threshold, pedal_off_threshold,
                    pedal_min_event_len, pedal_merge_gap,
                    sf2_direct_input):
    """재생목록 처리 래퍼"""
    try:
        chosen_sf2 = sf2_direct_input if sf2_direct_input else sf2_choice_val
        chosen_sf2 = resolve_sf2_path(chosen_sf2, uploaded_sf2_val)

        yield from playlist_pipeline_generator(
            playlist_text,
            chosen_sf2,
            pedal_energy_smooth=pedal_energy_smooth,
            pedal_low_freq_cut=pedal_low_freq_cut,
            pedal_on_threshold=pedal_on_threshold,
            pedal_off_threshold=pedal_off_threshold,
            pedal_min_event_len=pedal_min_event_len,
            pedal_merge_gap=pedal_merge_gap
        )
    except Exception as e:
        yield f"재생목록 처리 중 예외 발생: {e}\n{traceback.format_exc()}", None

# Gradio Blocks UI
with gr.Blocks(title="YourMT3+ & M2M 통합 파이프라인") as demo:
    gr.Markdown("# 🎹 YourMT3+ & M2M 통합 파이프라인")
    gr.Markdown("**YourMT3+ 전사 + M2M 페달 검출 + FluidSynth 렌더링**")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### SoundFont 설정")
            sf2_dropdown = gr.Dropdown(
                choices=sf2_choices,
                value=sf2_choices[0][1] if sf2_choices else "None",
                label="드라이브 SF2 선택"
            )
            sf2_upload = gr.Files(
                file_types=[".sf2"],
                label="또는 SF2 업로드"
            )
            sf2_direct = gr.Textbox(
                label="(고급) SF2 절대경로 직접 입력",
                placeholder="/content/drive/MyDrive/sf2_library/piano.sf2"
            )
            gr.Markdown("**우선순위:** 직접입력 > 업로드 > 드롭다운")

        with gr.Column(scale=1):
            gr.Markdown("### 페달 검출 파라미터")
            pedal_smooth = gr.Slider(0.1, 1.0, value=0.5, step=0.05,
                                    label="에너지 스무딩")
            pedal_lowcut = gr.Slider(200, 2000, value=500, step=50,
                                    label="저주파 컷오프 (Hz)")
            pedal_on = gr.Slider(0.5, 2.0, value=1.0, step=0.1,
                                label="페달 ON 임계값 (Z)")
            pedal_off = gr.Slider(0.3, 1.5, value=0.7, step=0.05,
                                 label="페달 OFF 임계값 (Z)")
            pedal_min = gr.Slider(0.02, 0.5, value=0.08, step=0.01,
                                 label="최소 이벤트 길이 (초)")
            pedal_merge = gr.Slider(0.01, 0.3, value=0.08, step=0.01,
                                   label="병합 간격 (초)")

    with gr.Tabs():
        # ===== 탭 1: 파일 업로드 =====
        with gr.TabItem("📁 파일 업로드"):
            gr.Markdown("### 오디오 파일을 업로드하여 처리")

            upload_files = gr.Files(
                file_types=[".wav", ".mp3", ".flac", ".m4a"],
                label="오디오 파일 선택"
            )

            run_upload_btn = gr.Button("🎵 처리 시작", variant="primary", size="lg")

            with gr.Row():
                upload_result_file = gr.File(label="📥 결과 MIDI 다운로드")
                upload_preview = gr.Audio(label="🔊 WAV 미리듣기", type="filepath")

            run_upload_btn.click(
                fn=on_run_upload,
                inputs=[
                    upload_files, sf2_dropdown, sf2_upload,
                    pedal_smooth, pedal_lowcut,
                    pedal_on, pedal_off,
                    pedal_min, pedal_merge,
                    sf2_direct
                ],
                outputs=[upload_result_file, upload_preview]
            )

        # ===== 탭 2: 재생목록 처리 =====
        with gr.TabItem("📺 YouTube 재생목록"):
            gr.Markdown("### YouTube 재생목록 또는 개별 영상 URL")

            playlist_input = gr.Textbox(
                lines=6,
                label="YouTube URL (줄바꿈으로 여러 개 입력 가능)",
                placeholder="https://www.youtube.com/watch?v=...\nhttps://www.youtube.com/playlist?list=..."
            )

            run_playlist_btn = gr.Button("🚀 재생목록 처리 시작", variant="primary", size="lg")

            playlist_log = gr.Textbox(
                label="📋 진행 로그",
                lines=10,
                max_lines=20
            )
            playlist_last = gr.JSON(label="📊 마지막 처리 결과")

            run_playlist_btn.click(
                fn=playlist_wrapper,
                inputs=[
                    playlist_input, sf2_dropdown, sf2_upload,
                    pedal_smooth, pedal_lowcut,
                    pedal_on, pedal_off,
                    pedal_min, pedal_merge,
                    sf2_direct
                ],
                outputs=[playlist_log, playlist_last]
            )

    # 하단 안내
    gr.Markdown(f"""
    ---
    ### 📁 결과 저장 위치
    - **Google Drive**: `{DRIVE_RESULTS_DIR}`
    - **상태 파일**: `{STATE_FILE}`

    ### ℹ️ 사용 방법
    1. **SF2 설정**: 상단에서 SoundFont 선택 (렌더링 필요시)
    2. **파일 업로드**: 로컬 오디오 파일 처리
    3. **재생목록**: YouTube URL을 입력하여 배치 처리

    ### 🎯 파이프라인 흐름
```
    오디오 입력 → YourMT3+ 전사 → M2M 페달 검출 → CC64 삽입 → FluidSynth 렌더링 → Drive 저장
```

    ### 🔧 모델 정보
    - **전사 엔진**: {model_name}
    - **페달 검출**: M2M 규칙 기반
    - **렌더러**: FluidSynth
    """)

# UI 실행
print("\n" + "="*60)
print("Gradio UI 시작 중...")
print("="*60 + "\n")

demo.launch(
    share=True,  # 공유 링크 생성
    debug=True,
    show_error=True
)

사용 가능한 SF2: 15개

Gradio UI 시작 중...

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://9d13ce1b7bfeea4e81.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


[다운로드] YouTube: https://www.youtube.com/watch?v=rA_2B7Yj4QE...
  다운로드 시도 1/3...
[전사] YourMT3+: rA_2B7Yj4QE.mp3
  오디오 정보: 48000Hz, 286초
  전사 시작...
⏰ converting audio: 0m 1s 843.10ms
✗ transcribe 실패: 'NoneType' object is not callable
✗ 처리 실패: 'NoneType' object is not callable
Traceback (most recent call last):
  File "/tmp/ipython-input-359947423.py", line 64, in process_item_from_url_or_path
    mid = transcribe(model, audio_info)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-2239082707.py", line 135, in transcribe
    pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/amt/src/model/ymt3.py", line 566, in inference_file
    preds = self.inference(x, task_tokens).detach().cpu().numpy()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/amt/src/model/ymt3.py", line 486, in inference
    enc_hs = self.pre_decoder(enc_hs)  # (B, task_l

Traceback (most recent call last):
  File "/tmp/ipython-input-359947423.py", line 64, in process_item_from_url_or_path
    mid = transcribe(model, audio_info)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-2239082707.py", line 135, in transcribe
    pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/amt/src/model/ymt3.py", line 566, in inference_file
    preds = self.inference(x, task_tokens).detach().cpu().numpy()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/amt/src/model/ymt3.py", line 486, in inference
    enc_hs = self.pre_decoder(enc_hs)  # (B, task_len + 256, 512)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/p

Keyboard interruption in main thread... closing server.
  실패. 5초 대기 중...


KeyboardInterrupt: 

In [17]:
# 테스트 셀 (임시로 실행)
print("모델 확인:")
print(f"  - model 타입: {type(model)}")
print(f"  - pre_decoder 타입: {type(model.pre_decoder)}")
print(f"  - pre_decoder: {model.pre_decoder}")

# Sequential 내부 확인
if hasattr(model.pre_decoder, '__iter__'):
    for i, layer in enumerate(model.pre_decoder):
        print(f"    [{i}] {type(layer).__name__}: {layer}")
# 모델 설정 확인
print("=" * 60)
print("모델 설정 확인")
print("=" * 60)
print(f"encoder.num_latents: {model.encoder.num_latents if hasattr(model.encoder, 'num_latents') else 'N/A'}")
print(f"encoder 타입: {type(model.encoder).__name__}")

# audio_cfg 확인
print(f"\naudio_cfg:")
for k, v in model.audio_cfg.items():
    print(f"  {k}: {v}")

모델 확인:
  - model 타입: <class 'model.ymt3.YourMT3'>
  - pre_decoder 타입: <class 'torch.nn.modules.container.Sequential'>
  - pre_decoder: Sequential(
  (0): None
)
    [0] NoneType: None
모델 설정 확인
encoder.num_latents: N/A
encoder 타입: PerceiverTFEncoder

audio_cfg:
  codec: spec
  hop_length: 300
  audio_backend: torchaudio
  sample_rate: 16000
  input_frames: 32767
  n_fft: 2048
  n_mels: 512
  f_min: 50.0
  f_max: 8000.0
