In [2]:
import os
import argparse
import torch
from tqdm import tqdm

import data_loader as data_loaders
import model as models
import trainer.loss as loss_functions
import trainer.metric as metric_functions

import utils.util as util
from utils import color_print as cp

import numpy as np
import matplotlib.pylab as plt

In [8]:
base_model_config = 'config/kws_res15_base.json'
fine_model_config = 'config/kws_res15_fine_tune.json'

task = "kws_res15_narrow"

# folder_name = 'test/softmax_bce'
folder_name = 'res15/sigmoid_bce_loss/0'
most_recent_model = max(os.listdir(folder_name+"/"+task+"_base"))
base_model_cpt = os.path.join(folder_name, task+'_base', most_recent_model, 'model_best.pth')
print('base_model', base_model_cpt)

most_recent_model = max(os.listdir(folder_name+"/"+task+"_fine_tune/1/"))
model_1_cpt = os.path.join(folder_name, task+'_fine_tune/1/', most_recent_model, 'model_best.pth')
print('class 1 model', model_1_cpt)

most_recent_model = max(os.listdir(folder_name+"/{}_fine_tune/2/".format(task)))
model_2_cpt = os.path.join(folder_name, task+'_fine_tune/2/', most_recent_model, 'model_best.pth')
print('class 2 model', model_2_cpt)


base_model res15/sigmoid_bce_loss/0/kws_res15_narrow_base/0404_214351/model_best.pth
class 1 model res15/sigmoid_bce_loss/0/kws_res15_narrow_fine_tune/1/0404_222134/model_best.pth
class 2 model res15/sigmoid_bce_loss/0/kws_res15_narrow_fine_tune/2/0404_222230/model_best.pth


In [13]:
eval_all_model = True

# base model

In [22]:
base_config = torch.load(base_model_cpt)['config']
if "media" not in base_config['data_loader']['args']['data_dir']:
    base_config['data_loader']['args']['data_dir'] = "/media/brandon/SSD" + base_config['data_loader']['args']['data_dir']
base_config

{'name': 'kws_res15_narrow_base',
 'n_gpu': 1,
 'n_class': 30,
 'model': {'type': 'ResNarrowNet',
  'args': {'num_classes': 30,
   'n_layers': 13,
   'n_feature_maps': 19,
   'use_dilation': True}},
 'data_loader': {'type': 'GoogleKeywordDataLoader',
  'args': {'data_dir': '/media/brandon/SSD/data/speech_dataset',
   'batch_size': 128,
   'shuffle': True,
   'validation_split': 0.1,
   'num_workers': 2,
   'seed': 126}},
 'optimizer': {'type': 'SGD',
  'args': {'lr': 0.1, 'weight_decay': 0.0001, 'momentum': 0.9}},
 'loss': 'sigmoid_bce_loss',
 'metrics': ['pred_acc'],
 'lr_scheduler': {'type': 'MultiStepLR',
  'args': {'milestones': [10, 20], 'gamma': 0.1}},
 'trainer': {'epochs': 30,
  'save_dir': 'saved/',
  'save_period': 1,
  'verbosity': 2,
  'monitor': 'min val_loss',
  'early_stop': 10,
  'tensorboardX': True,
  'log_dir': 'saved/runs'}}

In [23]:
model_base = util.get_instance(models, 'model', base_config)

In [24]:
checkpoint_base = torch.load(base_model_cpt)
state_dict_base = checkpoint_base['state_dict']

In [35]:
if eval_all_model:
    model_base.load_state_dict(state_dict_base)
    model_base.eval()
    
    base_target_class = list(np.arange(base_config['n_class']))
    base_size_per_class = 3

    base_data_loader = getattr(data_loaders, base_config['data_loader']['type'])(
        base_config['data_loader']['args']['data_dir'],
        batch_size=512,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=2,
        size_per_class=base_size_per_class,
        target_class=base_target_class,
        unknown=True
    )
    
    base_loss_fn = getattr(loss_functions, base_config['loss'])

    base_config['metrics'] = ["pred_acc"]

    base_metrics = [getattr(metric_functions, met) for met in base_config['metrics']]
    
    total_loss = 0.0
    total_metrics = torch.zeros(len(base_metrics))

    with torch.no_grad():
            for i, (data, target) in enumerate(tqdm(base_data_loader)):
                one_hot_target = torch.eye(len(base_target_class))[target]

                if "kws" not in task:
                    plt.figure(figsize=[15,15])

                    for index, image in enumerate(data):
                        plt.subplot(len(data)/ size_per_class_base, size_per_class_base, index+1)
                        plt.imshow(np.reshape(torch.squeeze(image), [28,28]), cmap='gray')
                        plt.axis('off')
                        plt.title(target[index].item())

                output = model_base(data)

                # computing loss, metrics on test set
                loss = base_loss_fn(output, one_hot_target)
#                 loss = base_loss_fn(output, target)
                batch_size = data.shape[0]
                total_loss += loss.item() * batch_size
                for i, metric in enumerate(base_metrics):
                    total_metrics[i] += metric(output, target) * batch_size

    n_samples = len(base_data_loader.sampler)
    log = {'loss': total_loss / n_samples}
    log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(base_metrics)})

    test_result_str = 'TEST RESULTS\n'
    for key, val in log.items():
        test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

    cp.print_progress(test_result_str)

  0%|          | 0/1 [00:00<?, ?it/s]

< Dataset Summary >
	seed	: 0
	 bed 	: 0  ( 3 )
	 bird 	: 1  ( 3 )
	 cat 	: 2  ( 3 )
	 dog 	: 3  ( 3 )
	 down 	: 4  ( 3 )
	 eight 	: 5  ( 3 )
	 five 	: 6  ( 3 )
	 four 	: 7  ( 3 )
	 go 	: 8  ( 3 )
	 happy 	: 9  ( 3 )
	 house 	: 10  ( 3 )
	 left 	: 11  ( 3 )
	 marvin 	: 12  ( 3 )
	 nine 	: 13  ( 3 )
	 no 	: 14  ( 3 )
	 off 	: 15  ( 3 )
	 on 	: 16  ( 3 )
	 one 	: 17  ( 3 )
	 right 	: 18  ( 3 )
	 seven 	: 19  ( 3 )
	 sheila 	: 20  ( 3 )
	 six 	: 21  ( 3 )
	 stop 	: 22  ( 3 )
	 three 	: 23  ( 3 )
	 tree 	: 24  ( 3 )
	 two 	: 25  ( 3 )
	 up 	: 26  ( 3 )
	 wow 	: 27  ( 3 )
	 yes 	: 28  ( 3 )
	 zero 	: 29  ( 3 )
total data size :  90


100%|██████████| 1/1 [00:00<00:00,  1.20it/s]

[92m
[ PROGRESS ] ::  TEST RESULTS
	loss : 0.04041166603565216
	pred_acc : 0.9333333333333333

[0m





In [26]:
# loss_fn = getattr(loss_functions, base_config['loss'])

# base_config['metrics'] = ["pred_acc"]

# metrics = [getattr(metric_functions, met) for met in base_config['metrics']]

In [171]:
total_loss = 0.0
total_metrics = torch.zeros(len(metrics))

with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader)):
            one_hot_target = torch.eye(len(target_class_base))[target]
            
            plt.figure(figsize=[15,15])

            for index, image in enumerate(data):
                plt.subplot(len(data)/ size_per_class_base, size_per_class_base, index+1)
                plt.imshow(np.reshape(torch.squeeze(image), [28,28]), cmap='gray')
                plt.axis('off')
                plt.title(target[index].item())

            output = model_base(data)
            
            # computing loss, metrics on test set
#             loss = loss_fn(output, one_hot_target)
            loss = loss_fn(output, target)
            batch_size = data.shape[0]
            total_loss += loss.item() * batch_size
            for i, metric in enumerate(metrics):
                total_metrics[i] += metric(output, target) * batch_size

n_samples = len(data_loader.sampler)
log = {'loss': total_loss / n_samples}
log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metrics)})

test_result_str = 'TEST RESULTS\n'
for key, val in log.items():
    test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

cp.print_progress(test_result_str)

# fine tuned model - 1

In [172]:
config_1 = torch.load(model_1_cpt)['config']
config_1

{'name': 'mnist_fine_tune',
 'n_gpu': 1,
 'model': {'type': 'LeNet', 'args': {}},
 'data_loader': {'type': 'MnistDataLoader',
  'args': {'data_dir': '/media/brandon/SSD/data/mnist',
   'batch_size': 128,
   'shuffle': True,
   'validation_split': 0.1,
   'num_workers': 2,
   'target_class': [1],
   'unknown': True}},
 'optimizer': {'type': 'Adam',
  'args': {'lr': 0.001, 'weight_decay': 0, 'amsgrad': True}},
 'loss': 'bce_logits_loss',
 'metrics': ['pred_acc'],
 'lr_scheduler': {'type': 'StepLR', 'args': {'step_size': 50, 'gamma': 0.1}},
 'trainer': {'epochs': 10,
  'save_dir': 'saved/',
  'save_period': 1,
  'verbosity': 2,
  'monitor': 'min val_loss',
  'early_stop': 10,
  'tensorboardX': True,
  'log_dir': 'saved/runs'},
 'target_class': [1]}

In [173]:
model_1 = util.get_instance(models, 'model', config_1)
layer_id = model_1.swap_fc(len(config_1['target_class']) + 1)
model_1

LeNet(
  (conv1): Conv2d(1, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(5, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=80, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=2, bias=True)
)

In [174]:
checkpoint_1 = torch.load(model_1_cpt)
state_dict_1 = checkpoint_1['state_dict']

In [175]:
# model_1.load_state_dict(state_dict_1)
# model_1.eval()

In [176]:
# size_per_class_1 = 2
# print("target class :", config_1['target_class'])

In [177]:
# # data_loader = getattr(data_loaders, config_1['data_loader']['type'])(
# #     config_1['data_loader']['args']['data_dir'],
# #     batch_size=512,
# #     shuffle=False,
# #     validation_split=0.0,
# #     training=False,
# #     num_workers=2,
# #     size_per_class=size_per_class_1,
# #     target_class=config_1['target_class'],
# #     unknown=True
# # )

# data_loader = getattr(data_loaders, config_1['data_loader']['type'])(
#     config_1['data_loader']['args']['data_dir'],
#     batch_size=512,
#     shuffle=False,
#     validation_split=0.0,
#     training=False,
#     num_workers=2,
#     size_per_class=size_per_class_1,
#     target_class=[1,2],
#     unknown=False
# )

In [178]:
# total_loss = 0.0
# total_metrics = torch.zeros(len(metrics))

# with torch.no_grad():
#         for i, (data, target) in enumerate(tqdm(data_loader)):
#             one_hot_target = torch.eye(2)[target]
            
#             plt.figure(figsize=[5,5])

#             for index, image in enumerate(data):
#                 plt.subplot(len(data)/ size_per_class_1, size_per_class_1, index+1)
#                 plt.imshow(np.reshape(torch.squeeze(image), [28,28]), cmap='gray')
#                 plt.axis('off')
#                 plt.title(target[index].item())

#             output = model_1(data)
            
#             # computing loss, metrics on test set
# #             loss = loss_fn(output, one_hot_target)
#             loss = loss_fn(output, target)
#             batch_size = data.shape[0]
#             total_loss += loss.item() * batch_size
#             for i, metric in enumerate(metrics):
#                 total_metrics[i] += metric(output, target) * batch_size

# n_samples = len(data_loader.sampler)
# log = {'loss': total_loss / n_samples}
# log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metrics)})

# test_result_str = 'TEST RESULTS\n'
# for key, val in log.items():
#     test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

# cp.print_progress(test_result_str)

# fine tuned model - 2

In [179]:
config_2 = torch.load(model_2_cpt)['config']
config_2

{'name': 'mnist_fine_tune',
 'n_gpu': 1,
 'model': {'type': 'LeNet', 'args': {}},
 'data_loader': {'type': 'MnistDataLoader',
  'args': {'data_dir': '/media/brandon/SSD/data/mnist',
   'batch_size': 128,
   'shuffle': True,
   'validation_split': 0.1,
   'num_workers': 2,
   'target_class': [2],
   'unknown': True}},
 'optimizer': {'type': 'Adam',
  'args': {'lr': 0.001, 'weight_decay': 0, 'amsgrad': True}},
 'loss': 'bce_logits_loss',
 'metrics': ['pred_acc'],
 'lr_scheduler': {'type': 'StepLR', 'args': {'step_size': 50, 'gamma': 0.1}},
 'trainer': {'epochs': 10,
  'save_dir': 'saved/',
  'save_period': 1,
  'verbosity': 2,
  'monitor': 'min val_loss',
  'early_stop': 10,
  'tensorboardX': True,
  'log_dir': 'saved/runs'},
 'target_class': [2]}

In [180]:
model_2 = util.get_instance(models, 'model', config_2)
layer_id = model_2.swap_fc(len(config_2['target_class']) + 1)
model_2

LeNet(
  (conv1): Conv2d(1, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(5, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=80, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=2, bias=True)
)

In [181]:
checkpoint_2 = torch.load(model_2_cpt)
state_dict_2 = checkpoint_2['state_dict']

In [182]:
# model_2.load_state_dict(state_dict_2)
# model_2.eval()

In [183]:
# size_per_class_2 = 2
# print("target class :", config_2['target_class'])

In [184]:
# data_loader = getattr(data_loaders, config_2['data_loader']['type'])(
#     config_2['data_loader']['args']['data_dir'],
#     batch_size=512,
#     shuffle=False,
#     validation_split=0.0,
#     training=False,
#     num_workers=2,
#     size_per_class=size_per_class_2,
#     target_class=[2,1],
#     unknown=False
# )

In [185]:
# total_loss = 0.0
# total_metrics = torch.zeros(len(metrics))

# with torch.no_grad():
#         for i, (data, target) in enumerate(tqdm(data_loader)):
#             one_hot_target = torch.eye(2)[target]
            
#             plt.figure(figsize=[5,5])

#             for index, image in enumerate(data):
#                 plt.subplot(len(data)/ size_per_class_2, size_per_class_2, index+1)
#                 plt.imshow(np.reshape(torch.squeeze(image), [28,28]), cmap='gray')
#                 plt.axis('off')
#                 plt.title(target[index].item())

#             output = model_2(data)
            
#             # computing loss, metrics on test set
# #             loss = loss_fn(output, one_hot_target)
#             loss = loss_fn(output, target)
#             batch_size = data.shape[0]
#             total_loss += loss.item() * batch_size
#             for i, metric in enumerate(metrics):
#                 total_metrics[i] += metric(output, target) * batch_size

# n_samples = len(data_loader.sampler)
# log = {'loss': total_loss / n_samples}
# log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metrics)})

# test_result_str = 'TEST RESULTS\n'
# for key, val in log.items():
#     test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

# cp.print_progress(test_result_str)

# combined model

In [186]:
target_class = [1, 2]
base_config = torch.load(base_model_cpt)['config']
model = util.get_instance(models, 'model', base_config)
checkpoint_base = torch.load(base_model_cpt)
state_dict_base = checkpoint_base['state_dict']

model.load_state_dict(state_dict_base)


# print('base')
# for k,v in model_base.named_parameters():
#     print(k, torch.mean(v).item())

# print('old')
# for k,v in model_1.named_parameters():
#     print(k, torch.mean(v).item())
    
# print('old2')
# for k,v in model_2.named_parameters():
#     print(k, torch.mean(v).item())
    
# print('combined')
# for k,v in model.named_parameters():
#     print(k, torch.mean(v).item())
    


In [187]:
weight_1 = state_dict_1["fc2.weight"][0]
bias_1 = state_dict_1["fc2.bias"][0]

weight_2 = state_dict_2["fc2.weight"][0]
bias_2 = state_dict_2["fc2.bias"][0]

weight_list = [weight_1, weight_2]
bias_list = [bias_1, bias_2]

weight = torch.stack(weight_list)
bias = torch.stack(bias_list)

In [188]:
layer_id = model.swap_fc(len(target_class))
model.fc2.weight = torch.nn.Parameter(weight)
model.fc2.bias = torch.nn.Parameter(bias)

model.to(torch.device('cpu'))
model.eval()

LeNet(
  (conv1): Conv2d(1, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(5, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=80, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=2, bias=True)
)

In [189]:
loss_fn = getattr(loss_functions, base_config['loss'])

base_config['metrics'] = ["pred_acc"]

metrics = [getattr(metric_functions, met) for met in base_config['metrics']]

In [190]:
# size_per_class = 1000
# data_loader = getattr(data_loaders, config_1['data_loader']['type'])(
#     config_1['data_loader']['args']['data_dir'],
#     batch_size=512,
#     shuffle=False,
#     validation_split=0.0,
#     training=False,
#     num_workers=2,
#     size_per_class=size_per_class,
#     target_class=target_class,
#     unknown=False
# )

In [191]:
import torch.nn as nn

total_loss = 0.0
total_metrics = torch.zeros(len(metrics))

# plt.figure(figsize=[15,15])
counter = 0
print(folder_name)

with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader)):        
            one_hot_target = torch.eye(2)[target]
            
            output = model(data)
    
            for index, image in enumerate(data):
                result_1 = model_1(torch.unsqueeze(image, dim=0)).tolist()[0][0]
                result_2 = model_2(torch.unsqueeze(image, dim=0)).tolist()[0][0]
                
                arr = [result_1, result_2]
                result = model(torch.unsqueeze(image, dim=0)).tolist()[0]
#                 print(index)
#                 print('softmax', nn.Softmax()(model(torch.unsqueeze(image, dim=0))))
#                 print('sigmoid', nn.Sigmoid()(model(torch.unsqueeze(image, dim=0))))
#                 break
                
                if target[index].item() != np.argmax(np.array(result)):
#                     plt.subplot(20, 1, counter+1)
#                     counter += 1
#                     plt.imshow(np.reshape(torch.squeeze(image), [28,28]), cmap='gray')
#                     plt.axis('off')
#                     summary = str(target[index].item()) + " : " + str( arr ) + " ->" + str(result)
#                     plt.title(str(target[index].item()) + " : " + str( arr ))

                    print(str(target[index].item()) + " : " + str( result ))
#                     print('\t', nn.Softmax()(model(torch.unsqueeze(image, dim=0))))
                
            
            # computing loss, metrics on test set
            loss = loss_fn(output, one_hot_target)
#             print('target', target.shape)
#             print('output', output.shape)
#             loss = loss_fn(output, target)
            batch_size = data.shape[0]
            total_loss += loss.item() * batch_size
            for i, metric in enumerate(metrics):
                total_metrics[i] += metric(output, target) * batch_size

n_samples = len(data_loader.sampler)
log = {'loss': total_loss / n_samples}
log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metrics)})

test_result_str = 'TEST RESULTS\n'
for key, val in log.items():
    test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

cp.print_progress(test_result_str)

  0%|          | 0/4 [00:00<?, ?it/s]

test/identity_softmax_bce


  return F.softmax(x)


0 : [0.40989843010902405, 0.5901015996932983]
0 : [0.09660104662179947, 0.9033989906311035]
0 : [0.17841428518295288, 0.8215857148170471]
0 : [0.4459337890148163, 0.5540662407875061]
0 : [0.03756008297204971, 0.9624398946762085]
0 : [0.19705691933631897, 0.8029430508613586]
0 : [0.44178715348243713, 0.5582128167152405]
0 : [0.04545839875936508, 0.9545416235923767]
0 : [0.09829176217317581, 0.9017082452774048]


 25%|██▌       | 1/4 [00:00<00:01,  1.92it/s]

0 : [0.013788566924631596, 0.9862114191055298]
0 : [0.32151496410369873, 0.6784849762916565]
0 : [0.27290472388267517, 0.7270951867103577]
0 : [0.17582440376281738, 0.8241756558418274]
0 : [0.4444516897201538, 0.5555483102798462]
0 : [0.4328670799732208, 0.5671328902244568]


 50%|█████     | 2/4 [00:00<00:00,  2.48it/s]

0 : [0.38003700971603394, 0.6199629902839661]
0 : [0.4565279185771942, 0.5434721112251282]


100%|██████████| 4/4 [00:01<00:00,  2.70it/s]


softmax
[92m
[ PROGRESS ] ::  TEST RESULTS
	loss : 0.5243982200622559
	pred_acc : 0.9915

[0m



