In [1]:
import numpy as np

from torch.backends import cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from thesis_v2.models.pcn_local.feature_extraction import (
    process_one_case_wrapper
)

from thesis_v2.models.pcn_local.reference.loader import get_pretrained_network

from thesis_v2 import dir_dict, join

# this can save memory error.
# See <https://github.com/pytorch/pytorch/issues/1230>
cudnn.benchmark = False
cudnn.enabled = True


def load_image_dataset(image_dataset_key):
    assert image_dataset_key == 'first1000'
    torch.manual_seed(0)
    valdir = join('/my_data_2/standard_datasets/ILSVRC2015/Data/CLS-LOC', 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        # this batch size should be safe regardless of network type (3CLS or 5CLS)

        # shuffle to make sure we do not get images from the same class.
        batch_size=100, shuffle=True)

    # get data in numpy
    data = []

    for i, (x, _) in enumerate(val_loader):
        data.append(x.numpy())
        if i >= 9:
            break

    data = np.concatenate(data)
    assert data.shape == (1000, 3, 224, 224)
    return data

In [2]:
a = load_image_dataset('first1000')

In [3]:
model = get_pretrained_network(
    'PredNetBpE_3CLS',
    root_dir=join(
        dir_dict['root'], '..', 'thesis-yimeng-v1', '3rdparty', 'PCN-with-Local-Recurrent-Processing', 'checkpoint'
    )
)

=> creating model 'PredNetBpE_3CLS'


In [4]:
# pass through first 100 images
cudnn.benchmark = False
cudnn.deterministic = True

from torch import tensor, no_grad

input_test = torch.tensor(a[:100]).cuda()
model.cuda().eval()

with no_grad():
    validate_data_debug = model(input_test).cpu().numpy()

In [5]:
validate_data = np.load('./check_pcn_loading_first100.npy')

In [6]:
assert np.array_equal(validate_data, validate_data_debug)