In [1]:
from utils.experiment_utils import get_all_experiments_info, load_best_model
import torch
import os
import hydra
from omegaconf import DictConfig, OmegaConf

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from notebooks.mnist_classifier.mnist_tiny_cnn import TinyCNN

from mixer.mixer import SetMixer
from datasets.mnist import MNISTDataset

from torch.utils.data import DataLoader
from itertools import product

In [2]:
device = 'cuda'
configs = get_all_experiments_info('outputs/', False)
cfg = [c for c in configs if 'mnist_multinomial' in c['name'] 
                    and c['config']['experiment']['batch_size'] == 8]

In [3]:
# load + prep dataset
def prepare_dataset(dataset_cfg):
    # probs = np.column_stack((np.linspace(0, 1, num_probs), 1 - np.linspace(0, 1, num_probs)))
    dataset = hydra.utils.instantiate(dataset_cfg)
    # dataset.probs = probs
    # dataset.data, _, _ = dataset.make_sets()
    return dataset

# load encoder and move to device
def load_model(cfg, path, device):
    enc = hydra.utils.instantiate(cfg['encoder'])
    gen = hydra.utils.instantiate(cfg['generator'])
    state = load_best_model(path)
    enc.load_state_dict(state['encoder_state_dict'])
    gen.model.load_state_dict(state['generator_state_dict'])
    enc.eval()
    gen.eval()
    enc.to(device)
    gen.to(device)
    return enc, gen

In [6]:
classy = TinyCNN()
classy.load_state_dict(torch.load('notebooks/mnist_classifier/mnist_tinycnn.pth'))
classy.to('cuda')

TinyCNN(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=9216, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [7]:
def simplex_grid(dim, points_per_dim):
    lin = np.linspace(0, 1, points_per_dim)
    grid = np.array(list(product(*([lin] * dim))))
    grid = grid[np.isclose(grid.sum(axis=1), 1)]  # keep only rows that sum to 1
    return grid

points_per_dim = 5
k = 3

set_size = 100

mix_probs_labels = simplex_grid(k, points_per_dim)

# fr_dist = multinomial_fr(mix_probs)

n_sets = len(mix_probs_labels)

mix_probs = torch.tensor(np.repeat(mix_probs_labels, n_sets//k, axis=1))

dataset = MNISTDataset(n_classes=k, n_sets=n_sets, set_size=5000)
mixer = SetMixer(k=k, mixed_set_size=set_size, n_mixed_sets=n_sets)


In [10]:
d = {
    "Encoder" : [],
    "Generation class error" : []
}

with torch.no_grad():

    for c in cfg:


        encoder, generator = load_model(c['config'], c['dir'], 'cuda')
        
        rec = generator.sample(encoder(mixed_sets.reshape(n_sets, set_size, 1, 28, 28)), num_samples=100)

        preds = classy(rec.reshape(set_size*n_sets, 1, 28, 28)).argmax(dim=1).reshape(n_sets, set_size)

        compositions = torch.stack([(set_preds.bincount(minlength=10+1)) for set_preds in preds])

        est = compositions[:, :3].cpu().numpy()/100


        error = est - mix_probs_labels


        d['Encoder'].append(c['encoder'])
        d['Generation class error'].append((error**2).mean())
        print(d)

{'Encoder': ['ConvDistributionEncoder'], 'Generation class error': [np.float64(0.026155555555555557)]}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.91M/9.91M [00:01<00:00, 8.17MB/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 761kB/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.65M/1.65M [00:00<00:00, 4.74MB/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 4.06MB/s]


{'Encoder': ['ConvDistributionEncoder', 'MNISTPCAEncoder'], 'Generation class error': [np.float64(0.026155555555555557), np.float64(0.007762222222222223)]}
{'Encoder': ['ConvDistributionEncoder', 'MNISTPCAEncoder', 'WormholeEncoder'], 'Generation class error': [np.float64(0.026155555555555557), np.float64(0.007762222222222223), np.float64(0.002288888888888889)]}
{'Encoder': ['ConvDistributionEncoder', 'MNISTPCAEncoder', 'WormholeEncoder', 'KMEEncoder'], 'Generation class error': [np.float64(0.026155555555555557), np.float64(0.007762222222222223), np.float64(0.002288888888888889), np.float64(0.10325555555555559)]}


In [11]:
pd.DataFrame(d)

Unnamed: 0,Encoder,Generation class error
0,ConvDistributionEncoder,0.026156
1,MNISTPCAEncoder,0.007762
2,WormholeEncoder,0.002289
3,KMEEncoder,0.103256


In [4]:
from utils.eval_utils import compute_encodings_and_resamples, compute_metrics



In [7]:
enc, gen = load_model(cfg[0]['config'], cfg[0]['dir'], device)
dataset = prepare_dataset(cfg[0]['config']['dataset'])
mixer = hydra.utils.instantiate(cfg[0]['config']['mixer'])

dl = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=mixer.collate_fn)
dl_iter = iter(dl)
samples = [next(dl_iter)['samples'] for _ in range(100)]

results = compute_encodings_and_resamples(
    enc, gen, samples, device, 
    encode_batch_size=10, max_encode_samples=1_000,
    resample_batch_size=10, num_resamples=1_000,
)

Step 1/3: Encoding original samples


  torch.tensor(s[:max_encode_samples], dtype=torch.float32)
Encoding samples:   0%|                                                                                                                                                             | 0/10 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 5)

In [9]:
next(dl_iter)['samples'].squeeze().shape

torch.Size([3, 100, 28, 28])