In [1]:
import torch
import os
import sys
import random

# custom
from util import *
from transformers import GPT2Tokenizer
from ClipCap_forAAC.CLIPCAP_forAAC import * # network
from Train import *
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

data_dir = './Clotho'

# PANNs를 써먹기 위해 prefix_size를 수정
temporal_prefix_size = 15
global_prefix_size = 11
prefix_size = temporal_prefix_size + global_prefix_size

transformer_num_layers = {"temporal_num_layers" : 4, "global_num_layers" : 4}
prefix_size_dict = {"temporal_prefix_size" : temporal_prefix_size, "global_prefix_size" : global_prefix_size}

vocab_size = None
tokenizer_type = 'Custom'

if tokenizer_type == 'Custom' :
    tokenizer = tokenizer_forCustomVocab(Dataset = data_dir[2:])
    vocab_size = len(tokenizer.vocab)
elif tokenizer_type == 'GPT2' :
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

random_seed=2766
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark=False
torch.backends.cudnn.deterministic=True
np.random.seed(random_seed)
random.seed(random_seed)  

print("random_seed :", random_seed)
print("vocab_size :", vocab_size)
    
TEST_BATCH_SIZE = 5

if prefix_size == 0 :
    prefix_size = 26

test_dataloader_audiocaps  = CreateDataloader(tokenizer, './AudioCaps', TEST_BATCH_SIZE, 'test', prefix_size, is_TrainDataset = False, tokenizer_type = tokenizer_type)
test_dataloader_clotho = CreateDataloader(tokenizer, './Clotho', TEST_BATCH_SIZE, 'evaluation', prefix_size, is_TrainDataset = False, tokenizer_type = tokenizer_type)

random_seed : 2766
vocab_size : 10640


get dataset...: 100%|█████████████████████| 960/960 [00:00<00:00, 1383.66it/s]
get dataset...: 100%|████████████████████| 1045/1045 [00:06<00:00, 164.29it/s]


In [2]:
torch.cuda.empty_cache()

USE_CUDA = torch.cuda.is_available() 
device = torch.device('cuda:0' if USE_CUDA else 'cpu')

In [3]:
model = get_ClipCap_AAC(tokenizer, 
                        vocab_size = vocab_size, Dataset = data_dir[2:],
                        prefix_size_dict = prefix_size_dict, transformer_num_layers = transformer_num_layers, 
                        encoder_freeze = False, decoder_freeze = True,
                        pretrain_fromAudioCaps = True, device = device)

  fft_window = librosa.util.pad_center(fft_window, n_fft)
  return f(*args, **kwargs)


use Custom Tokenizer
temporal feature's mapping network : num_head = 8 num_layers = 4
global feature ver's mapping network : num_head = 8 num_layers = 4
Get Pre-traiend Params
Get Pre-traiend language header
GPT2 freezing


In [4]:
model_path = './Train_record/params_Custom_header_10640_clotho_2766CustomHeader/Param_epoch_19.pt'

params = torch.load(model_path, map_location = device)

model.load_state_dict(params) 

<All keys matched successfully>

In [5]:
# get_pred_captions(model, test_dataloader_audiocaps, device, dataset = 'AudioCaps') 
# get_pred_captions(model, test_dataloader_clotho, device, dataset = 'Clotho') 

In [6]:
metrics, captions_pred, captions_gt = eval_model(model, test_dataloader_clotho, 31, 'test', True, device, 'AudioCaps')

Eval using dataset...: 100%|██████████████| 1045/1045 [03:48<00:00,  4.57it/s]


loading annotations into memory...
0:00:00.006910
creating index...
index created!
Loading and preparing results...     
DONE (t=0.00s)
creating index...
index created!
tokenization...


PTBTokenizer tokenized 70696 tokens at 475271.06 tokens per second.
PTBTokenizer tokenized 11930 tokens at 122243.48 tokens per second.


setting up scorers...
computing Bleu score...
{'testlen': 9840, 'reflen': 10288, 'guess': [9840, 8795, 7750, 6705], 'correct': [5772, 2326, 925, 286]}
ratio: 0.9564541213062834
Bleu_1: 0.560
Bleu_2: 0.376
Bleu_3: 0.253
Bleu_4: 0.160
computing METEOR score...
METEOR: 0.170
computing Rouge score...
ROUGE_L: 0.378
computing CIDEr score...
CIDEr: 0.392
computing SPICE score...


Parsing reference captions
Parsing test captions


SPICE evaluation took: 1.948 s
SPICE: 0.118
computing SPIDEr score...
SPIDEr: 0.255


In [8]:
with open('clotho_customheader_pred_captions.pickle', 'wb') as f:
    pickle.dump(captions_pred, f, pickle.HIGHEST_PROTOCOL)

In [7]:
for i in range(len(captions_pred)):
    if captions_pred[i]['file_name'] == 'Rain on awning, canopy.wav' :
        print(captions_pred[i]['file_name'])
        print(captions_pred[i]['caption_predicted'])

Rain on awning, canopy.wav
rain is falling down on the ground.


In [9]:
d = sorted(metrics['cider']['scores'].items(), key = lambda item: item[1], reverse=True)

In [35]:
d[45]

('Snow crunch.wav', 1.3038218118558262)