In [2]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from diffusers import AudioLDM2Pipeline
from diffusers.pipelines.pipeline_utils import AudioPipelineOutput
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from peft import LoraConfig, get_peft_model
import librosa
import random
import numpy as np
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn as nn
import pandas as pd
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import scipy

## AudioLDM2의 Unet에 LoRA 구조를 추가하기

In [None]:
repo_id = "cvssp/audioldm2-music"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float32).to(device)
unet = pipeline.unet
noise_scheduler = pipeline.scheduler
vae = pipeline.vae



Loading pipeline components...:   0%|          | 0/11 [00:00<?, ?it/s]

'epsilon'

In [4]:
noise_scheduler.config

FrozenDict([('num_train_timesteps', 1000),
            ('beta_start', 0.0015),
            ('beta_end', 0.0195),
            ('beta_schedule', 'scaled_linear'),
            ('trained_betas', None),
            ('clip_sample', False),
            ('set_alpha_to_one', False),
            ('steps_offset', 1),
            ('prediction_type', 'epsilon'),
            ('thresholding', False),
            ('dynamic_thresholding_ratio', 0.995),
            ('clip_sample_range', 1.0),
            ('sample_max_value', 1.0),
            ('timestep_spacing', 'leading'),
            ('rescale_betas_zero_snr', False),
            ('_class_name', 'DDIMScheduler'),
            ('_diffusers_version', '0.20.0.dev0')])

In [None]:
# 일반적으로 cross-attention block에 LoRA를 적용함
# 이 모델에서 cross-attention layer는 "attn2"이고, 이 module에 LoRA를 적용하면 된다는 것을 파악함.

# 'attn2'를 포함하고, nn.Linear인 모든 모듈 감지
def find_cross_attention_modules(model):
    target_modules = []
    for name, module in model.named_modules():
        if "attn2" in name:
            if isinstance(module, torch.nn.Linear):  # torch.nn.Linear만 선택
                target_modules.append(name)
    return target_modules

# Cross-Attention 블록만 선택
cross_attention_modules = find_cross_attention_modules(unet)

In [5]:
lora_config = LoraConfig(
    r=16, 
    lora_alpha=32,  
    target_modules=cross_attention_modules,  
    lora_dropout=0.1,
    bias="none",  
    task_type="UNET" 
)

peft_unet = get_peft_model(unet, lora_config)
peft_unet.print_trainable_parameters()

trainable params: 3,170,304 || all params: 350,113,800 || trainable%: 0.9055


## Music-Caption Dataset 준비하기 (일단 piano로 진행)

In [6]:
from datasets import load_from_disk

def load_specific_instrument_dataset(instrument_name):
    # 폴더 이름 구성 (저장할 때 사용한 것과 동일한 형식)
    #folder_name = f'/workspace/dataset_by_instrument/{instrument_name}_train_preprocessed_dataset'
    folder_name = f'./{instrument_name}_train_preprocessed_dataset'
    
    # 폴더 존재 여부 확인
    if os.path.exists(folder_name):
        try:
            # 데이터셋 로드
            dataset = load_from_disk(folder_name)
            dataset.set_format(type="torch", columns=["log_mel_spectrogram"], output_all_columns=True) 
            print(f"Dataset for [{instrument_name}] loaded successfully from [{folder_name}]!")
            print(dataset)
            return dataset
        except Exception as e:
            print(f"Error loading dataset from [{folder_name}]: {e}")
            return None
    else:
        print(f"Folder [{folder_name}] does not exist. Please check the instrument name or the directory.")
        return None

In [7]:
#train_dataset = MusicCaptionDataset(load_specific_instrument_dataset("Sound_Piano"))
train_dataset = load_specific_instrument_dataset("Sound_Piano")

#train_dataset.set_format(type='torch', columns=['mel_spectrogram', 'prompt'])

train_loader = DataLoader(
    train_dataset,
    batch_size=2,  # Adjust based on your GPU memory
    shuffle=True,
)

train_dataset[0]

Dataset for [Sound_Piano] loaded successfully from [./Sound_Piano_train_preprocessed_dataset]!
Dataset({
    features: ['log_mel_spectrogram', 'text_caption', 'instrument_class'],
    num_rows: 161
})


{'log_mel_spectrogram': tensor([[[-80.0000, -80.0000, -80.0000,  ..., -36.7057, -38.4878, -34.3133],
          [-80.0000, -80.0000, -80.0000,  ..., -36.4729, -38.2550, -34.0804],
          [-80.0000, -80.0000, -80.0000,  ..., -29.8034, -38.9142, -37.8729],
          ...,
          [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000],
          [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000],
          [-80.0000, -80.0000, -80.0000,  ..., -80.0000, -80.0000, -80.0000]]]),
 'text_caption': '\nThis is an acoustic piano piece. It is an instrumental piece. There is a piano playing a melancholic melody. The atmosphere is lighthearted. This piece could be playing in the background at a coffee shop. \n \n',
 'instrument_class': 'Sound_Piano'}

In [8]:
for i in range(train_dataset.num_rows):
    print(torch.tensor(train_dataset[i]['log_mel_spectrogram']).shape)

  print(torch.tensor(train_dataset[i]['log_mel_spectrogram']).shape)


torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([1, 128, 800])
torch.Size([

## Train하기

In [9]:
# Define optimizer: Only LoRA parameters are trainable
optimizer = AdamW(peft_unet.parameters(), lr=1e-5)

# Define loss function: Mean Squared Error between predicted and target mel spectrograms
loss_fn = nn.MSELoss()

In [None]:
# Training parameters
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    peft_unet.train()
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in progress_bar:
        mel_spectrogram = batch['log_mel_spectrogram'].to(device)  # (batch_size, 1, n_mels, time)
        captions = batch['text_caption']
        instrument_classes = batch['instrument_class']
        
        prompt = [f'[only played by {instrument_class}] {caption}' for (instrument_class, caption)  in zip(instrument_classes, captions)]
        
        # Encode prompts
        with torch.no_grad():
            prompt_embeds, attention_mask_encoded, generated_prompt_embeds = pipeline.encode_prompt(
                prompt=prompt,
                device=device,
                num_waveforms_per_prompt=1,
                do_classifier_free_guidance=False
            )
        
        latent_image = vae.encode(mel_spectrogram).latent_dist.sample()
        latent_image*= vae.config.scaling_factor
        
        max_noise_steps = 1000
        
        timesteps = torch.randint(0, max_noise_steps, (latent_image.size(0),), device=device)
        noise = torch.randn_like(latent_image)
        noisy_latent_image = noise_scheduler.add_noise(latent_image, noise, timesteps)
        
        
        noise_pred = peft_unet(sample = noisy_latent_image, timestep=timesteps, 
                                encoder_hidden_states=generated_prompt_embeds, encoder_hidden_states_1=prompt_embeds, 
                                encoder_attention_mask_1=attention_mask_encoded).sample
        
        # Compute loss
        loss = loss_fn(noise_pred, noise)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})
    
    avg_epoch_loss = epoch_loss / len(train_loader)
    print(f"Average Training Loss: {avg_epoch_loss:.4f}")
    
    # Save model checkpoint after each epoch
    save_path = f"peft_audioldm2_epoch_{epoch + 1}.pt"
    torch.save(peft_unet.state_dict(), save_path)
    print(f"Saved PEFT model checkpoint to {save_path}")

Epoch 1/10


                                                

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 256 n 3200 k 256 mat1_ld 256 mat2_ld 256 result_ld 256 abcType 0 computeType 68 scaleType 0

## LoRA가 적용된 AudioLDM2으로 generate하기

In [None]:
import scipy

In [None]:
peft_unet_path = "peft_audioldm2_epoch_10.pt" # 가장 성능이 좋았던 epoch 선택
peft_unet.load_state_dict(torch.load(peft_unet_path))
peft_unet.eval()

In [19]:
pipeline.scheduler

DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.31.0",
  "beta_end": 0.0195,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.0015,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 1,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}