In [5]:
import argparse
import os
import time

#from omnidata import OmniglotSetsDataset, load_mnist_test_batch
from omnidata import load_mnist
from omnimodel import Statistician
from omniplot import save_test_grid
from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils import data
from tqdm import tqdm

import gzip
import numpy as np
import os
import pickle
import torch

from skimage.transform import rotate
from torch.utils import data

try:
    from utils import (kl_diagnormal_diagnormal, kl_diagnormal_stdnormal,
                       gaussian_log_likelihood)
except ModuleNotFoundError:
    # put parent directory in path for utils
    sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
    from utils import (kl_diagnormal_diagnormal, kl_diagnormal_stdnormal,
                       gaussian_log_likelihood)

In [46]:
class OmniglotSetsDataset(data.Dataset):
    def __init__(self, data_dir, sample_size=5, split='train', augment=False):
        self.split = split
        self.sample_size = sample_size
        path = os.path.join(data_dir, 'train_val_test_split.pkl')
        with open(path, 'rb') as file:
            splits = pickle.load(file)
        if split == 'train':
            images, labels = splits[:2]
            sets, set_labels = self.make_sets(images, labels)
            if augment:
                sets = self.augment_sets(sets)
            else:
                sets = np.random.binomial(1, p=sets, size=sets.shape).astype(np.float32)
        elif split == 'valid':
            images, labels = splits[2:4]
            sets, set_labels = self.make_sets(images, labels)
        elif split == 'test':
            images, labels = splits[4:]
            sets, set_labels = self.make_sets(images, labels)
        elif split == "kshot":
            sets, set_labels = splits[4:]
        else:
            "Unrecognized split, returning None."
            sets, set_labels = None, None
        if split != "kshot":
            sets = sets.reshape(-1, 5, 1, 28, 28)

        self.n = len(sets)
        self.data = {
            'inputs': sets,
            'targets': set_labels
        }

    def __getitem__(self, item):
        if self.split == "kshot":
            return (self.data['inputs'][item], self.data['targets'][item])
        else:
            return self.data['inputs'][item]

    def __len__(self):
        return self.n

    def augment_sets(self, sets):
        augmented = np.copy(sets)
        augmented = augmented.reshape(-1, self.sample_size, 28, 28)
        n_sets = len(augmented)

        for s in range(n_sets):
            flip_horizontal = np.random.choice([0, 1])
            flip_vertical = np.random.choice([0, 1])
            if flip_horizontal:
                augmented[s] = augmented[s, :, :, ::-1]
            if flip_vertical:
                augmented[s] = augmented[s, :, ::-1, :]

        for s in range(n_sets):
            angle = np.random.uniform(0, 360)
            for item in range(self.sample_size):
                augmented[s, item] = rotate(augmented[s, item], angle)
        augmented = np.concatenate([augmented.reshape(n_sets, self.sample_size, 28*28),
                                    sets])

        return augmented

    @staticmethod
    def one_hot(dense_labels, num_classes):
        num_labels = len(dense_labels)
        offset = np.arange(num_labels) * num_classes
        one_hot_labels = np.zeros((num_labels, num_classes))
        one_hot_labels.flat[offset + dense_labels.ravel()] = 1
        return one_hot_labels

    def make_sets(self, images, labels):
        num_classes = np.max(labels) + 1
        labels = self.one_hot(labels, num_classes)

        n = len(images)
        perm = np.random.permutation(n)
        images = images[perm]
        labels = labels[perm]

        image_sets = []
        set_labels = []

        for i in range(num_classes):
            ix = labels[:, i].astype(bool)
            num_instances_of_class = np.sum(ix)
            if num_instances_of_class < self.sample_size:
                pass
            else:
                remainder = num_instances_of_class % self.sample_size
                image_set = images[ix]
                if remainder > 0:
                    image_set = image_set[:-remainder]
                image_sets.append(image_set)
                k = len(image_set)
                set_labels.append(labels[ix][:int(k / self.sample_size)])

        x = np.concatenate(image_sets, axis=0).reshape(-1, self.sample_size, 28*28)
        y = np.concatenate(set_labels, axis=0)
        if np.max(x) > 1:
            x /= 255

        perm = np.random.permutation(len(x))
        x = x[perm]
        y = y[perm]

        return x, y

In [13]:
data_dir = "../omniglot-data"
batch_size = 32
train_dataset = OmniglotSetsDataset(data_dir=data_dir, split='train', augment=True)
test_dataset = OmniglotSetsDataset(data_dir=data_dir, split='test')
datasets = (train_dataset, test_dataset)

# create loaders
train_loader = data.DataLoader(dataset=train_dataset, batch_size=batch_size,
                               shuffle=True, num_workers=0, drop_last=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=batch_size,
                              shuffle=False, num_workers=0, drop_last=True)
loaders = (train_loader, test_loader)
    
train_dataset, test_dataset = datasets
train_loader, test_loader = loaders

In [49]:
kshot.data['targets']

array([1533, 1533, 1533, ..., 1592, 1592, 1592])

In [50]:
kshot.data['inputs']

array([[0.000000e+00, 4.440892e-16, 0.000000e+00, ..., 4.440892e-16,
        0.000000e+00, 0.000000e+00],
       [0.000000e+00, 4.440892e-16, 0.000000e+00, ..., 4.440892e-16,
        0.000000e+00, 0.000000e+00],
       [0.000000e+00, 4.440892e-16, 0.000000e+00, ..., 4.440892e-16,
        0.000000e+00, 0.000000e+00],
       ...,
       [0.000000e+00, 4.440892e-16, 0.000000e+00, ..., 4.440892e-16,
        0.000000e+00, 0.000000e+00],
       [0.000000e+00, 4.440892e-16, 0.000000e+00, ..., 4.440892e-16,
        0.000000e+00, 0.000000e+00],
       [0.000000e+00, 4.440892e-16, 0.000000e+00, ..., 4.440892e-16,
        0.000000e+00, 0.000000e+00]], dtype=float32)

In [None]:
for i in range(K):
    idx = np.where(lb == i)[0]
    np.random.shuffle(idx)
    D.append(images[idx[:W]])
    samples = images[idx[W : W + n_test_samples]]
    x_test.append(samples)
    x_labels += ([i] * samples.shape[0])
x_labels = np.array(x_labels)
x_test = np.vstack(x_test)
D = np.array(D)

In [57]:
from collections import Counter


In [82]:
def get_omniglot_Kshot(K=5, support=1):
    kshot =  OmniglotSetsDataset(data_dir=data_dir, split='kshot')
    test_classes = list(Counter(kshot.data['targets']).keys())
    chosen_K = np.random.choice(test_classes, K)
    n_test_samples = 20 - support
    D = []
    x_test = []
    x_labels = []
    lb = kshot.data['targets']
    images = kshot.data['inputs']
    for i in range(K):
        idx = chosen_K[i]
        targets_idx = np.where(lb == idx)[0]
        actual_inputs = images[targets_idx[:support]]
        D.append(actual_inputs)
        samples = images[targets_idx[support : support + n_test_samples]]
        x_test.append(samples)
        x_labels += ([i] * samples.shape[0])
    x_labels = np.array(x_labels)
    x_test = np.vstack(x_test)
    D = np.array(D)
    return D, x_test, x_labels

In [90]:
import argparse
import os
import time

from omnidata import load_mnist
from omnimodel import Statistician
from omniplot import save_test_grid
from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils import data
from tqdm import tqdm

import gzip
import numpy as np
import os
import pickle
import torch

from skimage.transform import rotate
from torch.utils import data

try:
    from utils import (kl_diagnormal_diagnormal, kl_diagnormal_stdnormal,
                       gaussian_log_likelihood)
except ModuleNotFoundError:
    # put parent directory in path for utils
    sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
    from utils import (kl_diagnormal_diagnormal, kl_diagnormal_stdnormal,
                       gaussian_log_likelihood)

n_features = 256 * 4 * 4  # output shape of convolutional encoder
# create model
from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils import data
from tqdm import tqdm
model_kwargs = {
    'batch_size': 32,
    'sample_size': 5,
    'n_features': n_features,
    'c_dim': 512,
    'n_hidden_statistic': 3,
    'hidden_dim_statistic': 256,
    'n_stochastic':1,
    'z_dim': 16,
    'n_hidden': 3,
    'hidden_dim': 256,
    'nonlinearity': F.elu,
    'print_vars': False
}

model = Statistician(**model_kwargs)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

#filename = "../outputdilate4/checkpoints/15-02-2019-03:43:01-400.m"
filename = "../outputdropout/checkpoints/23-02-2019-16:39:28-400.m"
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.eval()


def classify_datapoint(x, D_means, D_vars, K, single_sample=True):
    dataset = torch.from_numpy(x)
    with torch.no_grad():
        inputs = Variable(dataset.cuda())
    h1 = model.shared_convolutional_encoder(inputs)
    c_mean_, c_logvar_ = model.statistic_network(h1, summarize=False, single_sample=single_sample)
    kl_divergences = []
    for i in range(K):
        kl = kl_diagnormal_diagnormal(D_means[i], D_vars[i], c_mean_, c_logvar_)
        kl_divergences.append(kl.data.item())
    best_index = kl_divergences.index(min(kl_divergences))
    return best_index

def omniglot_one_shot(K=10, support=1, n_trials=10):
    n_test_samples = 20 - support
    W = support
    accs = []
    for trial in tqdm(range(n_trials)):
        D, x_test, x_labels = get_omniglot_Kshot(K=K, support=support)
        D_means = []
        D_vars = []
        for i in range(K):
            dataset = torch.from_numpy(D[i])
            with torch.no_grad():
                inputs = Variable(dataset.cuda())
            h = model.shared_convolutional_encoder(inputs)
            model.eval()
            c_mean_full, c_logvar_full = model.statistic_network(h, summarize=True)
            D_means.append(c_mean_full)
            D_vars.append(c_logvar_full)
        test_loader = data.DataLoader(dataset=x_test, batch_size=n_test_samples,
                                  shuffle=False, num_workers=0, drop_last=False)
        preds = []
        for batch in test_loader:
            with torch.no_grad():
                inputs = Variable(batch.cuda())
            h1 = model.shared_convolutional_encoder(inputs)
            c_mean_, c_logvar_ = model.statistic_network(h1, single_sample=True)
            for bi, x in enumerate(batch):
                kl_divergences = []
                for i in range(K):
                    kl = kl_diagnormal_diagnormal(D_means[i], D_vars[i], c_mean_[bi], c_logvar_[bi])
                    kl_divergences.append(kl.data.item())
                best_index = kl_divergences.index(min(kl_divergences))
                preds.append(best_index)
        print(preds)
        print(x_labels)
        acc = np.mean(np.array(preds) == x_labels)
        print(acc)
        accs.append(acc)
    return accs, preds, x_labels

#accs_1 = mnist_one_shot(support=1, n_test_samples=10, n_trials=100)
# print("1-shot: {}".format(np.mean(accs_1)))
# accs_5 = mnist_one_shot(support=5, n_test_samples=10, n_trials=100)
# print("5-shot: {}".format(np.mean(accs_5)))


In [92]:
accs_1, b, c = omniglot_one_shot(K=5, support=5, n_trials=1)

100%|██████████| 1/1 [00:02<00:00,  2.80s/it]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 4]
0.9866666666666667





In [89]:
np.array(c).shape

(94,)

In [5]:
n_features = 256 * 4 * 4  # output shape of convolutional encoder
# create model
from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils import data
from tqdm import tqdm
model_kwargs = {
    'batch_size': 32,
    'sample_size': 5,
    'n_features': n_features,
    'c_dim': 512,
    'n_hidden_statistic': 3,
    'hidden_dim_statistic': 256,
    'n_stochastic':1,
    'z_dim': 16,
    'n_hidden': 3,
    'hidden_dim': 256,
    'nonlinearity': F.elu,
    'print_vars': False
}

model = Statistician(**model_kwargs)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
filename = "../outputdilate4/checkpoints/15-02-2019-03:43:01-400.m"
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.eval()

Statistician(
  (shared_convolutional_encoder): SharedConvolutionalEncoder(
    (conv_layers): ModuleList(
      (0): Conv2d3x3(
        (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): Conv2d3x3(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (2): Conv2d3x3(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (3): Conv2d3x3(
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (4): Conv2d3x3(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (5): Conv2d3x3(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (6): Conv2d3x3(
        (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (7): Conv2d3x3(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1

In [25]:
train_dataset.data['inputs'].shape

(9536, 5, 1, 28, 28)

In [24]:
train_dataset.data['targets'].shape

(4768, 1200)

In [26]:
9536/5

1907.2

In [27]:
train_dataset.n

9536

In [44]:
labels = test_dataset.data['targets']

In [32]:
test_dataset.data['inputs'].shape

(394, 5, 1, 28, 28)

In [52]:
a = np.argmax(labels, axis = 1)

In [53]:
a

array([1591, 1621, 1523, 1619, 1539, 1536, 1553, 1557, 1579, 1620, 1585,
       1571, 1584, 1607, 1578, 1583, 1531, 1603, 1585, 1582, 1568, 1523,
       1530, 1574, 1536, 1566, 1535, 1528, 1526, 1523, 1543, 1534, 1563,
       1558, 1573, 1598, 1597, 1524, 1600, 1524, 1593, 1533, 1617, 1547,
       1593, 1555, 1543, 1578, 1615, 1589, 1585, 1569, 1610, 1540, 1586,
       1608, 1600, 1590, 1616, 1549, 1604, 1590, 1558, 1607, 1589, 1610,
       1584, 1613, 1620, 1616, 1619, 1608, 1604, 1542, 1563, 1575, 1529,
       1576, 1588, 1582, 1523, 1612, 1592, 1583, 1619, 1561, 1622, 1546,
       1610, 1596, 1592, 1621, 1537, 1572, 1573, 1604, 1577, 1595, 1531,
       1598, 1583, 1598, 1608, 1561, 1617, 1539, 1535, 1620, 1622, 1532,
       1613, 1593, 1587, 1528, 1617, 1578, 1532, 1586, 1567, 1603, 1555,
       1559, 1606, 1540, 1551, 1605, 1540, 1562, 1544, 1621, 1581, 1580,
       1604, 1560, 1562, 1570, 1606, 1556, 1531, 1594, 1533, 1592, 1618,
       1557, 1599, 1547, 1579, 1589, 1541, 1564, 16

In [6]:

#     labels = np.argmax(one_hot_labels, axis=1)
#     ixs = [np.random.choice(np.where(labels == i)[0], size=5, replace=False)
#            for i in range(10)]
#     batch = np.array([images[ix] for ix in ixs])
#     return torch.from_numpy(batch).clone().repeat((batch_size // 10) + 1, 1, 1)[:batch_size]

In [24]:
a=[1,2,3]
a[:2]

[1, 2]

In [32]:
np.vstack(x_test).shape

(100, 784)

In [34]:
np.array(x_labels).shape

(100,)

In [87]:
from tqdm import tqdm
def classify_datapoint(x, D_means, D_vars, K, single_sample=True):
    dataset = torch.from_numpy(x)
    with torch.no_grad():
        inputs = Variable(dataset.cuda())
    h1 = model.shared_convolutional_encoder(inputs)
    c_mean_, c_logvar_ = model.statistic_network(h1, summarize=False, single_sample=single_sample)
    kl_divergences = []
    for i in range(K):
        kl = kl_diagnormal_diagnormal(D_means[i], D_vars[i], c_mean_, c_logvar_)
        kl_divergences.append(kl.data.item())
    best_index = kl_divergences.index(min(kl_divergences))
    return best_index

def mnist_one_shot(K=10, support=1, n_trials=100, n_test_samples=1):
    W = support
    data_dir = "../mnist-data"
    images, one_hot_labels = load_mnist(data_dir=data_dir)
    lb = np.argmax(one_hot_labels, 1)
    for trial in tqdm(range(n_trials)):
        D = []
        x_test = []
        x_labels = []
        for i in range(K):
            idx = np.where(lb == i)[0]
            np.random.shuffle(idx)
            D.append(images[idx[:W]])
            samples = images[idx[W : W + n_test_samples]]
            x_test.append(samples)
            x_labels += ([i] * samples.shape[0])
        x_labels = np.array(x_labels)
        x_test = np.vstack(x_test)
        D = np.array(D)
        D_means = []
        D_vars = []
        for i in range(K):
            dataset = torch.from_numpy(D[i])
            with torch.no_grad():
                inputs = Variable(dataset.cuda())
            h = model.shared_convolutional_encoder(inputs)
            c_mean_full, c_logvar_full = model.statistic_network(h, summarize=True)
            D_means.append(c_mean_full)
            D_vars.append(c_logvar_full)
        test_loader = data.DataLoader(dataset=x_test, batch_size=100,
                                  shuffle=False, num_workers=0, drop_last=False)
        preds = []
        for batch in test_loader:
            with torch.no_grad():
                inputs = Variable(batch.cuda())
            h1 = model.shared_convolutional_encoder(inputs)
            c_mean_, c_logvar_ = model.statistic_network(h1, single_sample=True)
            for bi, x in enumerate(batch):
                kl_divergences = []
                for i in range(K):
                    kl = kl_diagnormal_diagnormal(D_means[i], D_vars[i], c_mean_[bi], c_logvar_[bi])
                    kl_divergences.append(kl.data.item())
                best_index = kl_divergences.index(min(kl_divergences))
                preds.append(best_index)
        
    #for i in range(K):
        #xs = x_test[i]
        #label = i
#     for x in x_test:
#         pred = classify_datapoint(x, D_means, D_vars, K)
#         preds.append(pred)
        acc = np.mean(np.array(preds) == np.array(x_labels))
        print(acc)
        accs.append(acc)
    return accs

#def get_p

In [92]:
accs_5 = mnist_one_shot(support=5, n_test_samples=1000, n_trials=100)






  0%|          | 0/100 [00:00<?, ?it/s][A[A[A[A[A




  1%|          | 1/100 [00:23<39:01, 23.66s/it][A[A[A[A[A

0.8381







  2%|▏         | 2/100 [00:47<38:42, 23.70s/it][A[A[A[A[A

0.8098







  3%|▎         | 3/100 [01:11<38:42, 23.94s/it][A[A[A[A[A

0.828







  4%|▍         | 4/100 [01:36<38:36, 24.13s/it][A[A[A[A[A

0.8106







  5%|▌         | 5/100 [02:01<38:22, 24.23s/it][A[A[A[A[A

0.8191







  6%|▌         | 6/100 [02:25<38:04, 24.31s/it][A[A[A[A[A

0.8121







  7%|▋         | 7/100 [02:50<37:48, 24.40s/it][A[A[A[A[A

0.8112







  8%|▊         | 8/100 [03:14<37:20, 24.35s/it][A[A[A[A[A

0.8132







  9%|▉         | 9/100 [03:38<37:00, 24.40s/it][A[A[A[A[A

0.8292







 10%|█         | 10/100 [04:03<36:35, 24.40s/it][A[A[A[A[A

0.8142







 11%|█         | 11/100 [04:27<36:17, 24.46s/it][A[A[A[A[A

0.7889







 12%|█▏        | 12/100 [04:52<35:48, 24.41s/it][A[A[A[A[A

0.8146







 13%|█▎        | 13/100 [05:16<35:15, 24.32s/it][A[A[A[A[A

0.8058







 14%|█▍        | 14/100 [05:40<34:56, 24.38s/it][A[A[A[A[A

0.8334







 15%|█▌        | 15/100 [06:05<34:33, 24.39s/it][A[A[A[A[A

0.847







 16%|█▌        | 16/100 [06:29<34:16, 24.48s/it][A[A[A[A[A

0.8148







 17%|█▋        | 17/100 [06:54<33:48, 24.44s/it][A[A[A[A[A

0.8138







 18%|█▊        | 18/100 [07:18<33:22, 24.42s/it][A[A[A[A[A

0.8098







 19%|█▉        | 19/100 [07:42<32:56, 24.40s/it][A[A[A[A[A

0.8168







 20%|██        | 20/100 [08:07<32:36, 24.45s/it][A[A[A[A[A

0.8257







 21%|██        | 21/100 [08:31<32:09, 24.43s/it][A[A[A[A[A

0.8217







 22%|██▏       | 22/100 [08:56<31:45, 24.43s/it][A[A[A[A[A

0.795







 23%|██▎       | 23/100 [09:20<31:17, 24.38s/it][A[A[A[A[A

0.8236







 24%|██▍       | 24/100 [09:44<30:49, 24.34s/it][A[A[A[A[A

0.7782







 25%|██▌       | 25/100 [10:09<30:30, 24.40s/it][A[A[A[A[A

0.8301







 26%|██▌       | 26/100 [10:33<30:03, 24.37s/it][A[A[A[A[A

0.8404







 27%|██▋       | 27/100 [10:57<29:24, 24.18s/it][A[A[A[A[A

0.8186







 28%|██▊       | 28/100 [11:21<29:03, 24.21s/it][A[A[A[A[A

0.8447







 29%|██▉       | 29/100 [11:45<28:34, 24.14s/it][A[A[A[A[A

0.8262







 30%|███       | 30/100 [12:10<28:13, 24.20s/it][A[A[A[A[A

0.8355







 31%|███       | 31/100 [12:34<27:49, 24.19s/it][A[A[A[A[A

0.8466







 32%|███▏      | 32/100 [12:58<27:26, 24.22s/it][A[A[A[A[A

0.8263







 33%|███▎      | 33/100 [13:22<27:03, 24.23s/it][A[A[A[A[A

0.8526







 34%|███▍      | 34/100 [13:46<26:36, 24.19s/it][A[A[A[A[A

0.7984







 35%|███▌      | 35/100 [14:11<26:13, 24.21s/it][A[A[A[A[A

0.8282







 36%|███▌      | 36/100 [14:35<25:48, 24.20s/it][A[A[A[A[A

0.8383







 37%|███▋      | 37/100 [14:59<25:23, 24.19s/it][A[A[A[A[A

0.7686







 38%|███▊      | 38/100 [15:23<25:02, 24.23s/it][A[A[A[A[A

0.8377







 39%|███▉      | 39/100 [15:48<24:38, 24.24s/it][A[A[A[A[A

0.7819







 40%|████      | 40/100 [16:12<24:13, 24.22s/it][A[A[A[A[A

0.8213







 41%|████      | 41/100 [16:36<23:48, 24.21s/it][A[A[A[A[A

0.7991







 42%|████▏     | 42/100 [17:00<23:24, 24.22s/it][A[A[A[A[A

0.8429







 43%|████▎     | 43/100 [17:24<23:00, 24.22s/it][A[A[A[A[A

0.8084







 44%|████▍     | 44/100 [17:49<22:36, 24.23s/it][A[A[A[A[A

0.8273







 45%|████▌     | 45/100 [18:13<22:08, 24.16s/it][A[A[A[A[A

0.8388







 46%|████▌     | 46/100 [18:37<21:44, 24.16s/it][A[A[A[A[A

0.7959







 47%|████▋     | 47/100 [19:01<21:22, 24.20s/it][A[A[A[A[A

0.8095







 48%|████▊     | 48/100 [19:25<20:58, 24.20s/it][A[A[A[A[A

0.8137







 49%|████▉     | 49/100 [19:49<20:34, 24.20s/it][A[A[A[A[A

0.8504







 50%|█████     | 50/100 [20:14<20:08, 24.17s/it][A[A[A[A[A

0.8187







 51%|█████     | 51/100 [20:38<19:45, 24.18s/it][A[A[A[A[A

0.8287







 52%|█████▏    | 52/100 [21:02<19:23, 24.24s/it][A[A[A[A[A

0.8001







 53%|█████▎    | 53/100 [21:26<19:00, 24.26s/it][A[A[A[A[A

0.8186







 54%|█████▍    | 54/100 [21:51<18:37, 24.29s/it][A[A[A[A[A

0.8146







 55%|█████▌    | 55/100 [22:15<18:11, 24.26s/it][A[A[A[A[A

0.8083







 56%|█████▌    | 56/100 [22:39<17:48, 24.28s/it][A[A[A[A[A

0.7918







 57%|█████▋    | 57/100 [23:04<17:24, 24.29s/it][A[A[A[A[A

0.7962







 58%|█████▊    | 58/100 [23:28<17:01, 24.32s/it][A[A[A[A[A

0.7933







 59%|█████▉    | 59/100 [23:52<16:35, 24.27s/it][A[A[A[A[A

0.8185







 60%|██████    | 60/100 [24:16<16:10, 24.25s/it][A[A[A[A[A

0.7948







 61%|██████    | 61/100 [24:41<15:46, 24.27s/it][A[A[A[A[A

0.8083







 62%|██████▏   | 62/100 [25:05<15:24, 24.33s/it][A[A[A[A[A

0.7922







 63%|██████▎   | 63/100 [25:30<15:00, 24.34s/it][A[A[A[A[A

0.8137







 64%|██████▍   | 64/100 [25:54<14:36, 24.35s/it][A[A[A[A[A

0.7887







 65%|██████▌   | 65/100 [26:18<14:10, 24.31s/it][A[A[A[A[A

0.829







 66%|██████▌   | 66/100 [26:42<13:46, 24.32s/it][A[A[A[A[A

0.8122







 67%|██████▋   | 67/100 [27:07<13:21, 24.30s/it][A[A[A[A[A

0.8209







 68%|██████▊   | 68/100 [27:31<12:56, 24.27s/it][A[A[A[A[A

0.7813







 69%|██████▉   | 69/100 [27:55<12:33, 24.30s/it][A[A[A[A[A

0.8173







 70%|███████   | 70/100 [28:20<12:08, 24.30s/it][A[A[A[A[A

0.816







 71%|███████   | 71/100 [28:44<11:43, 24.26s/it][A[A[A[A[A

0.8357







 72%|███████▏  | 72/100 [29:08<11:18, 24.25s/it][A[A[A[A[A

0.8345







 73%|███████▎  | 73/100 [29:32<10:55, 24.27s/it][A[A[A[A[A

0.7845







 74%|███████▍  | 74/100 [29:56<10:30, 24.26s/it][A[A[A[A[A

0.8017







 75%|███████▌  | 75/100 [30:21<10:07, 24.30s/it][A[A[A[A[A

0.8341







 76%|███████▌  | 76/100 [30:45<09:42, 24.29s/it][A[A[A[A[A

0.8027







 77%|███████▋  | 77/100 [31:09<09:17, 24.26s/it][A[A[A[A[A

0.8463







 78%|███████▊  | 78/100 [31:34<08:53, 24.27s/it][A[A[A[A[A

0.8294







 79%|███████▉  | 79/100 [31:58<08:30, 24.32s/it][A[A[A[A[A

0.7867







 80%|████████  | 80/100 [32:22<08:05, 24.27s/it][A[A[A[A[A

0.7776







 81%|████████  | 81/100 [32:46<07:41, 24.27s/it][A[A[A[A[A

0.8162







 82%|████████▏ | 82/100 [33:11<07:16, 24.24s/it][A[A[A[A[A

0.8244







 83%|████████▎ | 83/100 [33:35<06:51, 24.21s/it][A[A[A[A[A

0.8319







 84%|████████▍ | 84/100 [33:59<06:27, 24.24s/it][A[A[A[A[A

0.8052







 85%|████████▌ | 85/100 [34:23<06:03, 24.25s/it][A[A[A[A[A

0.7894







 86%|████████▌ | 86/100 [34:48<05:39, 24.23s/it][A[A[A[A[A

0.8421







 87%|████████▋ | 87/100 [35:12<05:14, 24.23s/it][A[A[A[A[A

0.8184







 88%|████████▊ | 88/100 [35:36<04:50, 24.20s/it][A[A[A[A[A

0.827







 89%|████████▉ | 89/100 [36:00<04:26, 24.21s/it][A[A[A[A[A

0.8452







 90%|█████████ | 90/100 [36:24<04:02, 24.24s/it][A[A[A[A[A

0.8109







 91%|█████████ | 91/100 [36:49<03:38, 24.23s/it][A[A[A[A[A

0.8369







 92%|█████████▏| 92/100 [37:13<03:13, 24.21s/it][A[A[A[A[A

0.7835







 93%|█████████▎| 93/100 [37:37<02:49, 24.20s/it][A[A[A[A[A

0.8257







 94%|█████████▍| 94/100 [38:01<02:25, 24.22s/it][A[A[A[A[A

0.7747







 95%|█████████▌| 95/100 [38:26<02:01, 24.27s/it][A[A[A[A[A

0.8088







 96%|█████████▌| 96/100 [38:50<01:36, 24.24s/it][A[A[A[A[A

0.801







 97%|█████████▋| 97/100 [39:14<01:12, 24.23s/it][A[A[A[A[A

0.8091







 98%|█████████▊| 98/100 [39:38<00:48, 24.23s/it][A[A[A[A[A

0.8165







 99%|█████████▉| 99/100 [40:03<00:24, 24.25s/it][A[A[A[A[A

0.8419







100%|██████████| 100/100 [40:27<00:00, 24.27s/it][A[A[A[A[A




[A[A[A[A[A

0.8514


In [93]:
np.mean(accs_5)

0.7453090909090909

In [94]:
accs_1 = mnist_one_shot(support=1, n_test_samples=1000, n_trials=100)
np.mean(accs_1)






  0%|          | 0/100 [00:00<?, ?it/s][A[A[A[A[A




  1%|          | 1/100 [00:24<39:47, 24.12s/it][A[A[A[A[A

0.6144







  2%|▏         | 2/100 [00:48<39:25, 24.14s/it][A[A[A[A[A

0.6186







  3%|▎         | 3/100 [01:12<39:09, 24.22s/it][A[A[A[A[A

0.6241







  4%|▍         | 4/100 [01:37<38:48, 24.25s/it][A[A[A[A[A

0.6503







  5%|▌         | 5/100 [02:01<38:28, 24.30s/it][A[A[A[A[A

0.6082







  6%|▌         | 6/100 [02:25<38:07, 24.34s/it][A[A[A[A[A

0.6077







  7%|▋         | 7/100 [02:50<37:46, 24.37s/it][A[A[A[A[A

0.7002







  8%|▊         | 8/100 [03:14<37:21, 24.37s/it][A[A[A[A[A

0.6781







  9%|▉         | 9/100 [03:39<36:57, 24.37s/it][A[A[A[A[A

0.6567







 10%|█         | 10/100 [04:03<36:34, 24.39s/it][A[A[A[A[A

0.6007







 11%|█         | 11/100 [04:27<36:13, 24.42s/it][A[A[A[A[A

0.4348







 12%|█▏        | 12/100 [04:52<35:48, 24.41s/it][A[A[A[A[A

0.6305







 13%|█▎        | 13/100 [05:16<35:23, 24.41s/it][A[A[A[A[A

0.5094







 14%|█▍        | 14/100 [05:41<35:00, 24.43s/it][A[A[A[A[A

0.6345







 15%|█▌        | 15/100 [06:05<34:29, 24.35s/it][A[A[A[A[A

0.6911







 16%|█▌        | 16/100 [06:29<34:00, 24.29s/it][A[A[A[A[A

0.559







 17%|█▋        | 17/100 [06:53<33:28, 24.20s/it][A[A[A[A[A

0.6208







 18%|█▊        | 18/100 [07:17<32:59, 24.13s/it][A[A[A[A[A

0.62







 19%|█▉        | 19/100 [07:41<32:34, 24.13s/it][A[A[A[A[A

0.5833







 20%|██        | 20/100 [08:05<32:07, 24.09s/it][A[A[A[A[A

0.6824







 21%|██        | 21/100 [08:29<31:43, 24.10s/it][A[A[A[A[A

0.6729







 22%|██▏       | 22/100 [08:53<31:20, 24.10s/it][A[A[A[A[A

0.6601







 23%|██▎       | 23/100 [09:17<30:54, 24.09s/it][A[A[A[A[A

0.6014







 24%|██▍       | 24/100 [09:42<30:31, 24.09s/it][A[A[A[A[A

0.5801







 25%|██▌       | 25/100 [10:06<30:05, 24.08s/it][A[A[A[A[A

0.7081







 26%|██▌       | 26/100 [10:30<29:44, 24.11s/it][A[A[A[A[A

0.5598







 27%|██▋       | 27/100 [10:54<29:18, 24.09s/it][A[A[A[A[A

0.6484







 28%|██▊       | 28/100 [11:18<28:54, 24.08s/it][A[A[A[A[A

0.5572







 29%|██▉       | 29/100 [11:42<28:27, 24.04s/it][A[A[A[A[A

0.5149







 30%|███       | 30/100 [12:06<28:04, 24.07s/it][A[A[A[A[A

0.5232







 31%|███       | 31/100 [12:30<27:39, 24.05s/it][A[A[A[A[A

0.5965







 32%|███▏      | 32/100 [12:54<27:15, 24.05s/it][A[A[A[A[A

0.6499







 33%|███▎      | 33/100 [13:18<26:51, 24.04s/it][A[A[A[A[A

0.6456







 34%|███▍      | 34/100 [13:42<26:31, 24.12s/it][A[A[A[A[A

0.5735







 35%|███▌      | 35/100 [14:06<26:08, 24.13s/it][A[A[A[A[A

0.6







 36%|███▌      | 36/100 [14:30<25:37, 24.02s/it][A[A[A[A[A

0.4292







 37%|███▋      | 37/100 [14:54<25:13, 24.02s/it][A[A[A[A[A

0.6496







 38%|███▊      | 38/100 [15:18<24:47, 23.99s/it][A[A[A[A[A

0.5963







 39%|███▉      | 39/100 [15:42<24:22, 23.98s/it][A[A[A[A[A

0.6727







 40%|████      | 40/100 [16:06<24:01, 24.03s/it][A[A[A[A[A

0.5208







 41%|████      | 41/100 [16:31<23:41, 24.10s/it][A[A[A[A[A

0.6793







 42%|████▏     | 42/100 [16:55<23:24, 24.21s/it][A[A[A[A[A

0.5828







 43%|████▎     | 43/100 [17:19<23:00, 24.22s/it][A[A[A[A[A

0.5592







 44%|████▍     | 44/100 [17:43<22:34, 24.19s/it][A[A[A[A[A

0.5769







 45%|████▌     | 45/100 [18:08<22:13, 24.25s/it][A[A[A[A[A

0.5826







 46%|████▌     | 46/100 [18:34<22:13, 24.69s/it][A[A[A[A[A

0.5676







 47%|████▋     | 47/100 [18:58<21:43, 24.60s/it][A[A[A[A[A

0.607







 48%|████▊     | 48/100 [19:22<21:16, 24.54s/it][A[A[A[A[A

0.5386







 49%|████▉     | 49/100 [19:47<20:59, 24.70s/it][A[A[A[A[A

0.6117







 50%|█████     | 50/100 [20:12<20:30, 24.62s/it][A[A[A[A[A

0.5885







 51%|█████     | 51/100 [20:36<20:00, 24.51s/it][A[A[A[A[A

0.5934







 52%|█████▏    | 52/100 [21:01<19:35, 24.49s/it][A[A[A[A[A

0.6307







 53%|█████▎    | 53/100 [21:25<19:09, 24.47s/it][A[A[A[A[A

0.5924







 54%|█████▍    | 54/100 [21:49<18:43, 24.42s/it][A[A[A[A[A

0.5673







 55%|█████▌    | 55/100 [22:14<18:17, 24.38s/it][A[A[A[A[A

0.6849







 56%|█████▌    | 56/100 [22:38<17:51, 24.36s/it][A[A[A[A[A

0.5346







 57%|█████▋    | 57/100 [23:02<17:26, 24.34s/it][A[A[A[A[A

0.5266







 58%|█████▊    | 58/100 [23:26<16:59, 24.28s/it][A[A[A[A[A

0.6432







 59%|█████▉    | 59/100 [23:51<16:36, 24.31s/it][A[A[A[A[A

0.5052







 60%|██████    | 60/100 [24:15<16:13, 24.34s/it][A[A[A[A[A

0.6528







 61%|██████    | 61/100 [24:39<15:48, 24.32s/it][A[A[A[A[A

0.6413







 62%|██████▏   | 62/100 [25:04<15:24, 24.32s/it][A[A[A[A[A

0.5253







 63%|██████▎   | 63/100 [25:28<15:00, 24.34s/it][A[A[A[A[A

0.6558







 64%|██████▍   | 64/100 [25:52<14:35, 24.31s/it][A[A[A[A[A

0.5958







 65%|██████▌   | 65/100 [26:16<14:08, 24.25s/it][A[A[A[A[A

0.6359







 66%|██████▌   | 66/100 [26:40<13:41, 24.16s/it][A[A[A[A[A

0.6861







 67%|██████▋   | 67/100 [27:04<13:15, 24.10s/it][A[A[A[A[A

0.6243







 68%|██████▊   | 68/100 [27:28<12:49, 24.06s/it][A[A[A[A[A

0.7015







 69%|██████▉   | 69/100 [27:52<12:26, 24.08s/it][A[A[A[A[A

0.5905







 70%|███████   | 70/100 [28:16<12:02, 24.08s/it][A[A[A[A[A

0.5556







 71%|███████   | 71/100 [28:41<11:40, 24.14s/it][A[A[A[A[A

0.6937







 72%|███████▏  | 72/100 [29:05<11:15, 24.13s/it][A[A[A[A[A

0.6574







 73%|███████▎  | 73/100 [29:29<10:53, 24.22s/it][A[A[A[A[A

0.5035







 74%|███████▍  | 74/100 [29:53<10:28, 24.17s/it][A[A[A[A[A

0.6543







 75%|███████▌  | 75/100 [30:17<10:03, 24.12s/it][A[A[A[A[A

0.5084







 76%|███████▌  | 76/100 [30:41<09:39, 24.13s/it][A[A[A[A[A

0.59







 77%|███████▋  | 77/100 [31:06<09:14, 24.10s/it][A[A[A[A[A

0.6169







 78%|███████▊  | 78/100 [31:30<08:49, 24.08s/it][A[A[A[A[A

0.6612







 79%|███████▉  | 79/100 [31:54<08:25, 24.05s/it][A[A[A[A[A

0.6612







 80%|████████  | 80/100 [32:18<08:01, 24.08s/it][A[A[A[A[A

0.5126







 81%|████████  | 81/100 [32:42<07:37, 24.10s/it][A[A[A[A[A

0.5847







 82%|████████▏ | 82/100 [33:06<07:12, 24.05s/it][A[A[A[A[A

0.5964







 83%|████████▎ | 83/100 [33:30<06:49, 24.07s/it][A[A[A[A[A

0.5636







 84%|████████▍ | 84/100 [33:54<06:26, 24.15s/it][A[A[A[A[A

0.6428







 85%|████████▌ | 85/100 [34:18<06:01, 24.11s/it][A[A[A[A[A

0.6822







 86%|████████▌ | 86/100 [34:42<05:37, 24.10s/it][A[A[A[A[A

0.6959







 87%|████████▋ | 87/100 [35:07<05:13, 24.15s/it][A[A[A[A[A

0.5873







 88%|████████▊ | 88/100 [35:31<04:49, 24.14s/it][A[A[A[A[A

0.5985







 89%|████████▉ | 89/100 [35:55<04:25, 24.11s/it][A[A[A[A[A

0.5442







 90%|█████████ | 90/100 [36:19<04:00, 24.05s/it][A[A[A[A[A

0.5772







 91%|█████████ | 91/100 [36:43<03:36, 24.02s/it][A[A[A[A[A

0.639







 92%|█████████▏| 92/100 [37:07<03:12, 24.09s/it][A[A[A[A[A

0.6429







 93%|█████████▎| 93/100 [37:31<02:48, 24.11s/it][A[A[A[A[A

0.6074







 94%|█████████▍| 94/100 [37:55<02:24, 24.15s/it][A[A[A[A[A

0.628







 95%|█████████▌| 95/100 [38:19<02:00, 24.03s/it][A[A[A[A[A

0.5883







 96%|█████████▌| 96/100 [38:43<01:36, 24.04s/it][A[A[A[A[A

0.6999







 97%|█████████▋| 97/100 [39:07<01:11, 23.96s/it][A[A[A[A[A

0.5417







 98%|█████████▊| 98/100 [39:31<00:47, 23.96s/it][A[A[A[A[A

0.5787







 99%|█████████▉| 99/100 [39:55<00:24, 24.12s/it][A[A[A[A[A

0.6454







100%|██████████| 100/100 [40:19<00:00, 24.11s/it][A[A[A[A[A




[A[A[A[A[A

0.5531


0.7120105011933174

In [48]:
test_loader = data.DataLoader(dataset=x_test, batch_size=32,
                              shuffle=False, num_workers=0, drop_last=True)
dataset = next(iter(test_loader))
with torch.no_grad():
    inputs = Variable(dataset.cuda())
h1 = model.shared_convolutional_encoder(inputs)
c_mean_, c_logvar_ = model.statistic_network(h1, single_sample=True)
# kl_divergences = []
# for i in range(K):
#     kl = kl_diagnormal_diagnormal(D_means[i], D_vars[i], c_mean_, c_logvar_)
#     c_mean_, c_logvar_ = model.statistic_network(h1,summarize=False)

In [52]:
dataset[0]

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
      

In [17]:
def get_dataset_statisctis()
dataset = torch.from_numpy(D[0])
with torch.no_grad():
    inputs = Variable(dataset.cuda())
inputs = inputs.view(1, 5, 784)
# get approximate posterior over full dataset
#c_mean_full, c_logvar_full = model.statistic_network(dataset, summarize=True)
h = model.shared_convolutional_encoder(inputs)
c_mean_full, c_logvar_full = model.statistic_network(h, summarize=True)


In [43]:
classify_datapoint(x,D_means,D_vars,K)

9

In [35]:


kl_divergences = []
for i in range(K):
    dataset = torch.from_numpy(D[i])
    with torch.no_grad():
        inputs = Variable(dataset.cuda())
    #inputs = inputs.view(1, 5, 784)
    # get approximate posterior over full dataset
    #c_mean_full, c_logvar_full = model.statistic_network(dataset, summarize=True)
    h = model.shared_convolutional_encoder(inputs)
    c_mean_full, c_logvar_full = model.statistic_network(h, summarize=True)
    kl = kl_diagnormal_diagnormal(c_mean_full, c_logvar_full, c_mean_, c_logvar_)
    kl_divergences.append(kl.data.item())

    

In [36]:
best_index = kl_divergences.index(min(kl_divergences))

In [37]:
kl_divergences

[195.96034240722656,
 132.67491149902344,
 250.48081970214844,
 168.60643005371094,
 137.26870727539062,
 198.39479064941406,
 242.20484924316406,
 97.00956726074219,
 137.3927764892578,
 1.6093254089355469e-06]

In [16]:
x = test[2] #lbl = 2
while dataset.size(1) != output_size:
    kl_divergences = []
    # need KL divergence between full approximate posterior and all
    # subsets of given size
    subset_indices = list(combinations(range(dataset.size(1)), dataset.size(1) - 1))

    for subset_index in subset_indices:
        # pull out subset, numpy indexing will make this much easier
        ix = Variable(torch.LongTensor(subset_index).cuda())
        subset = dataset.index_select(1, ix)

        # calculate approximate posterior over subset
        c_mean, c_logvar = self.statistic_network(subset, summarize=True)
        kl = kl_diagnormal_diagnormal(c_mean_full, c_logvar_full, c_mean, c_logvar)
        kl_divergences.append(kl.data.item())

    # determine which sample we want to remove
    best_index = kl_divergences.index(min(kl_divergences))

    # determine which samples to keep
    to_keep = subset_indices[best_index]
    to_keep = Variable(torch.LongTensor(to_keep).cuda())

    # keep only desired samples
    dataset = dataset.index_select(1, to_keep)

4096

In [112]:
torch.from_numpy(batch).clone().repeat(4,1,1)

torch.Size([40, 5, 784])

RuntimeError: Error(s) in loading state_dict for Statistician:
	While copying the parameter named "statistic_network.postpool.fc_layers.0.weight", whose dimensions in the model are torch.Size([256, 256]) and whose dimensions in the checkpoint are torch.Size([257, 257]).
	While copying the parameter named "statistic_network.postpool.fc_layers.0.bias", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.fc_layers.1.weight", whose dimensions in the model are torch.Size([256, 256]) and whose dimensions in the checkpoint are torch.Size([257, 257]).
	While copying the parameter named "statistic_network.postpool.fc_layers.1.bias", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.0.weight", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.0.bias", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.0.running_mean", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.0.running_var", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.1.weight", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.1.bias", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.1.running_mean", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.bn_layers.1.running_var", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([257]).
	While copying the parameter named "statistic_network.postpool.fc_params.weight", whose dimensions in the model are torch.Size([1024, 256]) and whose dimensions in the checkpoint are torch.Size([1024, 257]).