# Imports

In [None]:
#pip install basic-pitch pretty_midi frechet_audio_distance

In [None]:
import sys
argv_backup = sys.argv
sys.argv = ['']
from frechet_audio_distance import FrechetAudioDistance 
from gtt.eval import create_eval_baseline_dir, test_model
sys.argv  = argv_backup

In [None]:
import time
import os
import torch
import torchaudio
import numpy as np
from torch.utils.data import DataLoader

from gtt.dataloader import GttDataset
from gtt.model import GttNet
from gtt.utilities.utils import train_test_split, get_mean_std_loudness
from gtt.train import train_epoch, train_loop
from gtt.eval import mfcc_fro_distance, loudness_l2_distance
from PolyDDSP.modules.losses import SpectralLoss

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Construct Baseline DIRs for FAD comparison

## GuitarSet

In [None]:
source_dir = 'Data/GuitarSet/audio_pickup'
target_dir = 'Data/GuitarSet/eval_fad'

split = train_test_split(source_dir, train_pct=1, valid_pct=0)

files = split['train'] 

file_ds = GttDataset(audio_dir = source_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             device='cpu')

file_dl = DataLoader(file_ds, 1, shuffle=False)

file_count = 0

for batch in file_dl:
    fname = '{}.wav'.format(file_count)
    fpath = os.path.join(target_dir,fname)
    audio = batch['audio']
    
    torchaudio.save(fpath, audio, 22050)
    
    file_count += 1

## EGDB

In [None]:
source_dir = 'Data/EGDB/audio'
target_dir = 'Data/EGDB/eval_fad'

split = train_test_split(source_dir, train_pct=1, valid_pct=0)

files = split['train'] 

file_ds = GttDataset(audio_dir = source_dir,
                             midi_dir = '',
                             file_list = files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             device='cpu')

file_dl = DataLoader(file_ds, 1, shuffle=False)

file_count = 0

for batch in file_dl:
    fname = '{}.wav'.format(file_count)
    fpath = os.path.join(target_dir,fname)
    audio = batch['audio']
    
    torchaudio.save(fpath, audio, 22050)
    
    file_count += 1

## SynthTab Acoustic

In [None]:
source_dir = 'Data/synthtab_acoustic/audio'
target_dir = 'Data/synthtab_acoustic/eval_fad'

split = train_test_split(source_dir, train_pct=1, valid_pct=0)

files = split['train'] 

file_ds = GttDataset(audio_dir = source_dir,
                             midi_dir = '',
                             file_list = files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             device='cpu')

file_dl = DataLoader(file_ds, 1, shuffle=False)

file_count = 0

for batch in file_dl:
    fname = '{}.wav'.format(file_count)
    fpath = os.path.join(target_dir,fname)
    audio = batch['audio']
    
    torchaudio.save(fpath, audio, 22050)
    
    file_count += 1

## SynthTab Electric

In [None]:
source_dir = 'Data/synthtab_electric/audio'
target_dir = 'Data/synthtab_electric/eval_fad'

split = train_test_split(source_dir, train_pct=1, valid_pct=0)

files = split['train'] 

file_ds = GttDataset(audio_dir = source_dir,
                             midi_dir = '',
                             file_list = files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             device='cpu')

file_dl = DataLoader(file_ds, 1, shuffle=False)

file_count = 0

for batch in file_dl:
    fname = '{}.wav'.format(file_count)
    fpath = os.path.join(target_dir,fname)
    audio = batch['audio']
    
    torchaudio.save(fpath, audio, 22050)
    
    file_count += 1

# Guitarset HRG

## Set Up

In [None]:
gs_dir = 'Data/GuitarSet/audio_pickup'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(gs_dir, train_files)


gs_hrg_train_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_valid_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = valid_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch = False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_test_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch = False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_train_dl = DataLoader(gs_hrg_train_ds, 3, shuffle=True)
gs_hrg_valid_dl = DataLoader(gs_hrg_valid_ds, 3, shuffle=False)
gs_hrg_test_dl = DataLoader(gs_hrg_test_ds, 3, shuffle=False)


gs_hrg_ckpt_dir = 'model_checkpoints/guitarset_hrg'

gs_hrg_model = GttNet(device=device,
                      hop_length=128,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=False,
                      gru_cat=False,
                      use_reverb=False).to(device)

## Train

In [None]:
start_time = time.time()

train_loop(model = gs_hrg_model, 
            train_loader = gs_hrg_train_dl, 
            valid_loader = gs_hrg_valid_dl, 
            epochs=40200,
            valid_freq=10,
            ckpt_dir=gs_hrg_ckpt_dir,
            loud_epoch_freq=10,
            train_hours=8,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))

## Timbre Transfer Evaluate

In [None]:
#load model
ckpt_path = 'model_checkpoints/guitarset_hrg/model_epoch_final_config2.pt'

gs_hrg_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=False,
                      gru_cat=False,
                      use_reverb=False).to(device)

gs_hrg_model.load_state_dict(torch.load(ckpt_path))

gs_dir = 'Data/GuitarSet/audio_pickup'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(gs_dir, train_files)

gs_hrg_test_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_test_dl = DataLoader(gs_hrg_test_ds, 3, shuffle=False)

### Reconstruction Metrics

Mean MFCC Distance: 2166.895033094618

Mean loudness l2 Distance: 67.18081707424588

Mean MS spectral loss: 0.27293260395526886


In [None]:
test_len = len(gs_hrg_test_ds)

mfcc_dist = 0
loudness_l2 = 0
msstft = 0

gs_hrg_model.eval()
msstft_calc = SpectralLoss()

for batch in gs_hrg_test_dl:
    out_audio = gs_hrg_model(batch)
    
    mfcc_dist += mfcc_fro_distance(batch['audio'],out_audio, device=device)
    loudness_l2 += loudness_l2_distance(batch['audio'], out_audio, device=device)
    msstft += msstft_calc(out_audio,batch['audio']).item()

    
print('Mean MFCC Distance: {}'.format(mfcc_dist/test_len))
print('Mean loudness l2 Distance: {}'.format(loudness_l2/test_len))
print('Mean MS spectral loss: {}'.format(msstft/test_len))

### EGDB

In [None]:
gs_baseline_dir = 'Data/GuitarSet/eval_fad'

egdb_dir = 'Data/EGDB/audio'
egdb_test_dir = 'Data/EGDB/test'

split = train_test_split(egdb_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

egdb_test_ds = GttDataset(audio_dir = egdb_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 3,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'], 
                          device=device)

egdb_test_dl = DataLoader(egdb_test_ds, 3, shuffle=False)


test_model(model = gs_hrg_model, 
           test_loader = egdb_test_dl, 
           calc_fad=True, 
           baseline_dir=gs_baseline_dir,
           test_dir=egdb_test_dir)

### Synthtab acoustic

In [None]:
gs_baseline_dir = 'Data/GuitarSet/eval_fad'

st_dir = 'Data/synthtab_acoustic/audio'
st_test_dir = 'Data/synthtab_acoustic/test'

split = train_test_split(st_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

st_test_ds = GttDataset(audio_dir = st_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 15,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                        permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

st_test_dl = DataLoader(st_test_ds, 3, shuffle=True)


test_model(model = gs_hrg_model, 
           test_loader = st_test_dl, 
           calc_fad=True, 
           baseline_dir=gs_baseline_dir,
           test_dir=st_test_dir)

# Guitarset BP

## Set up

In [None]:
gs_dir = 'Data/GuitarSet/audio_pickup'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(gs_dir, train_files)

gs_bp_train_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = '',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_bp_valid_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = '',
                             file_list = valid_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_bp_test_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_bp_train_dl = DataLoader(gs_bp_train_ds, 3, shuffle=True)
gs_bp_valid_dl = DataLoader(gs_bp_valid_ds, 3, shuffle=False)
gs_bp_test_dl = DataLoader(gs_bp_test_ds, 3, shuffle=False)


gs_bp_ckpt_dir = 'model_checkpoints/guitarset_bp'

gs_bp_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      gru_cat=False,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=True,
                      trainable_velocity=False,
                      use_reverb=False).to(device)

## Train

mean MFCC distance scores of different configurations

Configuration 1:2526.36

Configuration 2: 2297

In [None]:
start_time = time.time()

train_loop(model = gs_bp_model, 
            train_loader = gs_bp_train_dl, 
            valid_loader = gs_bp_valid_dl, 
            epochs=4000,
            valid_freq=10,
            loud_epoch_freq=10,
            train_hours=8,
            ckpt_dir=gs_bp_ckpt_dir,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))

## Timbre Transfer Evaluate

In [None]:
#load model
ckpt_path = 'model_checkpoints/guitarset_bp/model_epoch_final_config2.pt'

gs_bp_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=True,
                      trainable_velocity=True,
                      gru_cat=False,
                      use_reverb=False).to(device)

gs_bp_model.load_state_dict(torch.load(ckpt_path))

gs_dir = 'Data/GuitarSet/audio_pickup'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(gs_dir, train_files)

gs_bp_test_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_bp_test_dl = DataLoader(gs_bp_test_ds, 3, shuffle=False)

### Reconstruction Metrics

Mean MFCC Distance: 2367.958062065972

Mean loudness l2 Distance: 92.15786955091689

Mean MS spectral loss: 0.29661866029103595

In [None]:
test_len = len(gs_bp_test_ds)

mfcc_dist = 0
loudness_l2 = 0
msstft = 0

gs_bp_model.eval()
msstft_calc = SpectralLoss()

for batch in gs_bp_test_dl:
    out_audio = gs_bp_model(batch)
    
    mfcc_dist += mfcc_fro_distance(batch['audio'],out_audio, device=device)
    loudness_l2 += loudness_l2_distance(batch['audio'], out_audio, device=device)
    msstft += msstft_calc(out_audio,batch['audio']).item()

    
print('Mean MFCC Distance: {}'.format(mfcc_dist/test_len))
print('Mean loudness l2 Distance: {}'.format(loudness_l2/test_len))
print('Mean MS spectral loss: {}'.format(msstft/test_len))

### EGDB

FAD: 

In [None]:
gs_baseline_dir = 'Data/GuitarSet/eval_fad'

egdb_dir = 'Data/EGDB/audio'
egdb_test_dir = 'Data/EGDB/test'

split = train_test_split(egdb_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

egdb_test_ds = GttDataset(audio_dir = egdb_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 2,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

egdb_test_dl = DataLoader(egdb_test_ds, 3, shuffle=False)


test_model(model = gs_bp_model, 
           test_loader = egdb_test_dl, 
           calc_fad=True, 
           baseline_dir=gs_baseline_dir,
           test_dir=egdb_test_dir)

### SynthTab

FAD:

In [None]:
gs_baseline_dir = 'Data/GuitarSet/eval_fad'

st_dir = 'Data/synthtab_acoustic/audio'
st_test_dir = 'Data/synthtab_acoustic/test'

split = train_test_split(st_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

st_test_ds = GttDataset(audio_dir = st_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 15,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

st_test_dl = DataLoader(st_test_ds, 3, shuffle=False)


test_model(model = gs_bp_model, 
           test_loader = st_test_dl, 
           calc_fad=True, 
           baseline_dir=gs_baseline_dir,
           test_dir=st_test_dir)

# EGDB BP

## Set up

In [None]:
egdb_dir = 'Data/EGDB/audio'

split = train_test_split(egdb_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(egdb_dir, train_files)

egdb_bp_train_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             trim_duration = 3,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

egdb_bp_valid_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = valid_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             trim_duration = 3,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

egdb_bp_test_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             trim_duration = 3,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

egdb_bp_train_dl = DataLoader(egdb_bp_train_ds, 3, shuffle=True)
egdb_bp_valid_dl = DataLoader(egdb_bp_valid_ds, 3, shuffle=False)
egdb_bp_test_dl = DataLoader(egdb_bp_test_ds, 3, shuffle=False)


egdb_bp_ckpt_dir = 'model_checkpoints/egdb_bp'

egdb_bp_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=True,
                      gru_cat=False,
                      trainable_velocity=False,
                      use_reverb=False).to(device)

## Train

Configuration 1: 2644

Configuration 2: 2323

In [None]:
start_time = time.time()

train_loop(model = egdb_bp_model, 
            train_loader = egdb_bp_train_dl, 
            valid_loader = egdb_bp_valid_dl, 
            epochs=4000,
            valid_freq=1,
            loud_epoch_freq=10,
            train_hours=8,
            ckpt_dir=egdb_bp_ckpt_dir,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))

## Evaluate

In [None]:
ckpt_path = 'model_checkpoints/egdb_bp/model_epoch_final_config2.pt'

egdb_bp_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=True,
                      gru_cat=False,
                      trainable_velocity=True,
                      use_reverb=False).to(device)

egdb_bp_model.load_state_dict(torch.load(ckpt_path))

egdb_dir = 'Data/EGDB/audio'

split = train_test_split(egdb_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
test_files = split['train']

loudness_metrics = get_mean_std_loudness(egdb_dir, train_files)

egdb_bp_test_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             trim_duration = 3,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)
egdb_bp_test_dl = DataLoader(egdb_bp_test_ds, 3, shuffle=False)

### Reconstruction Metrics

Mean MFCC Distance: 6956.451264880952

Mean loudness l2 Distance: 634.1826414078001

Mean MS spectral loss: 1.0165660381317139

In [None]:
test_len = len(egdb_bp_test_dl)

mfcc_dist = 0
loudness_l2 = 0
msstft = 0

egdb_bp_model.eval()
msstft_calc = SpectralLoss()

for batch in egdb_bp_test_dl:
    out_audio = egdb_bp_model(batch)
    
    mfcc_dist += mfcc_fro_distance(batch['audio'],out_audio, device=device)
    loudness_l2 += loudness_l2_distance(batch['audio'], out_audio, device=device)
    msstft += msstft_calc(out_audio,batch['audio']).item()

    
print('Mean MFCC Distance: {}'.format(mfcc_dist/test_len))
print('Mean loudness l2 Distance: {}'.format(loudness_l2/test_len))
print('Mean MS spectral loss: {}'.format(msstft/test_len))

### Guitar Set HRG

In [None]:
egdb_baseline_dir = 'Data/EGDB/eval_fad'

gs_dir = 'Data/GuitarSet/audio_pickup'
gs_test_dir = 'Data/GuitarSet/test'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

gs_hrg_test_ds = GttDataset(audio_dir = gs_dir,
                          midi_dir = 'Data/GuitarSet/hr_labels',
                          file_list = test_files, 
                          trim_duration = 5,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

gs_hrg_test_dl = DataLoader(gs_hrg_test_ds, 3, shuffle=False)


test_model(model = egdb_bp_model, 
           test_loader = gs_hrg_test_dl, 
           calc_fad=True, 
           baseline_dir=egdb_baseline_dir,
           test_dir=gs_test_dir)

### Guitar Set BP

In [None]:
egdb_baseline_dir = 'Data/EGDB/eval_fad'

gs_dir = 'Data/GuitarSet/audio_pickup'
gs_test_dir = 'Data/GuitarSet/test'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

gs_bp_test_ds = GttDataset(audio_dir = gs_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 5,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

gs_bp_test_dl = DataLoader(gs_bp_test_ds, 3, shuffle=False)


test_model(model = egdb_bp_model, 
           test_loader = gs_bp_test_dl, 
           calc_fad=True, 
           baseline_dir=egdb_baseline_dir,
           test_dir=gs_test_dir)

### Synthtab

In [None]:
egdb_baseline_dir = 'Data/EGDB/eval_fad'

st_dir = 'Data/synthtab_acoustic/audio'
st_test_dir = 'Data/synthtab_acoustic/test'

split = train_test_split(st_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

st_test_ds = GttDataset(audio_dir = st_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 5,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

st_test_dl = DataLoader(st_test_ds, 3, shuffle=False)


test_model(model = egdb_bp_model, 
           test_loader = st_test_dl, 
           calc_fad=True, 
           baseline_dir=egdb_baseline_dir,
           test_dir=st_test_dir)

# EGDB MIDI

## Setup

In [None]:
egdb_dir = 'Data/EGDB/audio'

split = train_test_split(egdb_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(egdb_dir, train_files)

egdb_midi_train_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = 'Data/EGDB/labels',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             trim_duration = 3,
                             max_n_pitch=30,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

egdb_midi_valid_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = valid_files, 
                             segment_length_seconds=4,
                             trim_duration = 3,
                             max_n_pitch=30,
                             permute_pitch=False,
                             random_crop=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)


egdb_midi_test_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             trim_duration = 3,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

egdb_midi_train_dl = DataLoader(egdb_midi_train_ds, 3, shuffle=True)
egdb_midi_valid_dl = DataLoader(egdb_midi_valid_ds, 3, shuffle=False)
egdb_midi_test_dl = DataLoader(egdb_midi_test_ds, 3, shuffle=False)


egdb_midi_ckpt_dir = 'model_checkpoints/egdb_midi'

egdb_midi_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      gru_cat=False,
                      use_amp_latent=False, #egdb paper sites that midi labels do not transcribe velocity
                      use_reverb=False).to(device)

## Train

In [None]:
start_time = time.time()

train_loop(model = egdb_midi_model, 
            train_loader = egdb_midi_train_dl, 
            valid_loader = egdb_midi_valid_dl, 
            epochs=20000,
            valid_freq=10,
            loud_epoch_freq=10,
            train_hours=12,
            ckpt_dir=egdb_midi_ckpt_dir,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))

## Evaluate

In [None]:
ckpt_path = 'model_checkpoints/egdb_midi/model_epoch_final_config2.pt'

egdb_midi_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=False,
                    gru_cat=False,
                      use_reverb=False).to(device)

egdb_midi_model.load_state_dict(torch.load(ckpt_path))


egdb_dir = 'Data/EGDB/audio'

split = train_test_split(egdb_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']
train_files = split['train']

loudness_metrics = get_mean_std_loudness(egdb_dir, train_files)

egdb_midi_test_ds = GttDataset(audio_dir = egdb_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             trim_duration = 3,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

egdb_midi_test_dl = DataLoader(egdb_midi_test_ds, 3, shuffle=False)

### Reconstruction Metrics

Mean MFCC Distance: 5544.938530815973

Mean loudness l2 Distance: 279.251953125

Mean MS spectral loss: 0.9871127605438232

In [None]:
test_len = len(egdb_midi_test_dl)

mfcc_dist = 0
loudness_l2 = 0
msstft = 0

egdb_midi_model.eval()
msstft_calc = SpectralLoss()

for batch in egdb_midi_test_dl:
    out_audio = egdb_midi_model(batch)
    
    mfcc_dist += mfcc_fro_distance(batch['audio'],out_audio, device=device)
    loudness_l2 += loudness_l2_distance(batch['audio'], out_audio, device=device)
    msstft += msstft_calc(out_audio,batch['audio']).item()

    
print('Mean MFCC Distance: {}'.format(mfcc_dist/test_len))
print('Mean loudness l2 Distance: {}'.format(loudness_l2/test_len))
print('Mean MS spectral loss: {}'.format(msstft/test_len))

### Guitar Set HRG

In [None]:
egdb_baseline_dir = 'Data/EGDB/eval_fad'

gs_dir = 'Data/GuitarSet/audio_pickup'
gs_test_dir = 'Data/GuitarSet/test'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

gs_hrg_test_ds = GttDataset(audio_dir = gs_dir,
                          midi_dir = 'Data/GuitarSet/hr_labels',
                          file_list = test_files, 
                          trim_duration = 5,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

gs_hrg_test_dl = DataLoader(gs_hrg_test_ds, 3, shuffle=False)


test_model(model = egdb_midi_model, 
           test_loader = gs_hrg_test_dl, 
           calc_fad=True, 
           baseline_dir=egdb_baseline_dir,
           test_dir=gs_test_dir)

### Guitar Set BP

In [None]:
egdb_baseline_dir = 'Data/EGDB/eval_fad'

gs_dir = 'Data/GuitarSet/audio_pickup'
gs_test_dir = 'Data/GuitarSet/test'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

gs_bp_test_ds = GttDataset(audio_dir = gs_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 5,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

gs_bp_test_dl = DataLoader(gs_bp_test_ds, 3, shuffle=False)


test_model(model = egdb_midi_model, 
           test_loader = gs_bp_test_dl, 
           calc_fad=True, 
           baseline_dir=egdb_baseline_dir,
           test_dir=gs_test_dir)

### Synthtab

In [None]:
egdb_baseline_dir = 'Data/EGDB/eval_fad'

st_dir = 'Data/synthtab_acoustic/audio'
st_test_dir = 'Data/synthtab_acoustic/test'

split = train_test_split(st_dir, train_pct=0, valid_pct=0)

test_files = split['test']

st_test_ds = GttDataset(audio_dir = st_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 10,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

st_test_dl = DataLoader(st_test_ds, 3, shuffle=False)


test_model(model = egdb_midi_model, 
           test_loader = st_test_dl, 
           calc_fad=True, 
           baseline_dir=egdb_baseline_dir,
           test_dir=st_test_dir)

# Synth Tab Acoustic

## Set Up

In [None]:
st_acoustic_dir = 'Data/synthtab_acoustic/audio'

split = train_test_split(st_acoustic_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(st_acoustic_dir, train_files)


st_acoustic_train_ds = GttDataset(audio_dir = st_acoustic_dir,
                             midi_dir = '',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

st_acoustic_valid_ds = GttDataset(audio_dir = st_acoustic_dir,
                             midi_dir = '',
                             file_list = valid_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

st_acoustic_test_ds = GttDataset(audio_dir = st_acoustic_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)


st_acoustic_train_dl = DataLoader(st_acoustic_train_ds, 3, shuffle=True)
st_acoustic_valid_dl = DataLoader(st_acoustic_valid_ds, 3, shuffle=False)
st_acoustic_test_dl = DataLoader(st_acoustic_test_ds, 3, shuffle=False)

st_acoustic_ckpt_dir = 'model_checkpoints/synthtab_acoustic'

st_acoustic_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      gru_cat=False,
                      trainable_velocity=False,
                      use_amp_latent=True,
                      use_reverb=False).to(device)

## Train

Config 1: 3611

Config 2: 3562

In [None]:
start_time = time.time()

train_loop(model = st_acoustic_model, 
            train_loader = st_acoustic_train_dl, 
            valid_loader = st_acoustic_valid_dl, 
            epochs=100000,
            valid_freq=100,
            loud_epoch_freq=100,
            train_hours=8,
            early_stop_epochs = 200,
            ckpt_dir=st_acoustic_ckpt_dir,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))

## Timbre Transfer Evaluate

In [None]:
ckpt_path = 'model_checkpoints/synthtab_acoustic/model_epoch_final_config2.pt'

st_acoustic_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      gru_cat=False,
                      trainable_velocity=True,
                      use_amp_latent=False,
                      use_reverb=False).to(device)

st_acoustic_model.load_state_dict(torch.load(ckpt_path))


st_acoustic_dir = 'Data/synthtab_acoustic/audio'

split = train_test_split(st_acoustic_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(st_acoustic_dir, train_files)

st_acoustic_test_ds = GttDataset(audio_dir = st_acoustic_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

st_acoustic_test_dl = DataLoader(st_acoustic_test_ds, 3, shuffle=False)

### Reconstruction Metrics

Mean MFCC Distance: 6835.857421875

Mean loudness l2 Distance: 367.476806640625

Mean MS spectral loss: 1.0953319072723389


In [None]:
test_len = len(st_acoustic_test_dl)

mfcc_dist = 0
loudness_l2 = 0
msstft = 0

st_acoustic_model.eval()
msstft_calc = SpectralLoss()

for batch in st_acoustic_test_dl:
    out_audio = st_acoustic_model(batch)
    
    mfcc_dist += mfcc_fro_distance(batch['audio'],out_audio, device=device)
    loudness_l2 += loudness_l2_distance(batch['audio'], out_audio, device=device)
    msstft += msstft_calc(out_audio,batch['audio']).item()

    
print('Mean MFCC Distance: {}'.format(mfcc_dist/test_len))
print('Mean loudness l2 Distance: {}'.format(loudness_l2/test_len))
print('Mean MS spectral loss: {}'.format(msstft/test_len))

### SynthTab Electric

In [None]:
st_acoustic_baseline_dir  = 'Data/synthtab_acoustic/eval_fad'

st_electric_dir = 'Data/synthtab_electric/audio'
st_electric_test_dir = 'Data/synthtab_electric/test'

split = train_test_split(st_electric_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

st_electric_test_ds = GttDataset(audio_dir = st_electric_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 10,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

st_electric_test_dl = DataLoader(st_electric_test_ds, 3, shuffle=False)


test_model(model = st_acoustic_model, 
           test_loader = st_electric_test_dl, 
           calc_fad=True, 
           baseline_dir=st_acoustic_baseline_dir,
           test_dir=st_electric_test_dir)

# SynthTab Electric

## Set Up

In [None]:
st_electric_dir = 'Data/synthtab_electric/audio'

split = train_test_split(st_electric_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files  = split['test']

loudness_metrics = get_mean_std_loudness(st_electric_dir, train_files)


st_electric_train_ds = GttDataset(audio_dir = st_electric_dir,
                             midi_dir = '',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

st_electric_valid_ds = GttDataset(audio_dir = st_electric_dir,
                             midi_dir = '',
                             file_list = valid_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

st_electric_test_ds = GttDataset(audio_dir = st_electric_dir,
                             midi_dir = '',
                             file_list = test_files, 
                             segment_length_seconds=4,
                             trim_duration = 10,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch=False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

st_electric_train_dl = DataLoader(st_electric_train_ds, 3, shuffle=True)
st_electric_valid_dl = DataLoader(st_electric_valid_ds, 3, shuffle=False)
st_electric_test_dl = DataLoader(st_electric_test_ds, 3, shuffle=False)


st_electric_ckpt_dir = 'model_checkpoints/synthtab_electric'

st_electric_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      gru_cat=False,
                      trainable_velocity=False,
                      use_amp_latent=True,
                      use_reverb=False).to(device)

## Train

Config 1: 3981

Config 2: 4232

In [None]:
start_time = time.time()

train_loop(model = st_electric_model, 
            train_loader = st_electric_train_dl, 
            valid_loader = st_electric_valid_dl, 
            epochs=100000,
            valid_freq=100,
            loud_epoch_freq=100,
            train_hours=8,
            early_stop_epochs = 200,
            ckpt_dir=st_electric_ckpt_dir,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))

## Evaluate

In [None]:
ckpt_path = 'model_checkpoints/synthtab_electric/model_epoch_final_config1.pt'

st_electric_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      gru_cat=False,
                      trainable_velocity=False,
                      use_amp_latent=True, 
                      use_reverb=False).to(device)

st_electric_model.load_state_dict(torch.load(ckpt_path))

st_electric_dir = 'Data/synthtab_electric/audio'

split = train_test_split(st_electric_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']

loudness_metrics = get_mean_std_loudness(st_electric_dir, train_files)

## Reconstruction Metrics

Mean MFCC Distance: 6389.4375

Mean loudness l2 Distance: 314.3045349121094

Mean MS spectral loss: 1.0757502317428589


In [None]:
test_len = len(st_electric_test_dl)

mfcc_dist = 0
loudness_l2 = 0
msstft = 0

st_acoustic_model.eval()
msstft_calc = SpectralLoss()

for batch in st_electric_test_dl:
    out_audio = st_electric_model(batch)
    
    mfcc_dist += mfcc_fro_distance(batch['audio'],out_audio, device=device)
    loudness_l2 += loudness_l2_distance(batch['audio'], out_audio, device=device)
    msstft += msstft_calc(out_audio,batch['audio']).item()

    
print('Mean MFCC Distance: {}'.format(mfcc_dist/test_len))
print('Mean loudness l2 Distance: {}'.format(loudness_l2/test_len))
print('Mean MS spectral loss: {}'.format(msstft/test_len))

### Synth Tab Acoustic

In [None]:
st_electric_baseline_dir = 'Data/synthtab_electric/eval_fad'

st_acoustic_dir = 'Data/synthtab_acoustic/audio'
st_acoustic_test_dir = 'Data/synthtab_acoustic/test'

split = train_test_split(st_acoustic_dir, train_pct=.8, valid_pct=.1)

test_files = split['test']

st_acoustic_test_ds = GttDataset(audio_dir = st_acoustic_dir,
                          midi_dir = '',
                          file_list = test_files, 
                          trim_duration = 10,
                          segment_length_seconds=4,
                          max_n_pitch=30,
                          random_crop=False,
                          permute_pitch=False,
                          loud_mean = loudness_metrics['mean'],
                          loud_std = loudness_metrics['std'],
                          device=device)

st_acoustic_test_dl = DataLoader(st_acoustic_test_ds, 3, shuffle=False)


test_model(model = st_electric_model, 
           test_loader = st_acoustic_test_dl, 
           calc_fad=True, 
           baseline_dir=st_electric_baseline_dir,
           test_dir=st_acoustic_test_dir)

# Guitarset Harmonic

Additional experiment testing only harmonic synthesis

In [None]:
gs_dir = 'Data/GuitarSet/audio_pickup'

split = train_test_split(gs_dir, train_pct=.8, valid_pct=.1)

train_files = split['train']
valid_files = split['valid']
test_files = split['test']

loudness_metrics = get_mean_std_loudness(gs_dir, train_files)


gs_hrg_train_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = train_files, 
                             segment_length_seconds=4,
                             synth_unit_frames = 90,
                             max_n_pitch=30,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_valid_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = valid_files, 
                             synth_unit_frames = 90,
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch = False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_test_ds = GttDataset(audio_dir = gs_dir,
                             midi_dir = 'Data/GuitarSet/hr_labels',
                             file_list = test_files, 
                            synth_unit_frames = 90,
                             segment_length_seconds=4,
                             max_n_pitch=30,
                             random_crop=False,
                             permute_pitch = False,
                             loud_mean = loudness_metrics['mean'],
                             loud_std = loudness_metrics['std'],
                             device=device)

gs_hrg_train_dl = DataLoader(gs_hrg_train_ds, 3, shuffle=True)
gs_hrg_valid_dl = DataLoader(gs_hrg_valid_ds, 3, shuffle=False)
gs_hrg_test_dl = DataLoader(gs_hrg_test_ds, 3, shuffle=False)


gs_hrg_ckpt_dir = 'model_checkpoints/guitarset_hrg'

gs_hrg_model = GttNet(device=device,
                      input_length_seconds=4,
                      hop_length=128,
                      synth_unit_frames = 90,
                      harmonic_n_controls=101,
                      gru_features=512,
                      noise_initial_bias=-5.0,
                      mlp_blocks=3,
                      timbre_enc_size = 15,
                      use_timbre_encoder=True,
                      use_amp_latent=False,
                      gru_cat=True,
                      trainable_velocity=False,
                      use_reverb=False).to(device)

In [None]:
start_time = time.time()

train_loop(model = gs_hrg_model, 
            train_loader = gs_hrg_train_dl, 
            valid_loader = gs_hrg_valid_dl, 
            epochs=40200,
            valid_freq=10,
            ckpt_dir=gs_hrg_ckpt_dir,
            loud_epoch_freq=10,
            train_hours=8,
            loud_batch=False)

end_time = time.time()
total_time = end_time - start_time

print('Run Time: {}'.format(total_time))