In [1]:
from hparams import HParams
hparams = HParams('.', name="efficient_vdvae")

Found existing ./logs-imagnet_64_baseline/hparams-imagnet_64_baseline.cfg! Resuming run using primary parameters!


In [2]:
import pickle
import numpy as np
import os
import torch
from numpy.random import seed
import random
from PIL import Image
from tqdm import tqdm

from utils.utils import assert_CUDA_and_hparams_gpus_are_equal, create_checkpoint_manager_and_load_if_exists, \
        get_logdir, get_variate_masks, transpose_dicts
from data.generic_data_loader import synth_generic_data, encode_generic_data, stats_generic_data
from data.cifar10_data_loader import synth_cifar_data, encode_cifar_data, stats_cifar_data
from data.imagenet_data_loader import synth_imagenet_data, encode_imagenet_data, stats_imagenet_data
from data.mnist_data_loader import synth_mnist_data, encode_mnist_data, stats_mnist_data
from model.def_model import UniversalAutoEncoder
from model.model import reconstruction_step, generation_step, encode_step
from model.losses import StructureSimilarityIndexMap
from utils import temperature_functions
from model.div_stats_utils import KLDivergenceStats

from sklearn.utils import shuffle

# Fix random seeds
torch.manual_seed(hparams.run.seed)
torch.manual_seed(hparams.run.seed)
torch.cuda.manual_seed(hparams.run.seed)
torch.cuda.manual_seed_all(hparams.run.seed)  # if you are using multi-GPU.
seed(hparams.run.seed)  # Numpy module.
random.seed(hparams.run.seed)  # Python random module.
torch.manual_seed(hparams.run.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

2024-06-18 17:32:53.638767: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
device = torch.device('cuda:0')

<torch.autograd.profiler.emit_nvtx at 0x7fcf146b01f0>

In [4]:
embeddings = pickle.load(open('logs-imagnet_64_baseline/latents/encodings_temp_first100.pkl', 'rb'))

In [5]:
embeddings.keys()

dict_keys(['images', 'latent_codes'])

In [6]:
np.array(list(embeddings['images'].values())).shape

(100, 3, 64, 64)

In [7]:
embeddings['latent_codes']['image_1093444'].keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 22, 24, 25, 26, 27, 30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 74, 76, 77, 79, 80, 81, 82])

In [8]:
for i in embeddings['latent_codes']['image_1093444']:
    print(i, embeddings['latent_codes']['image_1093444'][i].shape, embeddings['latent_codes']['image_1093444'][i].size)

0 (1, 1, 2, 2) 4
1 (1, 1, 1, 2) 2
2 (1, 1, 1, 2) 2
3 (1, 1, 2, 2) 4
4 (1, 1, 1, 2) 2
5 (1, 1, 2, 2) 4
6 (4, 4, 2, 2) 64
7 (4, 4, 2, 2) 64
8 (4, 4, 2, 2) 64
9 (4, 4, 2, 2) 64
10 (4, 4, 2, 2) 64
11 (4, 4, 2, 2) 64
12 (4, 4, 1, 2) 32
14 (8, 8, 1, 2) 128
15 (8, 8, 1, 2) 128
16 (8, 8, 2, 2) 256
17 (8, 8, 2, 2) 256
18 (8, 8, 2, 2) 256
19 (8, 8, 2, 2) 256
20 (8, 8, 1, 2) 128
22 (8, 8, 1, 2) 128
24 (8, 8, 1, 2) 128
25 (8, 8, 1, 2) 128
26 (8, 8, 1, 2) 128
27 (8, 8, 1, 2) 128
30 (8, 8, 1, 2) 128
31 (8, 8, 1, 2) 128
36 (16, 16, 1, 2) 512
37 (16, 16, 1, 2) 512
38 (16, 16, 1, 2) 512
39 (16, 16, 1, 2) 512
40 (16, 16, 1, 2) 512
41 (16, 16, 1, 2) 512
42 (16, 16, 1, 2) 512
43 (16, 16, 1, 2) 512
44 (16, 16, 1, 2) 512
45 (16, 16, 1, 2) 512
46 (16, 16, 1, 2) 512
47 (16, 16, 1, 2) 512
48 (16, 16, 1, 2) 512
49 (16, 16, 1, 2) 512
50 (16, 16, 1, 2) 512
51 (16, 16, 1, 2) 512
52 (16, 16, 1, 2) 512
53 (16, 16, 1, 2) 512
56 (16, 16, 1, 2) 512
58 (32, 32, 1, 2) 2048
59 (32, 32, 1, 2) 2048
60 (32, 32, 1, 2) 2048
61

In [9]:
embedding_size = 0
for i in embeddings['latent_codes']['image_1093444']:
    embedding_size += embeddings['latent_codes']['image_1093444'][i].size
embedding_size

100530

In [10]:
model = UniversalAutoEncoder()
model = model.to(device)
with torch.no_grad():
    ones = torch.ones((1, hparams.data.channels, hparams.data.target_res, hparams.data.target_res)).cuda(0)
    _ = model(ones)

In [11]:
checkpoint, checkpoint_path = create_checkpoint_manager_and_load_if_exists(rank=0)

In [12]:
if hparams.synthesis.load_ema_weights:
    assert checkpoint['ema_model_state_dict'] is not None
    model.load_state_dict(checkpoint['ema_model_state_dict'])
    print('EMA model is loaded')
else:
    assert checkpoint['model_state_dict'] is not None
    model.load_state_dict(checkpoint['model_state_dict'])
    print('Model Checkpoint is loaded')
print(checkpoint_path)

EMA model is loaded
./checkpoints-imagnet_64_baseline


In [13]:
def encode_data():
    if hparams.data.dataset_source in ['ffhq', 'celebAHQ', 'celebA', 'custom']:
        return encode_generic_data()
    elif hparams.data.dataset_source == 'cifar-10':
        return encode_cifar_data()
    elif hparams.data.dataset_source == 'binarized_mnist':
        return encode_mnist_data()
    elif hparams.data.dataset_source == 'imagenet':
        return encode_imagenet_data()
    else:
        raise ValueError(f'Dataset {hparams.data.dataset_source} is not included.')
data_loader = encode_data()

Number of Images: 1281167
Path:  ../datasets/imagenet_64/train_data/


In [14]:
def get_variate_masks(stats):
    thresh = np.quantile(stats, 1 - 0.03)
    return stats > thresh

In [15]:
div_stats = np.load('logs-imagnet_64_baseline/latents/div_stats.npy')
variate_masks = get_variate_masks(div_stats)
div_stats.shape, variate_masks.shape

((84, 32), (84, 32))

In [16]:
def reshape_distribution(dist_list, variate_mask):
    """
    :param dist_list: n_layers, 2*  [ batch_size n_variates, H , W]
    :return: Tensors  of shape batch_size, H, W ,n_variates, 2
    H, W , n_variates will be different from each other in the list depending on which layer you are in.
    """
    dist = torch.stack(dist_list, dim=0)  # 2, batch_size, n_variates, H ,W
    dist = dist[:, :, variate_mask, :, :]  # Only take effective variates
    dist = torch.permute(dist, (1, 3, 4, 2, 0))  # batch_size, H ,W ,n_variates (subset), 2
    # dist = torch.unbind(dist, dim=0)  # Return a list of tensors of length batch_size
    return dist

In [17]:
model = model.eval()
with torch.no_grad():
    for step, (inputs, filenames) in enumerate(tqdm(data_loader)):
        print(len(filenames), len(inputs))
        inputs = inputs.to(device, non_blocking=True)
        predictions, posterior_dist_list, prior_kl_dist_list = model(inputs, variate_masks)

        # If the mask states all variables of a layer are not effective we don't collect any latents from that layer
        # n_layers , batch_size, [H, W, n_variates, 2]
        dist_dict = {}
        for i, (dist_list, variate_mask) in enumerate(zip(posterior_dist_list, variate_masks)):
            if variate_mask.any():
                x = reshape_distribution(dist_list, variate_mask).detach().cpu().numpy()
                v = {name: xa for name, xa in zip(filenames, list(x))}
                dist_dict[i] = v

        if step == 0:
            break

  0%|          | 0/10010 [00:00<?, ?it/s]

128 128


  0%|          | 0/10010 [00:10<?, ?it/s]


In [18]:
predictions.cpu().numpy().shape

(128, 100, 64, 64)

In [19]:
for i in range(len(posterior_dist_list)):
    print(i, posterior_dist_list[i][0][0].cpu().numpy().shape, posterior_dist_list[i][0][0].cpu().numpy().size)
# posterior_dist_list[83][0].cpu().numpy().shape

0 (32, 1, 1) 32
1 (32, 1, 1) 32
2 (32, 1, 1) 32
3 (32, 1, 1) 32
4 (32, 1, 1) 32
5 (32, 1, 1) 32
6 (32, 4, 4) 512
7 (32, 4, 4) 512
8 (32, 4, 4) 512
9 (32, 4, 4) 512
10 (32, 4, 4) 512
11 (32, 4, 4) 512
12 (32, 4, 4) 512
13 (32, 8, 8) 2048
14 (32, 8, 8) 2048
15 (32, 8, 8) 2048
16 (32, 8, 8) 2048
17 (32, 8, 8) 2048
18 (32, 8, 8) 2048
19 (32, 8, 8) 2048
20 (32, 8, 8) 2048
21 (32, 8, 8) 2048
22 (32, 8, 8) 2048
23 (32, 8, 8) 2048
24 (32, 8, 8) 2048
25 (32, 8, 8) 2048
26 (32, 8, 8) 2048
27 (32, 8, 8) 2048
28 (32, 8, 8) 2048
29 (32, 8, 8) 2048
30 (32, 8, 8) 2048
31 (32, 8, 8) 2048
32 (32, 16, 16) 8192
33 (32, 16, 16) 8192
34 (32, 16, 16) 8192
35 (32, 16, 16) 8192
36 (32, 16, 16) 8192
37 (32, 16, 16) 8192
38 (32, 16, 16) 8192
39 (32, 16, 16) 8192
40 (32, 16, 16) 8192
41 (32, 16, 16) 8192
42 (32, 16, 16) 8192
43 (32, 16, 16) 8192
44 (32, 16, 16) 8192
45 (32, 16, 16) 8192
46 (32, 16, 16) 8192
47 (32, 16, 16) 8192
48 (32, 16, 16) 8192
49 (32, 16, 16) 8192
50 (32, 16, 16) 8192
51 (32, 16, 16) 8192
5

In [20]:
embedding_size = 0
for i in range(len(posterior_dist_list)):
    embedding_size += posterior_dist_list[i][0][0].cpu().numpy().size
embedding_size

2213568

In [21]:
# for i in range(len(prior_kl_dist_list)):
#     print(prior_kl_dist_list[i][0].cpu().numpy().shape)

In [22]:
x.shape, v['image_1093444'].shape

((128, 64, 64, 1, 2), (64, 64, 1, 2))

In [23]:
dist_dict.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 22, 24, 25, 26, 27, 30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 74, 76, 77, 79, 80, 81, 82])

In [24]:
for i in dist_dict.keys():
    print(i, dist_dict[i]['image_1093444'].shape, dist_dict[i]['image_1093444'].size)

0 (1, 1, 2, 2) 4
1 (1, 1, 1, 2) 2
2 (1, 1, 1, 2) 2
3 (1, 1, 2, 2) 4
4 (1, 1, 1, 2) 2
5 (1, 1, 2, 2) 4
6 (4, 4, 2, 2) 64
7 (4, 4, 2, 2) 64
8 (4, 4, 2, 2) 64
9 (4, 4, 2, 2) 64
10 (4, 4, 2, 2) 64
11 (4, 4, 2, 2) 64
12 (4, 4, 1, 2) 32
14 (8, 8, 1, 2) 128
15 (8, 8, 1, 2) 128
16 (8, 8, 2, 2) 256
17 (8, 8, 2, 2) 256
18 (8, 8, 2, 2) 256
19 (8, 8, 2, 2) 256
20 (8, 8, 1, 2) 128
22 (8, 8, 1, 2) 128
24 (8, 8, 1, 2) 128
25 (8, 8, 1, 2) 128
26 (8, 8, 1, 2) 128
27 (8, 8, 1, 2) 128
30 (8, 8, 1, 2) 128
31 (8, 8, 1, 2) 128
36 (16, 16, 1, 2) 512
37 (16, 16, 1, 2) 512
38 (16, 16, 1, 2) 512
39 (16, 16, 1, 2) 512
40 (16, 16, 1, 2) 512
41 (16, 16, 1, 2) 512
42 (16, 16, 1, 2) 512
43 (16, 16, 1, 2) 512
44 (16, 16, 1, 2) 512
45 (16, 16, 1, 2) 512
46 (16, 16, 1, 2) 512
47 (16, 16, 1, 2) 512
48 (16, 16, 1, 2) 512
49 (16, 16, 1, 2) 512
50 (16, 16, 1, 2) 512
51 (16, 16, 1, 2) 512
52 (16, 16, 1, 2) 512
53 (16, 16, 1, 2) 512
56 (16, 16, 1, 2) 512
58 (32, 32, 1, 2) 2048
59 (32, 32, 1, 2) 2048
60 (32, 32, 1, 2) 2048
61

In [29]:
embedding_size = 0
for i in dist_dict.keys():
    embedding_size += dist_dict[i]['image_1093444'].size
    if i == 30:
        break
embedding_size

2610

In [26]:
100530 / 2213568

0.045415365599791827

In [27]:
for i, (dist_list, variate_mask) in enumerate(zip(posterior_dist_list, variate_masks)):
        # x = reshape_distribution(dist_list, variate_mask).detach().cpu().numpy()
    dist = torch.stack(dist_list, dim=0)  # 2, batch_size, n_variates, H ,W
    if variate_mask.any():
        dist2 = dist[:, :, variate_mask, :, :]  # Only take effective variates
        x = torch.permute(dist2, (1, 3, 4, 2, 0)).detach().cpu().numpy()  # batch_size, H ,W ,n_variates (subset), 2
        # print(i, dist_list[0][0].cpu().numpy().shape, dist_list[0][0].cpu().numpy().size * 2, dist[:,0].detach().cpu().numpy().shape, x[0].shape, x[0].size)
        print(i, dist[:,0].detach().cpu().numpy().shape, dist[:,0].detach().cpu().numpy().size, variate_mask.sum(), dist2[:,0].detach().cpu().numpy().shape, x[0].shape, x[0].size)
    else:
        print(i, dist[:,0].detach().cpu().numpy().shape, dist[:,0].detach().cpu().numpy().size, 'No effective variates')

0 (2, 32, 1, 1) 64 2 (2, 2, 1, 1) (1, 1, 2, 2) 4
1 (2, 32, 1, 1) 64 1 (2, 1, 1, 1) (1, 1, 1, 2) 2
2 (2, 32, 1, 1) 64 1 (2, 1, 1, 1) (1, 1, 1, 2) 2
3 (2, 32, 1, 1) 64 2 (2, 2, 1, 1) (1, 1, 2, 2) 4
4 (2, 32, 1, 1) 64 1 (2, 1, 1, 1) (1, 1, 1, 2) 2
5 (2, 32, 1, 1) 64 2 (2, 2, 1, 1) (1, 1, 2, 2) 4
6 (2, 32, 4, 4) 1024 2 (2, 2, 4, 4) (4, 4, 2, 2) 64
7 (2, 32, 4, 4) 1024 2 (2, 2, 4, 4) (4, 4, 2, 2) 64
8 (2, 32, 4, 4) 1024 2 (2, 2, 4, 4) (4, 4, 2, 2) 64
9 (2, 32, 4, 4) 1024 2 (2, 2, 4, 4) (4, 4, 2, 2) 64
10 (2, 32, 4, 4) 1024 2 (2, 2, 4, 4) (4, 4, 2, 2) 64
11 (2, 32, 4, 4) 1024 2 (2, 2, 4, 4) (4, 4, 2, 2) 64
12 (2, 32, 4, 4) 1024 1 (2, 1, 4, 4) (4, 4, 1, 2) 32
13 (2, 32, 8, 8) 4096 No effective variates
14 (2, 32, 8, 8) 4096 1 (2, 1, 8, 8) (8, 8, 1, 2) 128
15 (2, 32, 8, 8) 4096 1 (2, 1, 8, 8) (8, 8, 1, 2) 128
16 (2, 32, 8, 8) 4096 2 (2, 2, 8, 8) (8, 8, 2, 2) 256
17 (2, 32, 8, 8) 4096 2 (2, 2, 8, 8) (8, 8, 2, 2) 256
18 (2, 32, 8, 8) 4096 2 (2, 2, 8, 8) (8, 8, 2, 2) 256
19 (2, 32, 8, 8) 4096 2 (

In [116]:
# dist = torch.stack(dist_list, dim=0)  # 2, batch_size, n_variates, H ,W
# len(dist_list), dist_list[0].shape, dist.shape

In [114]:
# dist = dist[:, :, variate_mask, :, :]  # Only take effective variates
# dist.shape

In [115]:
# dist = torch.permute(dist, (1, 3, 4, 2, 0))  # batch_size, H ,W ,n_variates (subset), 2
# dist.shape

In [113]:
# for i in range(len(variate_masks)):
#     print(i, variate_masks[i].sum())

In [17]:
# device = next(model.parameters()).device
# print(device)

In [14]:
# a = torch.ones((1, hparams.data.channels, hparams.data.target_res, hparams.data.target_res)).cuda(7)

In [15]:
# a

In [16]:
# ones