this notebook checks that model trained in <https://github.com/leelabcnbc/thesis-yimeng-v1/blob/25471e6e80f7acd0f2ec82bb9c577c58bfdd7171/3rdparty/PCN-with-Local-Recurrent-Processing/main_imagenet_fp16.py> can be loaded back from its checkpoints.

In [1]:
import numpy as np

import time

import torch.utils.data
import torchvision.transforms as transforms 
import torchvision.datasets as datasets
import torch.backends.cudnn as cudnn

from thesis_v2 import dir_dict, join
from thesis_v2.models.pcn_local.reference import loader

In [2]:
dir_dict

{'root': '/my_data/thesis-yimeng-v2',
 'results': '/my_data/thesis-yimeng-v2/results',
 'datasets': '/my_data/thesis-yimeng-v2/results/datasets',
 'features': '/my_data/thesis-yimeng-v2/results/features',
 'models': '/my_data/thesis-yimeng-v2/results/models',
 'analyses': '/my_data/thesis-yimeng-v2/results/analyses',
 'plots': '/my_data/thesis-yimeng-v2/results/plots',
 'visualization': '/my_data/thesis-yimeng-v2/results/visualization',
 'private_data': '/my_data/thesis-yimeng-v2/private_data',
 'private_data_supp': '/my_data/thesis-yimeng-v2/private_data_supp',
 'debug_data': '/my_data/thesis-yimeng-v2/debug_data',
 'trash': '/my_data/thesis-yimeng-v2/trash'}

In [3]:
best_file = join(dir_dict['root'], '..', 'thesis-yimeng-v1', '3rdparty', 'PCN-with-Local-Recurrent-Processing', 'checkpoint', 'model_best.pth.tar.3CLS')
trained_model, checkpoint = loader.load_pcn_imagenet('PredNetBpE', 3, checkpoint_path=best_file)

=> creating model 'PredNetBpE_3CLS'


In [4]:
trained_model

PredNetBpE(
  (baseconv): features2(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (featBN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (PcConvs): ModuleList(
    (0): PcConvBp(
      (FFconv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (FBconv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b0): ParameterList(  (0): Parameter containing: [torch.FloatTensor of size 1x64x1x1])
      (relu): ReLU(inplace=True)
      (bypass): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (resp_init): Lambda()
      (resp_loop): Lambda()
    )
    (1): PcConvBp(
      (FFconv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (FBconv): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b0): Paramete

In [5]:
# ok. let's create a validation set data loader, making sure we can get back that validation accuracy from
# `checkpoint`.

In [6]:
checkpoint.keys()

dict_keys(['epoch', 'name', 'state_dict', 'best_prec1', 'prec1', 'prec5', 'optimizer'])

In [7]:
checkpoint['epoch'], checkpoint['best_prec1'], checkpoint['prec1'], checkpoint['prec5']

(92, 74.838, 74.838, 92.232)

In [8]:
def get_val_loader():
    valdir = join('/my_data_2/standard_datasets/ILSVRC2015/Data/CLS-LOC', 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        # this batch size should be safe regardless of network type (3CLS or 5CLS)
        batch_size=100, shuffle=False)
    return val_loader

In [9]:
def validate(val_loader, model, crop_num = 1):
    # switch to evaluate mode
    model.eval()
    
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(non_blocking=True)
            return model(input).cpu().numpy()

In [10]:
# # deterministic setting.
cudnn.benchmark = False
cudnn.deterministic = True

In [11]:
# cool, even higher.
trained_model.cuda()
validate_data = validate(get_val_loader(), trained_model)

np.save('check_pcn_loading_first100.npy', validate_data)

In [12]:
print(validate_data.shape)

(100, 1000)
