In [1]:
from collections import defaultdict, Counter
import torchvision.transforms as transforms
import torchvision
import torch
from models import Encoder, Decoder, CategoricalVAE
from models import gumbel_softmax
import itertools
from tqdm import tqdm
import numpy as np

In [2]:
transform = transforms.Compose([transforms.ToTensor()])
training_images = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)

In [3]:
batch_size = 1
train_dataset = torch.utils.data.DataLoader(
    dataset=training_images, batch_size=batch_size, shuffle=True
)

image_shape = next(iter(train_dataset))[0][0].shape  # [1, 28, 28]
K = 26  # number of classes
N = 3  # number of categorical distributions

encoder = Encoder(N, K, image_shape)
decoder = Decoder(N, K, image_shape)
model = CategoricalVAE(encoder, decoder)

state_dict = torch.load("outputs/default/save_49999.pt", weights_only=True)
model.load_state_dict(state_dict)

N = 3 and K = 26


<All keys matched successfully>

In [4]:
latents = []
labels = []
for batch in tqdm(train_dataset):
    with torch.no_grad():
        phi, x_hat = model(batch[0], temperature=1.0)
        z_given_x = gumbel_softmax(phi, temperature=1.0, hard=True, batch=True)
        latents.append(z_given_x.argmax(axis=2)[0].numpy())
        labels.append(batch[1].item())

100%|██████████| 60000/60000 [00:20<00:00, 2970.82it/s]


In [5]:
n_grams = defaultdict(dict)
for x in [1, 2, 3]:
    for n_gram in list(itertools.product([x for x in range(26)], repeat=x)):
        n_grams[n_gram]["length"] = x

In [6]:
total = len(labels)
obs_probs = {k:v/total for k,v in dict(Counter(labels)).items()}

In [7]:
latents_uniq, counts =  np.unique(latents,axis=0,return_counts=True)

In [8]:
joint_probs = defaultdict(lambda: defaultdict(int))
for latent_id in tqdm(range(latents_uniq.shape[0])):
    for idx, latent in enumerate(latents):
        if np.array_equal(latent,latents_uniq[latent_id]):
            joint_probs[latent_id][labels[idx]] += 1

100%|██████████| 3470/3470 [03:52<00:00, 14.90it/s]


In [9]:
for latent_id in tqdm(range(latents_uniq.shape[0])):
    for obs_type in joint_probs[latent_id].keys():
        joint_probs[latent_id][obs_type] /= total

100%|██████████| 3470/3470 [00:00<00:00, 2326816.13it/s]


In [10]:
np.seterr(divide="ignore", invalid="ignore")

for latent_id in range(latents_uniq.shape[0]):
    latent_prob = counts[latent_id]/total
    for obs_type in list(obs_probs.keys()):
        obs_prob = obs_probs[obs_type]
        joint_prob = joint_probs[latent_id][obs_type]
        joint_self_information = -np.log2(joint_prob)
        npmi = np.log2(joint_prob / (latent_prob * obs_prob)) / joint_self_information
        if npmi > 0.3:
            print(f"Possible hit for latent {latents_uniq[latent_id]} and observation type {obs_type}")
            

Possible hit for latent [0 1 7] and observation type 1
Possible hit for latent [0 4 2] and observation type 2
Possible hit for latent [ 0  7 25] and observation type 5
Possible hit for latent [0 9 6] and observation type 0
Possible hit for latent [ 0 12 25] and observation type 5
Possible hit for latent [ 0 14  5] and observation type 2
Possible hit for latent [ 0 18 23] and observation type 1
Possible hit for latent [ 0 19  8] and observation type 4
Possible hit for latent [ 0 19  9] and observation type 4
Possible hit for latent [ 0 19 10] and observation type 4
Possible hit for latent [ 0 20  0] and observation type 4
Possible hit for latent [ 0 20  9] and observation type 4
Possible hit for latent [ 0 20 10] and observation type 4
Possible hit for latent [ 0 22  6] and observation type 0
Possible hit for latent [1 0 4] and observation type 7
Possible hit for latent [1 4 0] and observation type 7
Possible hit for latent [1 6 9] and observation type 9
Possible hit for latent [ 1  6 1

In [11]:
for latent_id in range(latents_uniq.shape[0]):
    latent_prob = counts[latent_id]/total
    for obs_type in list(obs_probs.keys()):
        obs_prob = obs_probs[obs_type]
        joint_prob = joint_probs[latent_id][obs_type]
        joint_self_information = -np.log2(joint_prob)
        npmi = np.log2(joint_prob / (latent_prob * obs_prob)) / joint_self_information
        if npmi > 0.4:
            print(f"Possible hit for latent {latents_uniq[latent_id]} and observation type {obs_type}")
            

Possible hit for latent [ 0  7 25] and observation type 5
Possible hit for latent [ 3 13 18] and observation type 1
Possible hit for latent [ 3 13 23] and observation type 1
Possible hit for latent [ 4  9 25] and observation type 5
Possible hit for latent [ 7 11  2] and observation type 6
Possible hit for latent [8 1 8] and observation type 1
Possible hit for latent [12 13  4] and observation type 7
Possible hit for latent [24 20 17] and observation type 6
Possible hit for latent [25 13  7] and observation type 1
Possible hit for latent [25 13 23] and observation type 1
Possible hit for latent [25 18 23] and observation type 1
