In [1]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[0])

# Add to sys.path
sys.path.insert(0, parent_dir)


In [2]:
from utils.experiment_utils import get_all_experiments_info, load_best_model

device = 'cuda'
configs = get_all_experiments_info('../outputs/', False)
cfgs = [
    c for c in configs if 'gmm_exp' in c['name'] 
        and c['config']['experiment']['latent_dim'] == 32
        and c['config']['experiment']['hidden_dim'] == 128
        and c['config']['dataset']['prior_mu'] == [0, 5]
        and hasattr(c['config']['encoder'], 'layers')
        and c['config']['encoder']['layers'] == 4 
]   



In [3]:
cfgs = [cfgs[1]]

In [4]:
import hydra
# load + prep dataset
def prepare_dataset_and_mixer(cfg, set_size=None, n_sets=None, n_mixed_sets=None):
    # probs = np.column_stack((np.linspace(0, 1, num_probs), 1 - np.linspace(0, 1, num_probs)))
    if set_size is not None:
        cfg['dataset']['set_size'] = set_size
    if n_sets is not None:
        cfg['dataset']['n_sets'] = n_sets
    if n_mixed_sets is not None:
        cfg['mixer']['n_mixed_sets'] = n_mixed_sets
    dataset = hydra.utils.instantiate(cfg['dataset'])
    mixer = hydra.utils.instantiate(cfg['mixer'])
    return dataset, mixer


# load encoder and move to device
def load_model(cfg, path, device):
    enc = hydra.utils.instantiate(cfg['encoder'])
    gen = hydra.utils.instantiate(cfg['generator'])
    state = load_best_model(path)
    enc.load_state_dict(state['encoder_state_dict'])
    gen.model.load_state_dict(state['generator_state_dict'])
    enc.eval()
    gen.eval()
    enc.to(device)
    gen.to(device)
    return enc, gen

enc, gen = load_model(cfgs[0]['config'], cfgs[0]['dir'], device)


In [5]:
from torch.utils.data import DataLoader
ds, mx = prepare_dataset_and_mixer(cfgs[0]['config'], set_size=10_000, n_sets=1_000, n_mixed_sets=1)
dl = DataLoader(ds, batch_size=3, shuffle=False, collate_fn=mx.collate_fn)

In [6]:
dl_iter = iter(dl)
samples = [next(dl_iter)['samples'].squeeze() for _ in range(100)]

In [7]:
from utils.eval_utils import compute_encodings_and_resamples, compute_metrics

In [8]:
results = compute_encodings_and_resamples(
    enc, gen, samples, device, 
    encode_batch_size=10, max_encode_samples=10_000,
    resample_batch_size=10, num_resamples=10_000,
)

Step 1/3: Encoding original samples


  torch.tensor(s[:max_encode_samples], dtype=torch.float32)
Encoding samples: 100%|██████████| 10/10 [00:00<00:00, 69.04it/s]


Step 2/3: Generating samples from latents


Generating samples:   0%|          | 0/10 [00:00<?, ?it/s]

sampling timestep 323

  batch_tensor = torch.tensor(batch_latents, dtype=torch.float32).to(device)


sampling timestep 770

Generating samples:  10%|█         | 1/10 [00:00<00:02,  4.01it/s]

sampling timestep 650

Generating samples:  20%|██        | 2/10 [00:00<00:01,  4.04it/s]

sampling timestep 139

Generating samples:  30%|███       | 3/10 [00:00<00:01,  4.05it/s]

sampling timestep 214

Generating samples:  40%|████      | 4/10 [00:00<00:01,  4.06it/s]

sampling timestep 400

Generating samples:  50%|█████     | 5/10 [00:01<00:01,  4.07it/s]

sampling timestep 260

Generating samples:  60%|██████    | 6/10 [00:01<00:00,  4.07it/s]

sampling timestep 980

Generating samples:  70%|███████   | 7/10 [00:01<00:00,  4.07it/s]

sampling timestep 169

Generating samples:  80%|████████  | 8/10 [00:01<00:00,  4.07it/s]

sampling timestep 100

Generating samples:  90%|█████████ | 9/10 [00:02<00:00,  4.07it/s]

sampling timestep 690

Generating samples: 100%|██████████| 10/10 [00:02<00:00,  4.06it/s]


Step 3/3: Re-encoding generated samples


Encoding samples: 100%|██████████| 10/10 [00:00<00:00, 168.00it/s]


In [9]:
metrics = compute_metrics(results, batch_size=1_000)

  torch.tensor(pair[0][:sample_batch_size], dtype=torch.float32)
Computing distribution metrics: 100%|██████████| 10/10 [00:00<00:00, 17.42it/s]


In [10]:
metrics

{'latent_recon_error': {'mean': 0.1312226,
  'std': 0.0,
  'per_set': array(0.1312226, dtype=float32)},
 'mmd': {'mean': 0.00199424147605896,
  'std': 0.0019327218960941259,
  'per_set': array([0.00182664, 0.00090182, 0.00174415, 0.00074124, 0.00010455,
         0.00126559, 0.00069761, 0.00090051, 0.00214827, 0.00569415,
         0.00032723, 0.00197875, 0.00132   , 0.00300479, 0.00271773,
         0.00184464, 0.00023079, 0.00260997, 0.00185871, 0.00460327,
         0.00196469, 0.00211346, 0.00195205, 0.00185609, 0.00122714,
         0.00149918, 0.00393069, 0.00024259, 0.00025058, 0.00292897,
         0.00032818, 0.00159204, 0.00923777, 0.00028932, 0.00025475,
         0.00042617, 0.00199354, 0.00306427, 0.00063765, 0.00033736,
         0.00066614, 0.00073111, 0.00038934, 0.00092077, 0.00268739,
         0.00092649, 0.0018307 , 0.00049198, 0.00210047, 0.00056505,
         0.00739872, 0.00458705, 0.00050592, 0.00153506, 0.00102139,
         0.00112307, 0.00204301, 0.00173414, 0.00092602,