In [1]:
import torch
import numpy as np
import torch.nn.functional as F

from utils.utils import *
from datetime import datetime
from evaluation.train_inception_model import SpectrogramInception3
from data.preprocessing import AudioProcessor
from torch.utils.data import DataLoader
from evaluation.metrics.inception_score import InceptionScore
from evaluation.inception_models import DEFAULT_FOOTSTEPS_INCEPTION_MODEL
from data.audio_transforms import MelScale
from tqdm import trange
import ipdb

from os.path import dirname, realpath, join
import logging
from data.loaders import get_data_loader

In [2]:
config = {
    "model_name": "footsteps_inception_model_best_2021-04-29.pt",
    "comments": "inception trained on footsteps dataset",
    "state_dict_path": "evaluation/inception_models/footsteps_inception_model_best_2021-04-29.pt",
    
    "real_samples_path": "audio/footsteps_real",
    "synth_samples_path": "audio/footsteps_generated_23-04-2021_15h",
    
    "output_path": "evaluation",
    "output_folder": "evaluation_metrics",
    
    "batch_size": 20,

    "real_samples_loader_config": {
        "dbname": "footsteps",
        "data_path": "audio/footsteps_real/",
        "criteria": {},
        "shuffle": True,
        "tr_val_split": 1.0
    },

    "synth_samples_loader_config": {
        "dbname": "footsteps",
        "data_path": "audio/footsteps_generated_23-04-2021_15h/",
        "criteria": {},
        "shuffle": True,
        "tr_val_split": 1.0
    },
    
    "transform_config": {
        "transform": "stft",
        "fade_out": True,
        "fft_size": 1024,
        "win_size": 1024,
        "n_frames": 64,
        "hop_size": 256,
        "log": False,
        "ifreq": False,
        "sample_rate": 16000,
        "audio_length": 16000
    }
}

In [3]:
model_name = config['model_name']
state_dict_path = config['state_dict_path']
output_path = mkdir_in_path(config['output_path'], config['output_folder'])
# output_log = join(output_path, f"{model_name}_evaluation.log")
# logging.basicConfig(filename=output_log, level=logging.INFO)

In [4]:
real_samples_loader_config = config['real_samples_loader_config']

transform_config = config['transform_config']
transform = transform_config['transform']

dbname = real_samples_loader_config['dbname']

batch_size = config['batch_size']

processor = AudioProcessor(**transform_config)

loader_module = get_data_loader(dbname)

real_samples_loader = loader_module(name=dbname + '_' + transform, preprocessing=processor, **real_samples_loader_config)

n_real_samples = len(real_samples_loader)
print('n_real_samples: ', n_real_samples)

real_samples_data_loader = DataLoader(real_samples_loader,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=2)

device = 'cuda' if GPU_is_available() else 'cpu'

state_dict = torch.load(state_dict_path, map_location=device)
inception_footsteps = SpectrogramInception3(state_dict['fc.weight'].shape[0], aux_logits=False)
inception_footsteps.load_state_dict(state_dict)
inception_footsteps = inception_footsteps.to(device)

mel = MelScale(sample_rate=transform_config['sample_rate'],
                fft_size=transform_config['fft_size'],
                n_mel=transform_config.get('n_mel', 256),
                rm_dc=True)
mel = mel.to(device)

Configuring stft transform...
Dataset audio/footsteps_real/processed/footsteps_stft/footsteps_stft.pt exists. Reloading...
n_real_samples:  720
Cuda not available. Running on CPU


In [5]:
is_maker_real_samples = InceptionScore()
inception_score_real_samples = []

with torch.no_grad():
    for batch_idx, data in enumerate(real_samples_data_loader):
        input, labels = data
        input.to(device)
        input = mel(input.float())
        mag_input = F.interpolate(input[:, 0:1], (299, 299))
        
        preds = inception_footsteps(mag_input.float())
        
        is_maker_real_samples.updateWithMiniBatch(preds)
        inception_score_real_samples.append(is_maker_real_samples.getScore())
        
        print('batch: ', batch_idx, 'IS: ', is_maker_real_samples.getScore())

batch:  0 IS:  5.466772467532699
batch:  1 IS:  5.565402963191995
batch:  2 IS:  5.488650793942982
batch:  3 IS:  5.39592375210772
batch:  4 IS:  5.530355001448792
batch:  5 IS:  5.511362250136727
batch:  6 IS:  5.569491659003954
batch:  7 IS:  5.596020796236649
batch:  8 IS:  5.571151441673725
batch:  9 IS:  5.590743335479327
batch:  10 IS:  5.622973009719555
batch:  11 IS:  5.640172493268067
batch:  12 IS:  5.6603077568569855
batch:  13 IS:  5.66947609021406
batch:  14 IS:  5.651755544234304
batch:  15 IS:  5.645156396813439
batch:  16 IS:  5.645800953091852
batch:  17 IS:  5.657693016402988
batch:  18 IS:  5.6632517700655125
batch:  19 IS:  5.635337063481726
batch:  20 IS:  5.619325566882319
batch:  21 IS:  5.61160862551869
batch:  22 IS:  5.626786617103338
batch:  23 IS:  5.638231170166783
batch:  24 IS:  5.644710638177991
batch:  25 IS:  5.651486517388761
batch:  26 IS:  5.651865390349704
batch:  27 IS:  5.646568196427036
batch:  28 IS:  5.65654520139305
batch:  29 IS:  5.64827223

In [6]:
IS_mean = np.mean(inception_score_real_samples)
IS_std = np.std(inception_score_real_samples)
output_file = f'{output_path}/IS_real_{str(n_real_samples)}_{model_name}_{datetime.now().strftime("%d-%m-%y_%H_%M")}.txt'

with open(output_file, 'w') as f:
    f.write(str(IS_mean) + '\n')
    f.write(str(IS_std))
    f.close()

In [7]:
synth_samples_loader_config = config['synth_samples_loader_config']
synth_samples_loader = loader_module(name=dbname + '_' + transform, preprocessing=processor, **synth_samples_loader_config)

n_synth_samples = len(synth_samples_loader)
print('n_synth_samples: ', n_synth_samples)

synth_samples_data_loader = DataLoader(synth_samples_loader,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=2)

Dataset audio/footsteps_generated_audio_23-04-2021_15h/processed/footsteps_stft/footsteps_stft.pt exists. Reloading...
n_synth_samples:  600


In [8]:
is_maker_synth_samples = InceptionScore()
inception_score_synth_samples = []

with torch.no_grad():
    for batch_idx, data in enumerate(synth_samples_data_loader):
        input, labels = data
        input.to(device)
        input = mel(input.float())
        mag_input = F.interpolate(input[:, 0:1], (299, 299))
        
        preds = inception_footsteps(mag_input.float())
        
        is_maker_synth_samples.updateWithMiniBatch(preds)
        inception_score_synth_samples.append(is_maker_synth_samples.getScore())
        
        print('batch: ', batch_idx, 'IS: ', is_maker_synth_samples.getScore())

batch:  0 IS:  4.423433516736267
batch:  1 IS:  4.739492460773463
batch:  2 IS:  4.780054009446706
batch:  3 IS:  4.794234848508344
batch:  4 IS:  4.881266130791967
batch:  5 IS:  4.897593552653293
batch:  6 IS:  4.862352724544507
batch:  7 IS:  4.876063027039243
batch:  8 IS:  4.918181825607562
batch:  9 IS:  4.881561095075136
batch:  10 IS:  4.855938893865848
batch:  11 IS:  4.82972756597471
batch:  12 IS:  4.8460322472590756
batch:  13 IS:  4.826242596787764
batch:  14 IS:  4.82129198938978
batch:  15 IS:  4.85694043447846
batch:  16 IS:  4.8497129942704555
batch:  17 IS:  4.820516677409252
batch:  18 IS:  4.830111953674338
batch:  19 IS:  4.839483279567941
batch:  20 IS:  4.866913317115213
batch:  21 IS:  4.843082987051767
batch:  22 IS:  4.835472558572825
batch:  23 IS:  4.839133574784806
batch:  24 IS:  4.839520770586575
batch:  25 IS:  4.8305785479729435
batch:  26 IS:  4.84058180697983
batch:  27 IS:  4.85227207247436
batch:  28 IS:  4.843070559135143
batch:  29 IS:  4.83337969

In [9]:
IS_mean = np.mean(inception_score_synth_samples)
IS_std = np.std(inception_score_synth_samples)
output_file = f'{output_path}/IS_synth_{str(n_synth_samples)}_{model_name}_{datetime.now().strftime("%d-%m-%y_%H_%M")}.txt'

with open(output_file, 'w') as f:
    f.write(str(IS_mean) + '\n')
    f.write(str(IS_std))
    f.close()