In [1]:
import os
os.chdir('/mnt/ialabnas/homes/fidelrio/systematic-text-representations/')

import json
from pathlib import Path
import random
import pprint

import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


import torch

from config import load_config
from data import build_datasets, build_loader, build_detailed_test_dataloaders
from data import CollatorForMaskedSelectedTokens, CollatorForMaskedRandomSelectedTokens, IdentityCollator
from data import ALL_POSSIBLE_COLORS
from model import MultimodalModel, MultimodalPretrainingModel
from utils import load_checkpoint
from lightning import Trainer, seed_everything
from tqdm.auto import tqdm, trange

from PIL import Image

from torch.utils.data import DataLoader, Subset
from torch.nn.functional import softmax

import numpy as np
from sklearn.decomposition import PCA

pp = pprint.PrettyPrinter(indent=2)

In [2]:
# %matplotlib inline 

In [3]:
def scene_tensor_to_txt(tensor):
    return ' '.join([processor.inv_vocabulary[t] for t in tensor.tolist()])

def print_scene_tensor(tensor):
    scene_text = scene_tensor_to_txt(tensor)
    print(scene_text.replace('[PAD]', '').replace('[SEP]','\n     '))
    
def print_parallel(tensor0, tensor1, tensor2, confidences, titles):
    ttl0, ttl1, ttl2 = titles
    print(f'{ttl0:6.6s} {ttl1:6.6s} {ttl2:6.6s}')
    for t0, t1, t2, conf in zip(
            tensor0.tolist(), tensor1.tolist(), tensor2.tolist(), confidences.tolist()):
        w0 = processor.inv_vocabulary[t0]
        w1 = processor.inv_vocabulary[t1]
        w2 = processor.inv_vocabulary[t2]
        
        if w0 == '[SEP]':
            print()
            continue
        if w0 == '[PAD]':
            break
        
        print_txt = f'{w0:6.6s} {w1:6.6s} {w2:6.6s} ({conf:.4f})'
        if w0 != w2:
            print_txt = bold(print_txt)
            

        print(print_txt)
        
def bold(text):
    return ("\033[1m" + text + "\033[0m")

In [4]:
device = torch.device('cuda')

n_colors = 8
epoch = None
exp_name = f'mmlm--n_colors={n_colors}c--mlm_probability=0.15'

checkpoint = load_checkpoint(exp_name, epoch=epoch)
print('Epoch:', checkpoint['epoch'])

Epoch: 999


In [5]:
!ls outputs/$exp_name

'epoch=109-step=16170.ckpt'  'epoch=579-step=85260.ckpt'
'epoch=119-step=17640.ckpt'  'epoch=589-step=86730.ckpt'
'epoch=129-step=19110.ckpt'  'epoch=599-step=88200.ckpt'
'epoch=139-step=20580.ckpt'  'epoch=59-step=8820.ckpt'
'epoch=149-step=22050.ckpt'  'epoch=609-step=89670.ckpt'
'epoch=159-step=23520.ckpt'  'epoch=619-step=91140.ckpt'
'epoch=169-step=24990.ckpt'  'epoch=629-step=92610.ckpt'
'epoch=179-step=26460.ckpt'  'epoch=639-step=94080.ckpt'
'epoch=189-step=27930.ckpt'  'epoch=649-step=95550.ckpt'
'epoch=199-step=29400.ckpt'  'epoch=659-step=97020.ckpt'
'epoch=19-step=2940.ckpt'    'epoch=669-step=98490.ckpt'
'epoch=209-step=30870.ckpt'  'epoch=679-step=99960.ckpt'
'epoch=219-step=32340.ckpt'  'epoch=689-step=101430.ckpt'
'epoch=229-step=33810.ckpt'  'epoch=699-step=102900.ckpt'
'epoch=239-step=35280.ckpt'  'epoch=69-step=10290.ckpt'
'epoch=249-step=36750.ckpt'  'epoch=709-step=104370.ckpt'
'epoch=259-step=38220.ckpt'  'epoch=719-step=105840.ckpt'
'epoch=269-st

In [6]:
config = load_config(exp_name)

config.vocabulary_path = config.vocabulary_path.replace('/workspace/' ,'/workspace1/')
config.base_path = config.base_path.replace('/workspace/' ,'/workspace1/')

Loading mmlm--n_colors=8c--mlm_probability=0.15 last checkpoint config from outputs/mmlm--n_colors=8c--mlm_probability=0.15/last.ckpt
Add new arg: aug_zero_color = False


In [7]:
# pp.pprint(vars(config))

In [11]:
train_dataset, test_dataset, systematic_dataset, common_systematic_dataset = build_datasets(config)
config.pad_idx = train_dataset.pad_idx
config.n_tokens = train_dataset.n_tokens


In [12]:
test_loaders = build_detailed_test_dataloaders(test_dataset, config) # type_of_tokens_to_test
systematic_loaders = build_detailed_test_dataloaders(systematic_dataset, config) # type_of_tokens_to_test

In [13]:
model = MultimodalModel(config).to(device)
training_model = MultimodalPretrainingModel(model, config).to(device)
training_model.load_state_dict(checkpoint['state_dict'])

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


<All keys matched successfully>

In [14]:
processor = test_dataset.processor
mask_token_idx = processor.vocabulary['[MASK]']

In [19]:
class RandomPixelShuffle(object):
    def __call__(self, img):
        channels, height, width = img.size()
        indices = np.random.permutation(height * width)
        shuffled_img = img.view(channels, -1)[:, indices].view(channels, height, width)
        return shuffled_img

from torchvision import transforms
try:
    original_transform
except NameError:
    original_transform = processor.image_transform

processor.image_transform = transforms.Compose([
    original_transform,
    RandomPixelShuffle()
])

In [20]:
relation_tokens = sorted([processor.vocabulary[w] for w in ['left', 'right', 'behind', 'front']])
color_tokens = sorted(
    [processor.vocabulary[w] for w in ALL_POSSIBLE_COLORS if w in processor.vocabulary])
shapes_tokens = sorted([processor.vocabulary[w] for w in ['cylinder', 'sphere', 'cube']])
materials_tokens = sorted([processor.vocabulary[w] for w in ['metal', 'rubber']])
size_tokens = sorted([processor.vocabulary[w] for w in ['small', 'large']])

In [21]:
# collator = CollatorForMaskedLanguageModeling(config, processor)
collator = CollatorForMaskedSelectedTokens(config, processor, tokens=color_tokens)
# collator = CollatorForMaskedRandomSelectedTokens(config, processor, tokens=shapes_tokens, p=0.2)
# collator = IdentityCollator(config, processor)

In [22]:
# sample_idx = 333
sample_idx = random.randint(0, len(test_dataset))
image, scene = test_dataset.retrieve_raw(sample_idx)
image_tensor, scene_tensor = test_dataset[sample_idx]

collated_images, collated_scenes, collated_labels = collator([(image_tensor, scene_tensor)])
collated_images = collated_images.to(device)
collated_scenes = collated_scenes.to(device)
collated_labels = collated_labels.to(device)

print(sample_idx)

13979


In [23]:
output_logits = model(collated_images, collated_scenes)

confidences = softmax(output_logits, dim=-1).max(dim=-1).values
predictions = output_logits.argmax(dim=-1)

In [24]:
trainer = Trainer(max_epochs=config.max_epochs,
                  accelerator="gpu",
                  devices=torch.cuda.device_count()
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [25]:
vocab = processor.vocabulary

relation_tokens = sorted(
    [vocab[w] for w in ['left', 'right', 'behind', 'front'] if w in vocab])
colors_tokens = sorted(
    [vocab[w] for w in ALL_POSSIBLE_COLORS if w in vocab])
#     [vocab[w] for w in ['blue', 'brown', 'cyan', 'green', 'red', 'purple', 'yellow', 'gray']])
shapes_tokens = sorted(
    [vocab[w] for w in ['cylinder', 'sphere', 'cube'] if w in vocab])
materials_tokens = sorted(
    [vocab[w] for w in ['metal', 'rubber'] if w in vocab])
sizes_tokens = sorted(
    [vocab[w] for w in ['small', 'large'] if w in vocab])

random_baseline = {
    'relation':  1 / len(relation_tokens),
    'color':  1 / len(color_tokens),
    'shapes':  1 / len(shapes_tokens),
    'materials':  1 / len(materials_tokens),
    'size':  1 / len(size_tokens),
    'identity':  1 / len(processor.vocabulary),
}

In [26]:
batch_size = 256
test_indices = random.sample(range(len(test_dataset)), k=batch_size)
pc_subset_test = Subset(test_dataset, test_indices)
pc_subset_systematic = Subset(systematic_dataset, test_indices)

colors_collator = CollatorForMaskedSelectedTokens(config, processor, tokens=colors_tokens)
shapes_collator = CollatorForMaskedSelectedTokens(config, processor, tokens=shapes_tokens)
materials_collator = CollatorForMaskedSelectedTokens(config, processor, tokens=materials_tokens)
sizes_collator = CollatorForMaskedSelectedTokens(config, processor, tokens=sizes_tokens)
dlkwargs = {
    'batch_size': batch_size,
    'num_workers': int(os.environ.get("SLURM_CPUS_PER_TASK", 4)),
    'pin_memory': torch.cuda.is_available(),
}

for task, collator in [('colors', colors_collator),
                       ('shapes', shapes_collator),
                       ('materials', materials_collator),
                       ('sizes', sizes_collator)]:
    
    test_loaders[task] = DataLoader(
        pc_subset_test, shuffle=False, collate_fn=collator, **dlkwargs)
    systematic_loaders[task] = DataLoader(
        pc_subset_systematic, shuffle=False, collate_fn=collator, **dlkwargs)
    

In [27]:
feature_maps = []  # This will be a list of Tensors, each representing a feature map

def hook_feat_map(mod, inp, out):
    feature_maps.clear()
    feature_maps.append(out)

model.transformer.register_forward_hook(hook_feat_map)

<torch.utils.hooks.RemovableHandle at 0x7f8dad491b10>

In [28]:
# images.shape, scenes.shape, labels.shape

In [29]:
# tasks = ['colors', 'shapes', 'materials', 'sizes']
tasks = ['shapes']

In [None]:
feats_by_set = {}
gt_by_set = {}
for test_name, loaders in [('test', test_loaders), ('systematic', systematic_loaders)]:
    feats_by_task = {}
    gt_by_task = {}
    for task in tasks:
        images, scenes, labels = next(iter(loaders[task]))
        images, scenes, labels = images.to(device), scenes.to(device), labels.to(device)
        cimages, cscenes, clabels = images, scenes, labels
        with torch.no_grad():
            output_logits = model(images, scenes)

            features = feature_maps[0]
            confidences = softmax(output_logits, dim=-1).max(dim=-1).values
            predictions = output_logits.argmax(dim=-1)

            scene_features = features.transpose(1,0)[:,-config.max_scene_size:]
            mask_idxs = (scenes == mask_token_idx)
            gt_by_task[task] = labels[:,-config.max_scene_size:][mask_idxs].cpu()
            feats_by_task[task] = scene_features[mask_idxs].cpu()
            
    feats_by_set[test_name] = feats_by_task
    gt_by_set[test_name] = gt_by_task

In [None]:
clf_idxs = torch.unique(torch.cat([gt_by_set['test'][t] for t in tasks]))
clf_idxs_by_task = {t: torch.unique(gt_by_set['test'][t]) for t in tasks}
all_clf_vectors = model.classifier.weight.data.cpu()
clf_vectors = model.classifier.weight.data.cpu()[clf_idxs]
clf_vectors_by_task = {t: all_clf_vectors[tidxs] for t, tidxs in clf_idxs_by_task.items()}

In [None]:
def scatter_pca(X,
                y, 
                title='', 
                special_X=None, 
                special_y=None, 
                don_t_label_these=[], 
                labels_to_use=[], 
                special_labels_to_use=[], 
                ax=None):

    is_3d = X.shape[-1] == 3

    if ax is None:
        fig = plt.figure(figsize=(9*1.75,5*1.75))
        if is_3d:
            ax = fig.add_subplot(projection='3d')
        else:
            ax = fig.add_subplot()

    label_namer = processor.inv_vocabulary
    if labels_to_use:
        label_namer = labels_to_use
    for label_idx in sorted(set(y)):
        idxs = y == label_idx 
        label = label_namer[label_idx]

        plot_args = [X[:,0][idxs], X[:,1][idxs]]
        if is_3d:
            plot_args = plot_args + [X[:,2][idxs]]

        plot_kwargs = {}
        if not don_t_label_these or label not in don_t_label_these:
            plot_kwargs['label'] = label

        scatter_shapes = ax.scatter(*plot_args, **plot_kwargs)
       
    special_label_namer = processor.inv_vocabulary
    if labels_to_use:
        special_label_namer = labels_to_use
    if special_labels_to_use:
        special_label_namer = special_labels_to_use
    if special_X is not None:
        for label_idx in sorted(set(special_y)): 
            label = special_label_namer[label_idx]

            plot_kwargs = {}
            if not don_t_label_these or label not in don_t_label_these:
                plot_kwargs['label'] = label

            special_idxs = special_y == label_idx 
            special_plot_args = [special_X[:,0][special_idxs], special_X[:,1][special_idxs]]
            if is_3d:
                special_plot_args = special_plot_args + [special_X[:,2][special_idxs]]

#             color = scatter_shapes.get_facecolors()[0]
#             plot_kwargs['color'] = color
            plot_kwargs['marker'] = '*'
            if not is_3d:
                plot_kwargs['s'] = 200
        
            ax.scatter(*special_plot_args, **plot_kwargs)


    if title:
        ax.set_title(title)
    ax.legend(framealpha=1, loc='upper left')

    # Show plot
    # plt.savefig('exports/base-attributes.pdf', format='pdf', dpi=300, bbox_inches='tight')
    if not ax:
        plt.show()

In [None]:
don_t_label_these = [] if n_colors <= 27 else ALL_POSSIBLE_COLORS 

### Shape Task

In [None]:
%matplotlib inline

In [None]:
special_X = torch.cat([clf_vectors_by_task[t] for t in ['shapes']]).numpy()
special_gts = torch.cat([clf_idxs_by_task[t] for t in ['shapes']]).numpy()


# special_X = torch.cat([clf_vectors_by_task[t] for t in tasks]).numpy()
# special_gts = torch.cat([clf_idxs_by_task[t] for t in tasks]).numpy()
# X_2d_clf = pca.transform(special_X)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(1.5*9*1.75,5*1.75))

X = feats_by_set['test']['shapes'].numpy()
all_gts = gt_by_set['test']['shapes'].numpy()

pca = PCA(n_components=2)
X_2d = pca.fit_transform(X)


scatter_pca(
    X_2d, 
    all_gts, 
    special_X=pca.transform(special_X),
    special_y=special_gts,
    don_t_label_these=[], 
    title='IID Test', 
    ax=axs[0])


X = feats_by_set['systematic']['shapes'].numpy()
all_gts = gt_by_set['systematic']['shapes'].numpy()

pca = PCA(n_components=2)
X_2d = pca.fit_transform(X)

scatter_pca(
    X_2d, 
    all_gts, 
    special_X=pca.transform(special_X),
    special_y=special_gts,
    don_t_label_these=[], 
    title='Systematic Test', 
    ax=axs[1])

plt.savefig(f'exports/shape-embeddings-not-mixed-epoch={epoch}.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
%matplotlib inline

In [None]:
X = torch.cat([feats_by_set[set_]['shapes'] for set_ in ['test', 'systematic']]).numpy()
all_gts = [torch.full_like(gt_by_set[set_]['shapes'], tidx) for tidx, set_ in enumerate(['test', 'systematic'])]
all_gts = torch.cat(all_gts).numpy()

pca = PCA(n_components=2)
X_2d = pca.fit_transform(X)

scatter_pca(X_2d, 
            all_gts,
            labels_to_use=['test', 'systematic'], 
            special_X=pca.transform(special_X),
            special_y=special_gts,
            special_labels_to_use=processor.inv_vocabulary, 
            don_t_label_these=[], 
            title='Test and Systematic Test')

plt.savefig(
    f'exports/shape-embeddings-mixed-by-set-epoch={epoch}.pdf', format='pdf', dpi=300, bbox_inches='tight')

In [None]:
X = torch.cat([feats_by_set[set_]['shapes'] for set_ in ['test', 'systematic']]).numpy()
all_gts = [gt_by_set[set_]['shapes'] for set_ in ['test', 'systematic']]
all_gts = torch.cat(all_gts).numpy()

pca = PCA(n_components=2)
X_2d = pca.fit_transform(X)

scatter_pca(X_2d, 
            all_gts,
#             labels_to_use=['test', 'systematic'], 
            special_X=pca.transform(special_X),
            special_y=special_gts,
            special_labels_to_use=processor.inv_vocabulary, 
            don_t_label_these=[], 
            title='Test and Systematic Test')

plt.savefig(
    f'exports/shape-embeddings-mixed-by-shape-epoch={epoch}.pdf', format='pdf', dpi=300, bbox_inches='tight')