In [None]:
import torch
import torch.utils.data
import matplotlib.pyplot as plt

import yaml

from configs.trainer import TrainingConfig

from models.vision_encoder_decoder import VisionEncoderDecoder

from transformers import AutoTokenizer, PreTrainedTokenizer
from deeplake import load, Dataset
from torchvision.models import ViT_B_16_Weights

In [None]:
def get_dataloader(batch_size, shuffle):
    ds: Dataset = load('hub://activeloop/flickr30k')
    val_dl = ds[27000:]. \
        pytorch(batch_size=batch_size, shuffle=shuffle, num_workers=0,
                buffer_size=32, use_local_cache=True)
    return val_dl


In [None]:
config_file = 'training_configs/local/nano.yaml'
chkpt_file = 'checkpoints/nano.pt'

In [None]:
obj = yaml.safe_load(open(config_file, 'r'))
config: TrainingConfig = TrainingConfig.parse_obj(obj)
config.model.chkpt_path = chkpt_file
print(config)

In [None]:
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(config.tokenizer_str)
kwargs = {}
if tokenizer.eos_token_id is None:
    kwargs['eos_token'] = '<EOS>'
if tokenizer.bos_token_id is None:
    kwargs['bos_token'] = '<BOS>'
if tokenizer.mask_token_id is None and config.trainer.mask_fraction > 0:
    kwargs['mask_token'] = '<MSK>'
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
    config.tokenizer_str, **kwargs)


In [None]:
val_dl = get_dataloader(1, False)

In [None]:
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.has_mps else 'cpu')

In [None]:
model = VisionEncoderDecoder(config.model).to(device)


In [None]:
model.eval()

In [None]:
# num_beams = 4
# num_new_tokens = 64
# top_k = 16
# temperature = 1.0
# consolidation_temperature = 100.0
# batch_size = 1
ignore_index = -100
# beam_expansion_factor = 8
# length_boost = 1.0

num_candidates = 8

In [None]:
# generator = BeamSearchTokenGenerator(
#     model,
#     beam_width=num_beams,
#     temperature=temperature,
#     consolidation_temperature=consolidation_temperature,
#     max_new_tokens=num_new_tokens,
#     no_repeat_n_grams=(2, 3, 4, 5),
#     top_k=top_k,
#     beam_expansion_factor=beam_expansion_factor,
#     eos_token_id=tokenizer.eos_token_id,
#     length_boost=length_boost,
# )

In [None]:
tx = ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1.transforms()


for i, batch in enumerate(val_dl):
    if i == 20:
        break

    x = torch.tensor(batch['image'])
    plt.imshow(x[0])
    plt.show()

    x = tx(x[0].permute(2, 0, 1)).unsqueeze(0)

    prompt = tokenizer.bos_token
    prompt_ids = torch.tensor(tokenizer(text=prompt).input_ids, dtype=torch.long).to(x.device).unsqueeze(0).expand(x.size(0), -1).contiguous()
    x = x.to(device).expand(num_candidates, -1, -1, -1)

    decoded_ids = torch.tensor(
        tokenizer(text=prompt).input_ids,
        dtype=torch.long).to(device).unsqueeze(0).expand(x.size(0), -1).contiguous()

    result = model.generate(images=x,
                            prompt_ids=decoded_ids,
                            temperature=1.0,
                            max_new_tokens=64,
                            top_k=16)
    result = tokenizer.batch_decode(result[:, 1:])

    print('truth', batch['caption_0'][0], batch['caption_1'][0], batch['caption_2'][0], batch['caption_3'][0], batch['caption_4'][0], '\n')
    for gen in result:
        i = gen.find(tokenizer.eos_token)
        gen = gen[:i]
        print(gen)
    print("========================================================")