In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from datetime import datetime
from torch.utils.data import DataLoader

import utils.utils as utils
import evaluation.inception_network as inception_network
import evaluation.metrics.inception_score as inception_score
import data.preprocessing as preprocessing
import data.audio_transforms as audio_transforms
import data.loaders as loaders

In [None]:
device = 'cpu'

In [None]:
# configuration dictionary

config = {
    "model_name": "footsteps_inception_model_best_2021-09-26.pt",
    "comments": "inception trained on footsteps dataset",
    # "state_dict_path": "/homes/mc309/hifi-wavegan/drumgan_evaluation/evaluation/inception_models/footsteps_inception_model_best_2021-09-26.pt",
    "state_dict_path": "/Users/Marco/Documents/OneDrive - Queen Mary, University of London/PHD/REPOS/hifi-wavegan/drumgan_evaluation/evaluation/inception_models/footsteps_inception_model_best_2021-09-26.pt",

    # real samples used to train inception model
    # "real_samples_path": "/Users/Marco/Documents/OneDrive - Queen Mary, University of London/PHD/REPOS/_footsteps_data/zapsplat_misc_shoes_misc_surfaces_inception_network/",
    # real samples used to train gan
    # "real_samples_path": "/Users/Marco/Documents/OneDrive - Queen Mary, University of London/PHD/REPOS/_footsteps_data/zapsplat_pack_footsteps_high_heels_1s_aligned_for_inception_score/",
    # gan synthesised samples
    # "synth_samples_path": "/Users/Marco/Documents/OneDrive - Queen Mary, University of London/PHD/REPOS/hifi-wavegan/checkpoints/2021-09-20_13h23m-hifi/120k_generated_audio_large_for_is_and_kid/",
    # "synth_samples_path": "/Users/Marco/Documents/OneDrive - Queen Mary, University of London/PHD/REPOS/hifi-wavegan/checkpoints/2021-09-20_19h46m-wave/120k_generated_audio_large_for_is_and_kid/",
    
    "output_path": "evaluation",
    "output_folder": "evaluation_metrics",
    
    "batch_size": 20,

    "samples_loader_config": {
        "dbname": "footsteps",
        # "data_path": "/homes/mc309/ccwavegan-hifigan-fresh/checkpoints/2021-09-20_19h46m-wave/120k_generated_audio_large_for_is/",
        "data_path": "/Users/Marco/Documents/OneDrive - Queen Mary, University of London/PHD/REPOS/hifi-wavegan/checkpoints/2021-09-20_19h46m-wave/120k_generated_audio_large_for_is/",
        "criteria": {},
        "shuffle": True,
        "tr_val_split": 1.0
    },

    "transform_config": {
        "transform": "stft",
        "fade_out": True,
        "fft_size": 1024,
        "win_size": 1024,
        "n_frames": 64,
        "hop_size": 256,
        "log": False,
        "ifreq": False,
        "sample_rate": 16000,
        "audio_length": 8192
    }
}

In [None]:
model_name = config['model_name']
state_dict_path = config['state_dict_path']
output_path = utils.mkdir_in_path(config['output_path'], config['output_folder'])

In [None]:
# setup dataloader and processor for real samples

samples_loader_config = config['samples_loader_config']

transform_config = config['transform_config']
transform = transform_config['transform']

dbname = samples_loader_config['dbname']

batch_size = config['batch_size']

processor = preprocessing.AudioProcessor(**transform_config)

loader_module = loaders.get_data_loader(dbname)

samples_loader = loader_module(name=dbname + '_' + transform, preprocessing=processor, **samples_loader_config)

n_samples = len(samples_loader)
print('n_samples: ', n_samples)

samples_data_loader = DataLoader(samples_loader,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=2)


# load inception model
state_dict = torch.load(state_dict_path, map_location=device)
inception_footsteps = inception_network.SpectrogramInception3(state_dict['fc.weight'].shape[0], aux_logits=False)
inception_footsteps.load_state_dict(state_dict)
# inception_footsteps = inception_footsteps.to(device)

# inception model is trained on mel spectrograms
mel = audio_transforms.MelScale(sample_rate=transform_config['sample_rate'],
                fft_size=transform_config['fft_size'],
                n_mel=transform_config.get('n_mel', 256),
                rm_dc=True)
# mel = mel.to(device)

In [None]:
# compute IS
n_iter = 1
is_maker_samples = inception_score.InceptionScore()
inception_score_samples = []

for i in range(n_iter):
    print("iter: ", i)
    with torch.no_grad():
        for batch_idx, data in enumerate(samples_data_loader):
            input, labels = data
            # input.to(device)
            input = mel(input.float())
            # input.to(device)
            mag_input = F.interpolate(input[:, 0:1], (299, 299))
            # mag_input.to(device)
            
            preds = inception_footsteps(mag_input.float())
            
            is_maker_samples.updateWithMiniBatch(preds)
            inception_score_samples.append(is_maker_samples.getScore())
            
            print('batch: ', batch_idx, 'IS: ', is_maker_samples.getScore())

In [None]:
# save result
IS_mean = np.mean(inception_score_samples)
IS_std = np.std(inception_score_samples)
output_file = f'{output_path}/IS_{str(n_samples)}_{model_name}_{datetime.now().strftime("%d-%m-%y_%H_%M")}.txt'

with open(output_file, 'w') as f:
    f.write(str(IS_mean) + '\n')
    f.write(str(IS_std))
    f.close()

print("IS_mean: ", IS_mean)
print("IS_std: ", IS_std)