In [77]:
import os
import argparse
import torch
from tqdm import tqdm
import pprint
import matplotlib.image as mpimg

import data_loader as data_loaders
import model as models
import trainer.loss as loss_functions
import trainer.metric as metric_functions

import utils.util as util
from utils import color_print as cp

import numpy as np
import matplotlib.pylab as plt

In [78]:
# task = "kws_res15_narrow"
task = "cifar100"
eval_target_class = list(np.arange(7))

base_model_config = 'config/{}_base.json'.format(task)
fine_model_config = 'config/{}_fine_tune.json'.format(task)

# folder_name = 'dev'
# folder_name = 'cifar100_40/logsoftmax_nll_loss/0'
folder_name = 'cifar100_40/sigmoid_bce_loss/0'
# folder_name = 'cifar100_40/softmax_bce_loss/0'
most_recent_model = max(os.listdir(folder_name+"/"+task+"_base"))
base_model_cpt = os.path.join(folder_name, task+'_base', most_recent_model, 'model_best.pth')
print('base_model', base_model_cpt)

cpts = {}

for i in eval_target_class:
    most_recent_model = max(os.listdir(folder_name+"/"+task+"_fine_tune/"+str(i)+"/"))
    cpts[i] = os.path.join(folder_name, task+'_fine_tune/'+str(i)+'/', most_recent_model, 'model_best.pth')
    print('class '+str(i)+' model', cpts[i])



base_model cifar100_40/sigmoid_bce_loss/0/cifar100_base/0405_051920/model_best.pth
class 0 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/0/0405_074201/model_best.pth
class 1 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/1/0405_074304/model_best.pth
class 2 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/2/0405_074327/model_best.pth
class 3 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/3/0405_074346/model_best.pth
class 4 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/4/0405_074441/model_best.pth
class 5 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/5/0405_074510/model_best.pth
class 6 model cifar100_40/sigmoid_bce_loss/0/cifar100_fine_tune/6/0405_074535/model_best.pth


In [79]:
eval_all_model = False

# base model

In [80]:
base_config = torch.load(base_model_cpt)['config']
if "media" not in base_config['data_loader']['args']['data_dir']:
    base_config['data_loader']['args']['data_dir'] = "/media/brandon/SSD" + base_config['data_loader']['args']['data_dir']

In [81]:
model_base = util.get_instance(models, 'model', base_config)

In [82]:
checkpoint_base = torch.load(base_model_cpt)
state_dict_base = checkpoint_base['state_dict']

In [83]:
if eval_all_model:
    model_base.load_state_dict(state_dict_base)
    model_base.eval()
    
    base_target_class = list(np.arange(base_config['n_class']))
    base_size_per_class = 3

    base_data_loader = getattr(data_loaders, base_config['data_loader']['type'])(
        base_config['data_loader']['args']['data_dir'],
        batch_size=512,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=2,
        size_per_class=base_size_per_class,
        target_class=base_target_class,
        unknown=True
    )
    
    base_loss_fn = getattr(loss_functions, base_config['loss'])

    base_config['metrics'] = ["pred_acc"]

    base_metrics = [getattr(metric_functions, met) for met in base_config['metrics']]
    
    total_loss = 0.0
    total_metrics = torch.zeros(len(base_metrics))

    with torch.no_grad():
            for i, (data, target) in enumerate(tqdm(base_data_loader)):
                one_hot_target = torch.eye(len(base_target_class))[target]

#                 if "cifar" in task:
#                     plt.figure(figsize=[10,10])

#                     for index, image in enumerate(data):
#                         plt.subplot(len(data)/ base_size_per_class, base_size_per_class, index+1)
#                         plt.imshow((np.moveaxis(image.numpy(), 0, -1) * 255).astype(np.uint8))
#                         plt.axis('off')
#                         plt.title(target[index].item())

                output = model_base(data)

                # computing loss, metrics on test set
                loss = base_loss_fn(output, one_hot_target)
#                 loss = base_loss_fn(output, target)
                batch_size = data.shape[0]
                total_loss += loss.item() * batch_size
                for i, metric in enumerate(base_metrics):
                    total_metrics[i] += metric(output, target) * batch_size

    n_samples = len(base_data_loader.sampler)
    log = {'loss': total_loss / n_samples}
    log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(base_metrics)})

    test_result_str = 'TEST RESULTS\n'
    for key, val in log.items():
        test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

    cp.print_progress(test_result_str)

# fine tuned model

In [84]:
fine_tuned_config = {}
fine_tuned_models = {}
fine_tuned_state_dicts = {}

for eval_target in eval_target_class:
    cpt = torch.load(cpts[eval_target])
    fine_tuned_config[eval_target] = cpt['config']
    
    if "media" not in fine_tuned_config[eval_target]['data_loader']['args']['data_dir']:
        fine_tuned_config[eval_target]['data_loader']['args']['data_dir'] = "/media/brandon/SSD" + fine_tuned_config[eval_target]['data_loader']['args']['data_dir']
    
    fine_tuned_models[eval_target] = util.get_instance(models, 'model', fine_tuned_config[eval_target])
    fine_tuned_models[eval_target].swap_fc(len(fine_tuned_config[eval_target]['target_class']) + 1)
    fine_tuned_state_dicts[eval_target] = cpt['state_dict']
    
#     print(eval_target)
#     pprint.pprint(fine_tuned_config[eval_target])

In [85]:
if eval_all_model:
    for eval_target in eval_target_class:
        model = fine_tuned_models[eval_target]
        config = fine_tuned_config[eval_target]
        sd = fine_tuned_state_dicts[eval_target]
        
        model.load_state_dict(sd)
        model.eval()
        
        size_per_class = 20
        fine_tuned_target_class = [eval_target]
        
        fine_tuned_data_loader = getattr(data_loaders, config['data_loader']['type'])(
            config['data_loader']['args']['data_dir'],
            batch_size=512,
            shuffle=False,
            validation_split=0.0,
            training=False,
            num_workers=2,
            size_per_class=size_per_class,
            target_class=fine_tuned_target_class,
            unknown=True
        )
        

        loss_fn = getattr(loss_functions, config['loss'])

        config['metrics'] = ["pred_acc"]

        metrics = [getattr(metric_functions, met) for met in config['metrics']]

        total_loss = 0.0
        total_metrics = torch.zeros(len(base_metrics))

#         pprint.pprint(fine_tuned_data_loader.dataset.classes)
#         pprint.pprint(fine_tuned_data_loader.dataset.keyword_audios)
        
        with torch.no_grad():
            for i, (data, target) in enumerate(tqdm(fine_tuned_data_loader)):
                one_hot_target = torch.eye(len(fine_tuned_target_class) + 1)[target]

#                 if "cifar" in task:
#                     plt.figure(figsize=[15,15])

#                     for index, image in enumerate(data):
#                         plt.subplot(len(data)/ size_per_class, size_per_class, index+1)
#                         plt.imshow((np.moveaxis(image.numpy(), 0, -1) * 255).astype(np.uint8))
#                         plt.axis('off')
#                         plt.title(target[index].item())

                output = model(data)

                # computing loss, metrics on test set
                loss = loss_fn(output, one_hot_target)
#                 loss = base_loss_fn(output, target)
                batch_size = data.shape[0]
                total_loss += loss.item() * batch_size
                for i, metric in enumerate(metrics):
                    total_metrics[i] += metric(output, target) * batch_size

        n_samples = len(fine_tuned_data_loader.sampler)
        log = {'loss': total_loss / n_samples}
        log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metrics)})

        test_result_str = 'TEST RESULTS\n'
        for key, val in log.items():
            test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

        cp.print_progress(test_result_str)

# combined model

In [86]:
combined_config = torch.load(base_model_cpt)['config']
combined_model = util.get_instance(models, 'model', combined_config)
combined_checkpoint = torch.load(base_model_cpt)
combined_state_dict = checkpoint_base['state_dict']

combined_model.load_state_dict(combined_state_dict)

if "media" not in combined_config['data_loader']['args']['data_dir']:
        combined_config['data_loader']['args']['data_dir'] = "/media/brandon/SSD" + combined_config['data_loader']['args']['data_dir']
        
# print('base')
# for k,v in model_base.named_parameters():
#     print(k, torch.mean(v).item())

# print('old')
# for k,v in model_1.named_parameters():
#     print(k, torch.mean(v).item())
    
# print('old2')
# for k,v in model_2.named_parameters():
#     print(k, torch.mean(v).item())
    
# print('combined')
# for k,v in model.named_parameters():
#     print(k, torch.mean(v).item())

In [87]:
weight_list = []
bias_list = []

for eval_target in eval_target_class:
    model = fine_tuned_models[eval_target]
    config = fine_tuned_config[eval_target]
    weight_list.append(fine_tuned_state_dicts[eval_target]["fc.weight"][0])
    bias_list.append(fine_tuned_state_dicts[eval_target]["fc.bias"][0])
#     weight_list.append(fine_tuned_state_dicts[eval_target]["fc.weight"][1])
#     bias_list.append(fine_tuned_state_dicts[eval_target]["fc.bias"][1])

weight = torch.stack(weight_list)
bias = torch.stack(bias_list)

In [88]:
layer_id = combined_model.swap_fc(len(eval_target_class))
combined_model.fc.weight = torch.nn.Parameter(weight)
combined_model.fc.bias = torch.nn.Parameter(bias)

combined_model.to(torch.device('cpu'))
combined_model.eval()

DenseNet(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv1): Conv2d(24, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv1): Conv2d(36, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): BasicBlock(
        (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv1): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): BasicBlock(
        (bn1): BatchNorm2d(60, eps=1e-05, momentum=0.1, affine=True, track_ru

In [89]:
combined_loss_fn = getattr(loss_functions, combined_config['loss'])

combined_config['metrics'] = ["pred_acc"]

combined_metrics = [getattr(metric_functions, met) for met in combined_config['metrics']]

In [90]:
combined_size_per_class = 100
combined_data_loader = getattr(data_loaders, combined_config['data_loader']['type'])(
    combined_config['data_loader']['args']['data_dir'],
    batch_size=512,
    shuffle=False,
    validation_split=0.0,
    training=False,
    num_workers=2,
    size_per_class=combined_size_per_class,
    target_class=eval_target_class,
    unknown=False,
    seed=10
)

Files already downloaded and verified
< Dataset Summary >
	seed	: 10
	 0 - apple 	: 0  ( 100 )
	 1 - aquarium_fish 	: 1  ( 100 )
	 2 - baby 	: 2  ( 100 )
	 3 - bear 	: 3  ( 100 )
	 4 - beaver 	: 4  ( 100 )
	 5 - bed 	: 5  ( 100 )
	 6 - bee 	: 6  ( 100 )
total data size :  700


In [91]:
total_loss = 0.0
total_metrics = torch.zeros(len(combined_metrics))

# pprint.pprint(fine_tuned_data_loader.dataset.classes)
# pprint.pprint(fine_tuned_data_loader.dataset.keyword_audios)

with torch.no_grad():
    for i, (data, target) in enumerate(tqdm(combined_data_loader)):
        one_hot_target = torch.eye(len(eval_target_class))[target]

#         if "cifar" in task:
#             plt.figure(figsize=[15,15])

#             print(len(data))
#             for index, image in enumerate(data):
#                 plt.subplot((len(data)/ combined_size_per_class)+1, combined_size_per_class, index+1)
#                 plt.imshow((np.moveaxis(image.numpy(), 0, -1) * 255).astype(np.uint8))
#                 plt.axis('off')
#                 plt.title(target[index].item())

        output = combined_model(data)

        # computing loss, metrics on test set
        loss = combined_loss_fn(output, one_hot_target)
#                 loss = combined_loss_fn(output, target)
        batch_size = data.shape[0]
        total_loss += loss.item() * batch_size
        for i, metric in enumerate(combined_metrics):
            total_metrics[i] += metric(output, target) * batch_size

n_samples = len(combined_data_loader.sampler)
log = {'loss': total_loss / n_samples}
log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(combined_metrics)})

test_result_str = 'TEST RESULTS\n'
for key, val in log.items():
    test_result_str += ('\t' + str(key) + ' : ' + str(val) + '\n')

cp.print_progress(test_result_str)


  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:15<00:15, 15.70s/it][A
100%|██████████| 2/2 [00:21<00:00, 10.74s/it][A
[A

[92m
[ PROGRESS ] ::  TEST RESULTS
	loss : 1.1529671798433576
	pred_acc : 0.7871428571428571

[0m
