In [None]:
! pip install --editable '/private/home/rdessi/visualize_EGG/'

In [1]:
import argparse
from collections import defaultdict, Counter
import random
import json

import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
import torch

from egg.zoo.emcom_as_ssl.games import build_game

seed = 111
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
import egg
print(egg.__path__)

['/private/home/rdessi/EGG/egg']


In [1]:
# TODO: change path to checkpoint here and fix the params_dict
checkpoint_path = (
#    "/private/home/rdessi/neurips_exp/new_exps/fixed_temp_5_augmentations_shared_vision_xent/40360153_0/final.tar"
#    "/private/home/rdessi/neurips_exp/new_exps/fixed_temp_5_augmentations_no_shared_vision_xent/40335995_0/final.tar"
#    "/private/home/rdessi/neurips_exp/new_exps/simclr_augmentations_shared_vision_ntxent/40329294_0/final.tar"
    "/private/home/rdessi/neurips_exp/new_exps/simclr_augmentations_shared_vision_ntxent/40329294_0/final.tar"
)

params = dict(
    loss_type='ntxent',
    shared_vision=True,
    use_augmentations=False,
    simclr_sender=False
)

params_fixed = dict(
    pretrain_vision=False,
    model_name="resnet50",
    similarity="cosine",
    loss_temperature=1.0,
    batch_size=128,
    random_seed=111,
    gs_temperature=5.0,
    gs_temperature_decay=1.0,
    minimum_gs_temperature=1.0,
    update_gs_temp_frequency=1,
)

params.update(params_fixed)

distributed_context = argparse.Namespace(is_distributed=False)

other_params = dict(
    checkpoint_dir='', 
    checkpoint_freq=1, cuda=True, dataset_dir='/datasets01/imagenet_full_size/061417/train', 
    device=torch.device(type='cuda'), 
    distributed_context=distributed_context, distributed_port=18363, fp16=True, image_size=224,
    load_from_checkpoint=None, max_len=1, n_epochs=100, no_cuda=False, num_workers=4, 
    optimizer='adam', pdb=False, preemptable=True, projection_hidden_dim=2048, projection_output_dim=2048,
    straight_through=False, tensorboard=False, tensorboard_dir='runs/', train_gs_temperature=False, update_freq=1, 
    validation_dataset_dir='/datasets01/imagenet_full_size/061417/val', validation_freq=5, vocab_size=10, 
    wandb=False, weight_decay=1e-05
)

params.update(other_params)
params = argparse.Namespace(**params)

NameError: name 'argparse' is not defined

In [5]:
game = build_game(params)
checkpoint = torch.load(checkpoint_path)
game.load_state_dict(checkpoint.model_state_dict)

<All keys matched successfully>

In [6]:
from egg.zoo.emcom_as_ssl.data import get_dataloader

opts = params

train_loader, validation_loader = get_dataloader(
    dataset_dir=opts.dataset_dir,
    image_size=opts.image_size,
    batch_size=opts.batch_size,
    validation_dataset_dir=opts.validation_dataset_dir,
    num_workers=opts.num_workers,
    use_augmentations=opts.use_augmentations,
    is_distributed=opts.distributed_context.is_distributed,
    seed=opts.random_seed,
)
game = game.cuda().eval()

In [None]:
msg2img = defaultdict(list)

resnet_output_list = []
original_image_list = []

for (x_i, x_j, original_image), labels in validation_loader:
    x_i = x_i.cuda()
    x_j = x_j.cuda()
    
    with torch.no_grad():
        sender_encoded_input, receiver_encoded_input = game.vision_module(x_i, x_j)
        message, message_like, resnet_output = game.game.sender(sender_encoded_input)

    for i, symbol in enumerate(message_like):
        sym = torch.argmax(symbol).item()
        msg2img[sym].append(original_image[i])
        
        resnet_output_list.append(resnet_output[i])
        original_image_list.append(original_image[i])

In [None]:
cnt_msg = Counter()
print(len(msg2img))
for k, v in msg2img.items():
    cnt_msg[k] = len(v)

cnt_msg.most_common(20)

In [None]:
msg = 1547
images = random.sample(msg2img[msg], 30)

for img_idx in range(0, len(images), 3):
    v = torch.cat([
        images[img_idx].permute(1, 2, 0),
        images[img_idx+1].permute(1, 2, 0),
        images[img_idx+2].permute(1, 2, 0),
        # images[img_idx+3].permute(1, 2, 0),
    ], dim=1)
    plt.imshow(v)
    plt.show()

In [None]:
resnet_output_np = torch.stack(resnet_output_list[:10000], dim=0).cpu().numpy()
kmeans = KMeans(n_clusters=1000, random_state=0).fit(resnet_output_np)

In [None]:
cnt_kmeans = Counter()
len(kmeans.labels_)

cluster2img = defaultdict(list)

for idx, cluster in enumerate(kmeans.labels_):
    cluster2img[cluster].append(original_image_list[idx])
    cnt_kmeans[cluster] += 1
    
cnt_kmeans.most_common(10)

In [None]:
msg = 75
images_kmeans = random.sample(cluster2img[msg], 30)

for img_idx in range(0, len(images_kmeans), 3):
    v = torch.cat([
        images_kmeans[img_idx].permute(1, 2, 0),
        images_kmeans[img_idx+1].permute(1, 2, 0),
        images_kmeans[img_idx+2].permute(1, 2, 0),
    ], dim=1)
    plt.imshow(v)
    plt.show()