In [1]:
from PIL import Image
import requests
import torch

from transformers import ViltProcessor, ViltModel
import torch.nn as nn

class ViLT(nn.Module):

    def __init__(self, args = None):
        super().__init__()
        self.args = args
        self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        self.model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm-itm")
        self.linear1 = nn.Linear(768, 312)
        self.classifier1 = nn.Sequential(nn.Linear(768, 312), nn.ReLU(), nn.BatchNorm1d(312), nn.Linear(312, 312))

        self.class_linear1= nn.Linear(768, 201)
        self.class_classifier1 = nn.Sequential(nn.Linear(768, 201), nn.ReLU(), nn.BatchNorm1d(201), nn.Linear(201, 201))

    def forward(self, prompts, images):
        images = [i.cpu() for i in images]
        inputs = self.processor(images, prompts, return_tensors="pt", padding = True)
        for i in inputs:
            inputs[i] = inputs[i].cuda()
        outputs = self.model(**inputs)
        out = outputs.pooler_output

        classifications = []

        classifications.append(self.linear1(out))
        classifications.append(self.classifier1(out))
        classifications.append(self.class_linear1(out))
        classifications.append(self.class_classifier1(out))


        return classifications

In [2]:
from transformers import ViltForMaskedLM, ViltProcessor
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm-itm")

Some weights of the model checkpoint at dandelin/vilt-b32-mlm-itm were not used when initializing ViltForMaskedLM: ['itm_score.fc.bias', 'itm_score.fc.weight']
- This IS expected if you are initializing ViltForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViltForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViltForMaskedLM were not initialized from the model checkpoint at dandelin/vilt-b32-mlm-itm and are newly initialized: ['mlm_score.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
import torch
(torch.tensor([1, 1,1,1,1])).topk(3)

torch.return_types.topk(
values=tensor([1, 1, 1]),
indices=tensor([2, 4, 3]))

In [3]:
from PIL import Image
sentence = ['the color of the birds wings are [MASK]']
#open image 
image = Image.open('../data/CUB_200_2011/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg')
#print image
inputs = processor(image, sentence, return_tensors="pt", padding = True)
out = model(**inputs)

In [4]:
from torch.nn import functional as F
import torch
with open('../data/CUB_200_2011/attributes.txt') as f:
    attributes = f.read().splitlines()

wing_colors = []
wing_color_idx = {}
for i in range(len(attributes)):
    if 'wing_color' in attributes[i]:
        wing_colors.append(attributes[i].split('::')[-1])
for i in wing_colors:
    if i == 'iridescent' or i == 'rufous': continue
    wing_color_idx[i] = processor.tokenizer.vocab[i]
print(wing_color_idx)
one_hot = F.one_hot(torch.tensor(2630), processor.tokenizer.vocab_size)
#one_hot[2665] = 1

final = out['logits'][:, -1, :]

F.cross_entropy(one_hot.unsqueeze(0).float(), final)


{'blue': 2630, 'brown': 2829, 'purple': 6379, 'grey': 4462, 'yellow': 3756, 'olive': 9724, 'green': 2665, 'pink': 5061, 'orange': 4589, 'black': 2304, 'white': 2317, 'red': 2417, 'buff': 23176}


tensor(-1860352.8750, grad_fn=<DivBackward1>)

In [5]:
wing_colors

['blue',
 'brown',
 'iridescent',
 'purple',
 'rufous',
 'grey',
 'yellow',
 'olive',
 'green',
 'pink',
 'orange',
 'black',
 'white',
 'red',
 'buff']

In [36]:
for idx, i in enumerate(processor.tokenizer.vocab):
    if idx % 1000 == 0: print(idx)
    if processor.tokenizer.vocab[i] == 0:
        print(i)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
[PAD]
23000
24000
25000
26000


KeyboardInterrupt: 

In [38]:
dir(processor.tokenizer)

['SPECIAL_TOKENS_ATTRIBUTES',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_tokens',
 '_additional_special_tokens',
 '_auto_class',
 '_batch_encode_plus',
 '_bos_token',
 '_call_one',
 '_cls_token',
 '_convert_encoding',
 '_convert_id_to_token',
 '_convert_token_to_id_with_added_voc',
 '_create_repo',
 '_decode',
 '_decode_use_source_tokenizer',
 '_encode_plus',
 '_eos_token',
 '_eventual_warn_about_too_long_sequence',
 '_eventually_correct_t5_max_length',
 '_from_pretrained',
 '_get_files_timestamps',
 '_get_padding_truncation_strategies',
 '_in_target_context_manager',
 '_mask_token',
 '_pad',
 '_pad_token',
 '_pad