this notebook checks that model trained in <https://github.com/leelabcnbc/thesis-yimeng-v1/blob/25471e6e80f7acd0f2ec82bb9c577c58bfdd7171/3rdparty/PCN-with-Local-Recurrent-Processing/main_imagenet_fp16.py> can be loaded back from its checkpoints.

In [1]:
import time

import torch.utils.data
import torchvision.transforms as transforms 
import torchvision.datasets as datasets
import torch.backends.cudnn as cudnn

from thesis_v2 import dir_dict, join
from thesis_v2.models.pcn_local.reference import loader

In [2]:
dir_dict

{'root': '/my_data/thesis-yimeng-v2',
 'results': '/my_data/thesis-yimeng-v2/results',
 'datasets': '/my_data/thesis-yimeng-v2/results/datasets',
 'features': '/my_data/thesis-yimeng-v2/results/features',
 'models': '/my_data/thesis-yimeng-v2/results/models',
 'analyses': '/my_data/thesis-yimeng-v2/results/analyses',
 'plots': '/my_data/thesis-yimeng-v2/results/plots',
 'visualization': '/my_data/thesis-yimeng-v2/results/visualization',
 'private_data': '/my_data/thesis-yimeng-v2/private_data',
 'private_data_supp': '/my_data/thesis-yimeng-v2/private_data_supp',
 'debug_data': '/my_data/thesis-yimeng-v2/debug_data',
 'trash': '/my_data/thesis-yimeng-v2/trash'}

In [3]:
best_file = join(dir_dict['root'], '..', 'thesis-yimeng-v1', '3rdparty', 'PCN-with-Local-Recurrent-Processing', 'checkpoint', 'model_best.pth.tar.3CLS')
trained_model, checkpoint = loader.load_pcn_imagenet('PredNetBpE', 3, checkpoint_path=best_file)

=> creating model 'PredNetBpE_3CLS'


In [4]:
trained_model

PredNetBpE(
  (baseconv): features2(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (featBN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (PcConvs): ModuleList(
    (0): PcConvBp(
      (FFconv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (FBconv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b0): ParameterList(  (0): Parameter containing: [torch.FloatTensor of size 1x64x1x1])
      (relu): ReLU(inplace=True)
      (bypass): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (resp_init): Lambda()
      (resp_loop): Lambda()
    )
    (1): PcConvBp(
      (FFconv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (FBconv): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b0): Paramete

In [5]:
# ok. let's create a validation set data loader, making sure we can get back that validation accuracy from
# `checkpoint`.

In [6]:
checkpoint.keys()

dict_keys(['epoch', 'name', 'state_dict', 'best_prec1', 'prec1', 'prec5', 'optimizer'])

In [7]:
checkpoint['epoch'], checkpoint['best_prec1'], checkpoint['prec1'], checkpoint['prec5']

(92, 74.838, 74.838, 92.232)

In [8]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [14]:
def get_val_loader():
    valdir = join('/my_data_2/standard_datasets/ILSVRC2015/Data/CLS-LOC', 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        # this batch size should be safe regardless of network type (3CLS or 5CLS)
        batch_size=128, shuffle=False)
    return val_loader

In [15]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [16]:
def validate(val_loader, model, crop_num = 1):
    batch_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
#             input_var = input
#             target_var = target

            # compute output
            output = model(input)
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            #print validation status with certain frequency 'print_freq'
            if i % 10 == 0: 
                print('{0}-crop-validation\t'
                      'Test: [{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       crop_num, i, len(val_loader), batch_time=batch_time,
                       top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, top5.avg

In [17]:
cudnn.enabled = True


# the setting used in training.
cudnn.benchmark = True

# # deterministic setting.
# cudnn.benchmark = False
# cudnn.deterministic = True

In [18]:
# cool, even higher.
trained_model.cuda()
validate(get_val_loader(), trained_model)

1-crop-validation	Test: [0/391]	Time 4.491 (4.491)	Prec@1 88.281 (88.281)	Prec@5 96.875 (96.875)
1-crop-validation	Test: [10/391]	Time 1.840 (1.873)	Prec@1 75.781 (89.276)	Prec@5 95.312 (97.372)
1-crop-validation	Test: [20/391]	Time 1.535 (1.731)	Prec@1 85.938 (84.003)	Prec@5 95.312 (96.168)
1-crop-validation	Test: [30/391]	Time 1.639 (1.683)	Prec@1 80.469 (80.494)	Prec@5 90.625 (95.312)
1-crop-validation	Test: [40/391]	Time 1.561 (1.650)	Prec@1 86.719 (82.698)	Prec@5 92.969 (95.694)
1-crop-validation	Test: [50/391]	Time 1.517 (1.641)	Prec@1 95.312 (82.384)	Prec@5 97.656 (95.374)
1-crop-validation	Test: [60/391]	Time 1.475 (1.628)	Prec@1 78.906 (83.376)	Prec@5 91.406 (95.543)
1-crop-validation	Test: [70/391]	Time 1.521 (1.617)	Prec@1 74.219 (82.152)	Prec@5 96.875 (95.412)
1-crop-validation	Test: [80/391]	Time 1.538 (1.607)	Prec@1 77.344 (81.414)	Prec@5 96.875 (95.448)
1-crop-validation	Test: [90/391]	Time 1.540 (1.600)	Prec@1 74.219 (81.439)	Prec@5 94.531 (95.441)
1-crop-validation	Tes

(74.852, 92.24)