In [1]:
import json
from pathlib import Path
import random
import pprint

import matplotlib
import matplotlib.pyplot as plt

import os

import torch

from config import load_config
from data import (build_datasets,
                  CollatorForMaskedLanguageModeling,
                  CollatorForMaskedSelectedTokens,
                  CollatorForMaskedRandomSelectedTokens,
                  IdentityCollator)
from data import ALL_POSSIBLE_COLORS
from model import MultimodalModel, MultimodalPretrainingModel
from lightning import Trainer, seed_everything
from tqdm.auto import tqdm, trange

from PIL import Image

from torch.utils.data import DataLoader, Subset
from torch.nn.functional import softmax


pp = pprint.PrettyPrinter(indent=2)

In [2]:
exp_name = 'mmlm--n_colors=8c--mlm_probability=0.15' 

to_clean_int = lambda str_: ''.join(filter(str.isdigit, str_))
get_version = lambda p: int(to_clean_int(p.stem)) if to_clean_int(p.stem) else 0

checkpoint_paths = sorted(Path(f'outputs/{exp_name}/').glob('last*.ckpt'), key=get_version, reverse=True)
resume_from_path = str(checkpoint_paths[0])
checkpoint = torch.load(resume_from_path)

print('Epoch:', checkpoint['epoch'])

Epoch: 999


In [3]:
config = load_config(exp_name)

config.vocabulary_path = config.vocabulary_path.replace('/workspace/' ,'/workspace1/')
config.base_path = config.base_path.replace('/workspace/' ,'/workspace1/')

Loading mmlm--n_colors=8c--mlm_probability=0.15 last checkpoint config from outputs/mmlm--n_colors=8c--mlm_probability=0.15/last.ckpt
Add new arg: aug_zero_color = False


In [4]:
config.aug_zero = 4
config.aug_zero_independent = True
config.aug_zero_color = True

config.shuffle_object_identities = False

In [5]:
train_dataset, test_dataset, systematic_dataset, common_systematic_dataset = build_datasets(config)
config.pad_idx = train_dataset.pad_idx
processor = test_dataset.processor

In [6]:
model = MultimodalModel(config)
training_model = MultimodalPretrainingModel(model, config)
training_model.load_state_dict(checkpoint['state_dict'])

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


<All keys matched successfully>

In [7]:
relation_tokens = sorted([processor.vocabulary[w] for w in ['left', 'right', 'behind', 'front']])
color_tokens = sorted(
    set([processor.vocabulary[w] for w in ALL_POSSIBLE_COLORS if w in processor.vocabulary]))
shapes_tokens = sorted([processor.vocabulary[w] for w in ['cylinder', 'sphere', 'cube']])
materials_tokens = sorted([processor.vocabulary[w] for w in ['metal', 'rubber']])
size_tokens = sorted([processor.vocabulary[w] for w in ['small', 'large']])

In [8]:
# collator = CollatorForMaskedLanguageModeling(config, processor)
collator = IdentityCollator(config, processor)
train_loader = DataLoader(train_dataset, shuffle=False, collate_fn=collator, batch_size=1)

In [9]:
train_dataset.n_tokens, len(processor.vocabulary)

(120, 96)

In [13]:
def translate_idx(idx):
    vocab_size = len(processor.vocabulary)
    n_colors = len(color_tokens)
    min_color_idx = min(color_tokens)
    vocab_idx = (idx - vocab_size) // n_colors
    
    if idx < vocab_size:
        return idx
    return idx - vocab_size + min_color_idx - vocab_idx*n_colors

def scene_tensor_to_txt(tensor):
    return ' '.join([processor.inv_vocabulary.get(translate_idx(t), str(t)) for t in tensor.tolist()])

def print_scene_tensor(tensor):
    scene_text = scene_tensor_to_txt(tensor)
    print(scene_text.replace('[PAD]', '').replace('[SEP]','\n     '))
    
def print_parallel(tensor0, tensor1, tensor2, confidences, titles):
    ttl0, ttl1, ttl2 = titles
    print(f'{ttl0:6.6s} {ttl1:6.6s} {ttl2:6.6s}')
    for t0, t1, t2, conf in zip(
            tensor0.tolist(), tensor1.tolist(), tensor2.tolist(), confidences.tolist()):
        w0 = processor.inv_vocabulary[t0]
        w1 = processor.inv_vocabulary[t1]
        w2 = processor.inv_vocabulary[t2]
        
        if w0 == '[SEP]':
            print()
            continue
        if w0 == '[PAD]':
            break
        
        print_txt = f'{w0:6.6s} {w1:6.6s} {w2:6.6s} ({conf:.4f})'
        if w0 != w2:
            print_txt = bold(print_txt)
            

        print(print_txt)
        
def bold(text):
    return ("\033[1m" + text + "\033[0m")

In [28]:
for _ in range(2000):
    b = next(iter(train_loader))
    scene1 = b[1]
    scene.max(), scene[:,:]

    s1 = scene_tensor_to_txt(scene1[0])

    b = next(iter(train_loader))
    scene2 = b[1]
    scene.max(), scene[:,:]

    s2 = scene_tensor_to_txt(scene2[0]) 
    assert (scene1 != scene2).any() # If fails should be here
    assert s1 == s2                 # Never here

AssertionError: 