In [13]:
from typing import Union

import torch
import torch.nn as nn
import torch.utils.data


import yaml

from configs.trainer import TrainingConfig
from configs.models import ViTConfig
from training.utils import WrapperDataLoader

from models.vision_encoder_decoder import VisionEncoderDecoder

from accelerate import Accelerator
from transformers import AutoTokenizer, PreTrainedTokenizer
from deeplake import load, Dataset
from torchvision import transforms
from torchvision.models import ViT_B_16_Weights

from models.generation_utils import BeamSearchTokenGenerator
from operator import itemgetter

In [2]:
def get_dataloader(tokenizer, batch_size, shuffle, is_vit):
    txforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((128, 128)),
        transforms.Normalize(mean=(0.4274, 0.4218, 0.3878), std=(0.2754, 0.2705, 0.2874)),
    ]) if not is_vit else ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1.transforms()
    ds: Dataset = load('hub://activeloop/flickr30k')

    tokenizer.pad_token = tokenizer.eos_token

    def _tok(x):
        return tokenizer(text=x[0], return_tensors='pt',
                         max_length=256,
                         truncation='longest_first',
                         padding='max_length')

    def _transform(x):
        image = x['image'] if not is_vit else torch.tensor(x['image']).permute(2, 0, 1)
        result = {
            'image': txforms(image),
        }
        for k in range(5):
            data = x[f'caption_{k}']
            tokenized = _tok(data)
            result[f'input_ids_{k}'] = tokenized.input_ids.squeeze(0)
            result[f'attn_mask_{k}'] = tokenized.attention_mask.squeeze(0)
        return result

    train_dl = ds.query("SELECT * WHERE ROW_NUMBER() < 27000"). \
        pytorch(batch_size=batch_size, shuffle=shuffle, num_workers=0, transform=_transform,
                buffer_size=256, use_local_cache=True)
    val_dl = ds.query("SELECT * WHERE ROW_NUMBER() >= 27000 "). \
        pytorch(batch_size=batch_size, shuffle=shuffle, num_workers=0, transform=_transform,
                buffer_size=32, use_local_cache=True)
    return train_dl, val_dl


In [3]:
config_file = 'training_configs/local/nano.yaml'
chkpt_file = 'checkpoints/nano.pt'

In [4]:
obj = yaml.safe_load(open(config_file, 'r'))
config: TrainingConfig = TrainingConfig.parse_obj(obj)
print(config)

model=VisionEncoderDecoderConfig(vision_encoder_config=ViTConfig(refine_base_model=False, n_embd_out_vit=768, n_cls=16, gate_sizes=(1024,)), decoder_config=TransformerDecoderConfig(use_advanced_pos_emb=False, advanced_pos_emb_gate_sizes=None, pretrained_model=<ModelType.GPT2: 'gpt2'>, enable_gradient_checkpointing=False, n_layer=12, skip_alternate_cross_attn=True, block_size=1024, vocab_size=50257, transformer_config=TransformerConfig(rotator_config=MLPConfig(ff_mult=4.0), is_causal=False, is_cross_attn=True, max_block_size=None, is_sparse_attn=False, sparsity_factor=0.5, attn_config=SelfAttentionConfig(attn_dropout=0.1, bias=True, dropout=0.1, n_head=12, n_embd=768, attn_type=<SelfAttentionType.MULTI_HEAD: 'multi_head'>))), loose_match_decoder_state_dict=True, chkpt_path=None, use_cross_attn=True, use_soft_prompting=True, no_repeat_n_grams=(2, 3, 4, 5)) disable_flash=False ignore_index=-100 batch_size=4 dataloader_buffer_size=5 shuffle=True gradient_accumulation_steps=2 epochs=1 num_s

/var/folders/1n/tqgq4c0j5p7cs5dk9n284h7c0000gn/T/ipykernel_45659/3403858762.py:2: PydanticDeprecatedSince20: The `parse_obj` method is deprecated; use `model_validate` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.0.3/migration/
  config: TrainingConfig = TrainingConfig.parse_obj(obj)


In [5]:
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(config.tokenizer_str)
kwargs = {}
if tokenizer.eos_token_id is None:
    kwargs['eos_token'] = '<EOS>'
if tokenizer.mask_token_id is None and config.trainer.mask_fraction > 0:
    kwargs['mask_token'] = '<MSK>'
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
    config.tokenizer_str, **kwargs)


In [6]:
train_dl, val_dl = get_dataloader(tokenizer,
                                  config.dataloader_buffer_size * config.batch_size,
                                  config.shuffle,
                                  isinstance(config.model.vision_encoder_config, ViTConfig))
train_dl, val_dl = WrapperDataLoader(train_dl, batch_size=config.batch_size, ignore_idx=config.ignore_index,
                                     epochs=config.epochs), \
                   WrapperDataLoader(val_dl, batch_size=config.batch_size, ignore_idx=config.ignore_index,
                                     epochs=100000)

\

Opening dataset in read-only mode as you don't have write permissions.


\

This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/flickr30k



/

hub://activeloop/flickr30k loaded successfully.





In [7]:
device = 'cpu'  # 'cuda' if torch.cuda.is_available() else ('mps' if torch.has_mps else 'cpu')

In [8]:
config.model.chkpt_path = chkpt_file
model = VisionEncoderDecoder(config.model).to(device)


In [9]:
model.eval()

VisionEncoderDecoder(
  (decoder): TransformerDecoder(
    (transformer): ModuleDict(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): TransformerBlock(
          (ln_1): LayerNorm()
          (attn): MultiHeadAttention(
            (c_attn): Linear(in_features=768, out_features=2304, bias=True)
            (c_proj): Linear(in_features=768, out_features=768, bias=True)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): _MLP(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='tanh')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (cross_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuan

In [15]:
num_beams = 4
num_new_tokens = 64
top_k = 16
temperature = 1.0
consolidation_temperature = 100.0
batch_size = 1
ignore_index = -100
beam_expansion_factor = 8
length_boost = 1.0

In [16]:
generator = BeamSearchTokenGenerator(
    model,
    beam_width=num_beams,
    temperature=temperature,
    consolidation_temperature=consolidation_temperature,
    max_new_tokens=num_new_tokens,
    no_repeat_n_grams=(2, 3, 4),
    top_k=top_k,
    beam_expansion_factor=beam_expansion_factor,
    eos_token_id=tokenizer.eos_token_id,
    length_boost=length_boost,
)

In [None]:
for i, (x, labels) in enumerate(val_dl):
    if i == 20:
        break
    x = x.to(device)
    label_ = labels[0]

    prompt = 'A'
    prompt_ids = torch.tensor(tokenizer(text=prompt).input_ids, dtype=torch.long).to(x.device).unsqueeze(0).expand(x.size(0), -1).contiguous()
    result_ids, scores_model = generator(x, prompt_ids)
    result_ids, scores_model = result_ids.reshape(-1, result_ids.size(-1)), scores_model.reshape(-1)
    result = tokenizer.batch_decode(result_ids)
    reference = tokenizer.batch_decode([label_[label_ != ignore_index]])[0]

    truth = reference
    print('truth', reference, '\n')
    for k, (gen, score_model) in enumerate(sorted(zip(result, scores_model.tolist()), key=itemgetter(1), reverse=True)):
        i = gen.find(tokenizer.eos_token)
        gen = gen[:i]
        print(gen, score_model)
    print("========================================================")

truth Several campers are standing eating near a campfire.<|endoftext|> 

A child holding up two cows on grass on the grass side, one in the background -54.39771270751953
A child holding up two cows on grass on the grass side, one in the background -54.39771270751953
A child holding up two cows on grass on the grass side, one in the background -54.39771270751953
A child holding up two cows on grass on the grass side, one in the background -54.39771270751953
A procession stands among other signs as they walk among trees in red cars in blue.

 


  -82.3949966430664
A procession stands among other signs as they walk among trees in red cars in blue.

 


  -82.3949966430664
A procession stands among other signs as they walk among trees in red cars in blue.

 


  -82.3949966430664
A procession stands among other signs as they walk among trees in red cars in blue.

 


  -82.3949966430664
A woman on the backpacking.. and a couple are taking an afternoon away in another woman in a fire on s