In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision as tv
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision import tv_tensors  # we'll describe this a bit later, bare with us

import torchvision.datasets as datasets
from pathlib import Path

from torchview import draw_graph

import constants
import dataset
import util
import json
import pandas as pd
import models 
from models import VQANet
import matplotlib.pyplot as plt
import numpy as np
import time
import gc

from transformers import AutoTokenizer
import traceback

USE_GPU = True
dtype = torch.float32 # We will be using float throughout this tutorial.

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():      
    device = 'mps'                         
else:
    device = torch.device('cpu')


device = torch.device('cpu')
# Constant to control how frequently we print train loss.
print_every = 100
print('using device:', device)
    



using device: cpu


In [2]:

%load_ext autoreload
%autoreload 2

In [3]:

def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


In [4]:
# with open(constants.CAPTION_TRAIN, 'r') as f:
#     data = json.load(f)
#     print(data.keys())
#     print(data["annotations"][0])

# with open(constants.VQA_OPEN_ENDED_QUESTION_TRAIN, 'r') as f:
#     data = json.load(f)
#     print(data.keys())
#     print(data["questions"][0])

# with open(constants.VQA_OPEN_ENDED_ANSWER_TRAIN, 'r') as f:
#     data = json.load(f)
#     print(data.keys())
#     print(data["annotations"][0])
    
# with open(constants.CAPTION_VAL, 'r') as f:
#     data = json.load(f)
#     print(data.keys())

# with open(constants.VQA_OPEN_ENDED_QUESTION_VAL, 'r') as f:
#     data = json.load(f)
#     print(data.keys())

# with open(constants.VQA_OPEN_ENDED_ANSWER_VAL, 'r') as f:
#     data = json.load(f)
#     print(data.keys())

#dataset.load(constants.VQA_OPEN_ENDED_QUESTION_TRAIN, ['image_id', 'id', 'caption'])

In [5]:
train = dataset.Coco()
val = dataset.Coco("validation")
test = dataset.Coco("test")

Downloading split 'train' to '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/train' if necessary
Found annotations at '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/raw/instances_train2017.json'
Images already downloaded
Existing download of split 'train' is sufficient
Loading existing dataset 'coco-2017-train'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'validation' to '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/validation' if necessary
Found annotations at '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/raw/instances_val2017.json'
Images already downloaded
Existing download of split 'validation' is sufficient
Loading existing dataset 'coco-2017-validation'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'test' to '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/test' if necessary
Found test 

In [6]:
print(len(train))
print(len(train.captions))

118287
591753


In [8]:
if False: # debug
    img = train.__getitem__(1)
    print(img)
    print(img.image_id)
    print(img.image_path)

    print(">>>>")
    print(img.captions())

    print(">>>>")
    print(img.qa())
    print("shape", img.image_tensor().shape)

    show([img.image_tensor()])

#plt.imshow(  img.image_tensor().permute(1, 2, 0)  )


In [9]:
tokenizer  = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# Add the Q and A token as special token
tokenizer.add_special_tokens(constants.QA_TOKEN_DICT)



4

In [10]:
def collate_fn2(batch):
    result = {}
    
    result['image_ids'] = []
    result['images'] = []
    raw_captions = []  # plain text 
    raw_qa = []   # plain text
    result['c2i'] = [] # index for images for a given caption. same len as 'caption'
    result['qa2i'] = [] # index of corresponding image for a given qa. same len as 'qa'
    target  = [] # the corresponding target for the qa.
    result['images']
    for idx, data in enumerate(batch):
        result['image_ids'].append(data.image_id)
        result['images'].append(data.image_tensor())
        caption_list = data.captions()
        if caption_list is not None:
            raw_captions += caption_list
            for c in range(len(caption_list)):
                result['c2i'].append(idx)
        
        qa_list = data.qa()
        if qa_list is not None:
            raw_qa += qa_list
            for c in range(len(qa_list)):
                result['qa2i'].append(idx)
    #print("raw_cap", len(raw_captions))
    #print("raw_qa", len(raw_qa))
    
    result['raw_cap'] = raw_captions
    result['captions'] = None if len(raw_captions) == 0 else \
                                tokenizer(raw_captions, padding=True , return_tensors="pt").to(device)
    result['raw_qa'] = raw_qa
    if len(raw_qa) != 0:
#        print("raw_qa:", raw_qa)
        result['qa'] =  tokenizer(raw_qa, padding=True , return_tensors="pt")['input_ids'].to(device, dtype=torch.int64)
        end_padding = torch.broadcast_to(torch.zeros(1), (result['qa'].shape[0], 1)).to(device, dtype=torch.int64)
        #print(end_padding.shape)
        # return a shape {seq, batch}
        target = torch.column_stack((result['qa'][:, 1:], end_padding)).transpose(0, 1)
    else:
        result['qa'] = None
        target = None
    return result, target

In [11]:
from torch.utils.data import DataLoader
batch_size = 1
#fn = collate_fn
fn = collate_fn2 
shuffle = False  # True
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)

In [12]:
def spot_check(dataloader):
    x = None
    it = iter(dataloader)
    while x is None:
        x, target= next(it)
        if len(x["raw_qa"]) == 0:
            x = None
    print(x)
    show(x["images"])

In [None]:
spot_check(val_dataloader)

In [None]:
spot_check(test_dataloader)

In [13]:
ce_fn = nn.CrossEntropyLoss( reduction='none')
cos_fn = nn.CosineSimilarity(dim=1)

In [14]:
# out = model(x, device)
# image_embedding, captions_embeddings, output_logits = out
# print(captions_embeddings.shape)
# a = output_logits.reshape(-1, len(tokenizer))
# b = target.reshape(-1)
# print("a", a.shape, a)
# print("b", b.shape, b)

# ce_loss = ce_fn(a, b)
# print(ce_loss.shape)
# N = len(x['images'])
# M = len(x['qa2i'])
# ce = ce_loss.reshape(-1, M).transpose(0, 1)
# print(ce.shape)
# print(ce)
# per_qa  = torch.mean(ce, axis = 1)
# print(per_qa.shape)

In [15]:
def cal_average(size, blown_loss, replicas):
    result= torch.zeros(size).to(device)
    counts = torch.zeros(size).to(device)
    for index, val in enumerate(replicas):
        result[val] += blown_loss[index]
        counts[val] += 1
        
    for index in range(size):
        if counts[index] == 0:
            counts[index] = 1  # so that result / counts still makes sense.
    #print("result", result)
    #print("counts:", counts)
    result /= counts
    return result

In [None]:
# blown = models.blow_to(image_embedding, result['c2i'])
# print(image_embedding.shape)
# print(image_embedding)
# print(blown.shape)
# print(blown)
# print("captions_embedding:", captions_embeddings.shape)
# print(result['c2i'])

In [None]:
# print(blown)
# print(captions_embeddings)
# cos= nn.CosineSimilarity(dim = 0)
# print(cos(blown[1], captions_embeddings[1]))

# per_caption_loss = cos_fn(blown, captions_embeddings)
# print(per_caption_loss)
# per_image_caption_loss = cal_average(len(result['images']), per_caption_loss, result['c2i'])
# print(per_image_caption_loss.shape)

# print(per_image_caption_loss)

In [22]:
model = VQANet(tokenizer).to(device)
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

In [23]:
gamma = 0.9
DEBUG = False
def do_train(model, idx, x, target):
        N = len(x['images'])
        # Zero your gradients for every batch!
        optimizer.zero_grad()
        if DEBUG:
            image_embedding_for_captions, captions_embedding, output_logits = None, None, None
        else:
            image_embedding_for_captions, captions_embedding, output_logits  = model(x, device)

#        image_embedding, captions_embedding, output_logits = None, None, None
        per_image_qa_loss = torch.zeros(N).to(device)
        per_image_caption_loss = torch.zeros(N).to(device)
        
        if output_logits is not None:
            print("out_logits argmax", torch.argmax(output_logits.transpose(0,1), axis=2))
            print("target", target)
            a = output_logits.reshape(-1, len(tokenizer))

            b = target.reshape(-1)
            print("a", a.shape)
            print("b", b.shape)

            K = len(x['qa2i'])
            # back to (K, seq)
            qa_loss = ce_fn(a, b).reshape(-1, K).transpose(0, 1)
            #print("qa_loss", qa_loss.shape)
            # qa loss, shape of (K) (different images can have diff counts of qas)
            per_qa_loss = torch.mean(qa_loss, axis = 1)

            # per image qa loss, shape of (N)
            per_image_qa_loss = cal_average(N, per_qa_loss, x['qa2i'])
            #print("per_qa_loss", per_qa_loss.shape)
            #print("per_image_qa_loss", per_image_qa_loss.shape)

        if captions_embedding is not None:
            # loss per caption, shape of (M) (different images can have diff counts of captions)
            per_caption_loss = cos_fn(image_embedding_for_captions, captions_embedding)
            # per image loss on the caption scale. shape of (N)
            per_image_caption_loss = cal_average(N, per_caption_loss, x['c2i'])

        #print("per_caption_loss", per_caption_loss.shape)
        #print("per_image_caption_loss", per_image_caption_loss.shape)
        total_loss = gamma * per_image_caption_loss + per_image_qa_loss

        loss = torch.sum(total_loss)

        if not DEBUG:
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()


        del per_image_caption_loss
        del per_image_qa_loss
        del x
        del total_loss
        return loss
            
def training(model, writer, epoches, early_terminate = None, empty_catch_after_every_n = 200, gc_every = 20, print_every = 100):
    start_time = time.time()
    model.train()
    for epoch_idx in range(epoches):
        print ("----- Start Epoch %s -----" % epoch_idx)
        epoch_loss = 0
        for idx, (x, target) in enumerate(train_dataloader):
            should_print = print_every is not None and (print_every == 1 or idx % (print_every -1) == 0)
            if should_print:
                print(">>>> Batch # ", idx,  x['image_ids'] )
            if early_terminate is not None:
                if idx > early_terminate - 1:
                    print("early terminating. at ", idx)
                    break;
            try:
                loss = do_train(model, idx, x, target).detach()
                epoch_loss += loss.item()
                if should_print :
                    print("loss:", loss)
            except Exception as e:
                print(">>>> FAILED! Batch # ", idx,  x['image_ids'])
                traceback.print_exc()
                break;

            if empty_catch_after_every_n is not None and (idx + 1) % empty_catch_after_every_n == 0:
                print(">>>empty torch mps cache")
                torch.mps.empty_cache()
            if gc_every is not None and (idx + 1) % gc_every == 0:
                    print("explictly calling GC:")
                    gc.collect()
            print("--- %s Per batch time ---" % (time.time() - start_time))
        epoch_loss /= len(train_dataloader) if early_terminate is None else early_terminate
        writer.add_scalar("Loss/train", epoch_loss, epoch_idx)

        print(f"---DONE: {epoch_idx} epoch, {(time.time() - start_time)} seconds, loss {epoch_loss} ---")
    

In [24]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

training(model, writer, 1000, empty_catch_after_every_n=None, print_every = 1, early_terminate = 1)
#training(model, writer, 2, empty_catch_after_every_n=None)
writer.flush()


----- Start Epoch 0 -----
>>>> Batch #  0 [9]
qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[26049,  3776, 24057, 30333, 19047,  9564,  1890, 13428,  8652,  9093,
         14377, 12139, 24315, 29698],
        [12614,  3776,  9352, 19435,  3556, 21975,  6625, 23739,  4935, 22553,
          1421, 25248, 12039, 16449],
        [28989, 20760,  7381,  8938, 22088, 10123, 19836, 19993,  4725,  9484,
         13787, 22973, 24304, 22004]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  10

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 102,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 102,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 102,   0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [30523,  5061, 22953],
        [ 1016,  1998, 21408],
        [30524,  3756,  3669],
        [  102, 30524, 30524

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  1996,  1029,  1029,  1996,  1029,  1029, 30523, 30524,
         30524,   102,     0,     0],
        [30522,  2054,  1996,  1029,  1996,  1996,  1029, 30523,  1029, 30524,
             0, 30524,   102,     0],
        [30522,  2054,  1029,  1996,  1996, 30524,  1029, 30523,  1029, 30524,
          1029, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2129,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2129,  3609,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  2003,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0],
        [  101, 30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,
          1998,  3756, 30524,   102],
        [  101, 30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953,
         21408,  3669, 30524,   102]])
out_logits argmax tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0,     0],
        [30522,  2054,  3609,  2024,  1996, 10447,  1029, 30523,  5061,  1998,
          3756, 30524,   102,     0],
        [30522,  2054,  2003,  1996,  2665,  4933,  1029, 30523, 22953, 21408,
          3669, 30524,   102,     0]])
target tensor([[30522, 30522, 30522],
        [ 2129,  2054,  2054],
        [ 2116,  3609,  2003],
        [16324,  2024,  1996],
        [ 2064,  1996,  2665],
        [ 2022, 10447,  4933],
        [ 2464,  1029,  1029],
        [ 1029, 30523, 30523],
        [3

KeyboardInterrupt: 

In [25]:
def manual(dataset):
    item = dataset.__getitem__(0)
    item.annotations['qa'] = ['[QUESTION] how many cookies can be seen? [ANSWER]']
    return collate_fn2([item])
manual(train)

({'image_ids': [9],
  'images': [tensor([[[  2,   1,   1,  ..., 140, 136, 133],
            [  1,   1,   1,  ..., 144, 138, 135],
            [  0,   0,   1,  ..., 142, 139, 137],
            ...,
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0]],
   
           [[ 23,  22,  23,  ..., 176, 172, 170],
            [ 21,  21,  23,  ..., 179, 175, 173],
            [ 22,  23,  25,  ..., 180, 178, 177],
            ...,
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0]],
   
           [[110, 111, 114,  ..., 202, 198, 197],
            [111, 112, 115,  ..., 206, 201, 199],
            [113, 112, 115,  ..., 206, 203, 202],
            ...,
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0]]], dtype=t

In [30]:
test1x, target = manual(train)
model.eval()
model.answer(test1x, device, max_length = 3)

>>>>>0
{'image_ids': [9], 'images': [tensor([[[  2,   1,   1,  ..., 140, 136, 133],
         [  1,   1,   1,  ..., 144, 138, 135],
         [  0,   0,   1,  ..., 142, 139, 137],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[ 23,  22,  23,  ..., 176, 172, 170],
         [ 21,  21,  23,  ..., 179, 175, 173],
         [ 22,  23,  25,  ..., 180, 178, 177],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[110, 111, 114,  ..., 202, 198, 197],
         [111, 112, 115,  ..., 206, 201, 199],
         [113, 112, 115,  ..., 206, 203, 202],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]]], dtype=torch.uint8)], 'c2i': [0, 0, 0, 0, 0], 'qa2i': [0], 'raw_cap': 

qa_input_ids tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102]])
output_logits torch.Size([1, 13, 30526])
word torch.Size([1, 13]) tensor([[30522,  2054,  2116, 16324,  2064,  2022,  2464,  1029, 30523,  1016,
         30524,   102,     0]])
after word torch.Size([1]) tensor([102])
new_qa, 1 tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0]])
new_qa, 2 tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,     0]])
new_qa tensor([[  101, 30522,  2129,  2116, 16324,  2064,  2022,  2464,  1029, 30523,
          1016, 30524,   102,   102]])


{'image_ids': [9],
 'images': [tensor([[[  2,   1,   1,  ..., 140, 136, 133],
           [  1,   1,   1,  ..., 144, 138, 135],
           [  0,   0,   1,  ..., 142, 139, 137],
           ...,
           [  0,   0,   0,  ...,   0,   0,   0],
           [  0,   0,   0,  ...,   0,   0,   0],
           [  0,   0,   0,  ...,   0,   0,   0]],
  
          [[ 23,  22,  23,  ..., 176, 172, 170],
           [ 21,  21,  23,  ..., 179, 175, 173],
           [ 22,  23,  25,  ..., 180, 178, 177],
           ...,
           [  0,   0,   0,  ...,   0,   0,   0],
           [  0,   0,   0,  ...,   0,   0,   0],
           [  0,   0,   0,  ...,   0,   0,   0]],
  
          [[110, 111, 114,  ..., 202, 198, 197],
           [111, 112, 115,  ..., 206, 201, 199],
           [113, 112, 115,  ..., 206, 203, 202],
           ...,
           [  0,   0,   0,  ...,   0,   0,   0],
           [  0,   0,   0,  ...,   0,   0,   0],
           [  0,   0,   0,  ...,   0,   0,   0]]], dtype=torch.uint8)],
 'c2i': [0

In [None]:
print(tokenizer("[QUESTION] what is the players number? [ANSWER]")["input_ids"])

In [None]:
print(tokenizer("what is the players number? [ANSWER]")["input_ids"])

In [None]:
print(tokenizer("[QUESTION] what is the players number?")["input_ids"])

In [None]:
print(tokenizer("")["input_ids"])
print(tokenizer("[CLS][SEP]")["input_ids"])

In [None]:
a = torch.tensor([[  101, 30522,  2054,  2326,  2515,  1996,  2482,  9083,  2012,  1996,
         13730,  3073,  1029, 30523,     0,   120,     0,     0,     0,     0,
             0],
        [  101, 30522,  2054,  3609,  2003,  1996,  4744,  1029, 30523,     0,
           120,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101, 30522,  2054,  2003,  1996,  8638,  2081,  1997,  1029, 30523,
             0,   120,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101, 30522,  2054,  5127,  2003,  1996,  2447,  3061,  2012,  1029,
         30523,     0,   120,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101, 30522,  2054,  2003,  1996,  2867,  2193,  1029, 30523,     0,
           120,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101, 30522,  2054,  2003,  2006,  1996,  2879,  1005,  1055,  2132,
          1029, 30523,     0,   120,     0,     0,     0,     0,     0,     0,
             0]])
print(a[a== 102])

In [None]:
print(a.shape)
indices = (a == 102)
print(indices)
print(a)
b = torch.cat((a, torch.zeros((a.shape[0], 1), dtype=torch.int64)), dim = 1)
print(b)
#print(b.index_select(indices))
c = b.index_put(tuple(indices.t()), torch.ones(indices.shape[0], dtype=torch.int64))
print("c", c)
new_indices = indices + torch.tensor([0, 1])
print(new_indices)
d = c.index_put(tuple(new_indices.t()), 120 * torch.ones(indices.shape[0], dtype=torch.int64))
print("d", d)

a[indices[:,0], indices[:, 1]]