In [1]:
import torch
from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
    dim = 1024,
    depth = 2,
    dim_head = 64,
    heads = 8,
    num_latents = 64,    # the number of latents to shrink your media sequence to, perceiver style
    num_time_embeds = 4  # say you have 4 images maximum in your dialogue
)

medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from flamingo_pytorch import GatedCrossAttentionBlock

cross_attn = GatedCrossAttentionBlock(
    dim = 1024,
    dim_head = 64,
    heads = 8
)

In [3]:
cross_attn

GatedCrossAttentionBlock(
  (attn): MaskedCrossAttention(
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (to_q): Linear(in_features=1024, out_features=512, bias=False)
    (to_kv): Linear(in_features=1024, out_features=1024, bias=False)
    (to_out): Linear(in_features=512, out_features=1024, bias=False)
  )
  (ff): Sequential(
    (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=1024, out_features=4096, bias=False)
    (2): GELU(approximate=none)
    (3): Linear(in_features=4096, out_features=1024, bias=False)
  )
)

In [9]:
text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)
media_locations = torch.randint(0, 2, (1, 512)).bool()

In [14]:
text = cross_attn(
    text,
    perceived,
    media_locations = media_locations
)
text.shape

torch.Size([1, 512, 1024])

In [18]:
from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

In [19]:
vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

In [20]:

vit = Extractor(vit, return_embeddings_only = True)

In [22]:
import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)


In [23]:
flamingo_palm

FlamingoPaLM(
  (token_emb): Embedding(20000, 1024)
  (img_encoder): Extractor(
    (vit): ViT(
      (to_patch_embedding): Sequential(
        (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
        (1): Linear(in_features=3072, out_features=1024, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (transformer): Transformer(
        (layers): ModuleList(
          (0): ModuleList(
            (0): PreNorm(
              (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (fn): Attention(
                (attend): Softmax(dim=-1)
                (dropout): Dropout(p=0.1, inplace=False)
                (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1024, out_features=1024, bias=True)
                  (1): Dropout(p=0.1, inplace=False)
                )
              )
            )
            (1): PreNorm(
             

In [27]:
# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))
print("text shape :  ",text.shape)

palm_logits = flamingo_palm(text)
print("Palm logits shape : ",palm_logits.shape)

text shape :   torch.Size([2, 512])
Palm logits shape :  torch.Size([2, 512, 20000])


In [28]:
# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

In [29]:
sys.path.append('/Users/caghankoksal/Desktop/SS2022/emic-vqa/')
import os
import sys
import json
import _pickle as cPickle
from PIL import Image
import numpy as np
from torchvision import transforms

from src.datasets.vq_rad_dataset import VQ_Rad_Dataset


In [31]:
dataroot = '/Users/caghankoksal/Desktop/SS2022/emic-vqa/data/external/data_RAD'

In [32]:
train_dataset = VQ_Rad_Dataset(root=dataroot, transform=transforms.Compose([
                                transforms.RandomHorizontalFlip(),
                                transforms.Resize((224,224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]),
                                split='train',
                                question_tokenize=False)

In [72]:
from torch.utils.data import DataLoader
batch_size = 2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [73]:
train_features, questions, answers = next(iter(train_dataloader))

In [74]:
train_features.shape


torch.Size([2, 3, 224, 224])

In [75]:
dialogue = torch.randint(0, 20000, (batch_size, 512))

In [76]:
torch.randint(0, 20000, (batch_size, 512))

tensor([[ 7951, 16511,  3728,  ..., 10789,  6294, 10876],
        [17651,  6312,  4399,  ..., 13858,  9706,   515]])

In [77]:
dialogue[0]

tensor([ 9371,  3759,  4601, 15279, 19101,  9200, 10932,  6262,  3960,  4517,
        19970,  1811,  6833, 13687,  1719, 18717, 13864, 17647,  8979,  1287,
        16146, 19115, 16185, 19267, 10832, 16590, 12013,  4944,  1766, 16188,
        11231,  9270,  9480,  6573,  8595, 11527,  7139,  9485, 13493,  2015,
         7621,  6743, 15780,  6575,  3489, 16592, 11258,  4916,  1901, 18272,
         7876, 15779,  8727, 10864,  1043,  7612, 11702, 12588, 13673, 14100,
         7308, 11839,  7253, 14456,  9238,  2126,  1893,  7938,  4282,  3568,
        12089,  7174, 14488, 16398, 17294,  3178,  3637,  2625, 10731, 11047,
          413,  2220, 10449, 16959,  2959,   958, 16813,  4172, 11279, 18049,
        17363,  8292, 19531, 19374,  5192, 16038, 13807, 16454, 12272,  2722,
         8971, 10444, 19277,  7975,  5954, 19652, 17068,  8838,  6801,  4513,
         5479, 10516,  9131, 12774,  3071, 13605,  8305, 18306, 15422,  7926,
        16436,  1664, 13699,  4699, 19846, 12723,  3502, 17299, 

In [78]:
dialogue.shape


torch.Size([2, 512])

In [79]:
flamingo_logits = flamingo_palm(dialogue, train_features.unsqueeze(1))

In [80]:
questions,answers

(('What is the observed sign of pulmonary consolidation on the right side?',
  'What is abnormal with the ventricles?'),
 ('blunting of the costophrenic angle, loss of the right hemidiaphragm and right heart border',
  'Lateral and third ventricular hydrocephalus'))

In [81]:
flamingo_logits.shape

torch.Size([2, 512, 20000])

In [82]:
from transformers import AutoTokenizer

In [96]:
from transformers import AutoTokenizer

checkpoint = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [97]:
questions[0]

'What is the observed sign of pulmonary consolidation on the right side?'

In [98]:
inputs = tokenizer(list(questions), padding=True, truncation=True, return_tensors="pt")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [99]:
inputs

{'input_ids': tensor([[  102,  1792,   165,   111,  1058,   423,   131,  5186, 19074,   191,
           111,  2083,  2480,  3912,   103],
        [  102,  1792,   165,  4592,   190,   111, 15077, 30113,  3912,   103,
             0,     0,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])}

PreTrainedTokenizerFast(name_or_path='allenai/scibert_scivocab_uncased', vocab_size=31090, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})