## Base Inference

In [10]:
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from PIL import Image

model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-base").to("cuda")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-infographics-vqa-base")


## train

In [11]:
import os
import torch
from torch.utils.data import Dataset
import json
from PIL import Image

In [12]:
# from transformers import AutoProcessor
# auto_processor = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-base")

In [13]:
class Pix2StructDataset(Dataset):
    def __init__(self, image_dir, json_dir, processor, train):
        self.img_dir = image_dir
        with open(json_dir) as f:
            self.json_data = json.load(f)
        self.processor = processor
        self.file_list = os.listdir(image_dir)
        self.train = train
        
    def __getitem__(self, index): 
        data = self.json_data["data"][index]
        image_name = data["image_local_name"]
        img = Image.open(os.path.join(self.img_dir, image_name))
        q = data["question"]
        inputs = self.processor(images=img, text=q, return_tensors="pt").to("cuda")
        if self.train:
            a = data["answers"][0]
            label = self.processor.tokenizer(text=a, padding="max_length", return_tensors="pt", add_special_tokens=True, max_length=45).input_ids.to("cuda")
            inputs["labels"] = label
            return inputs
        return inputs
  
    
    def __len__(self): 
        return len(self.file_list)

In [14]:
def collator(batch):
  new_batch = {"flattened_patches":[], "attention_mask":[], "labels":[]}
  
  for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"][0])
    new_batch["attention_mask"].append(item["attention_mask"][0])
    new_batch["labels"].append(item["labels"][0])
  new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
  new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])
  new_batch["labels"] = torch.stack(new_batch["labels"])

  return new_batch

In [15]:
img_dir = './hy_info/task3//images/'
train_dataset = Pix2StructDataset(image_dir=img_dir, json_dir='./hy_info/task3/qas/infographicsVQA_train_v1.0.json', processor=processor, train=True)
val_dataset = Pix2StructDataset(image_dir=img_dir, json_dir='./hy_info/task3/qas/infographicsVQA_val_v1.0_withQT.json', processor=processor, train=False)
test_dataset = Pix2StructDataset(image_dir=img_dir, json_dir='./hy_info/task3/qas/infographicsVQA_test_v1.0.json', processor=processor, train=False)


In [16]:
# import json
# with open('./hy_info/task3/qas/infographicsVQA_val_v1.0_withQT.json','r') as f:
#     data_json = json.load(f)

In [17]:
# answer_list = []
# for i,l in enumerate(data_json['data']):
#     answer_list.append(l['answers'][0])
# # data_json['data'][0]['answers']

In [18]:
# answer_list.sort(key=lambda x:len(x),reverse=True)

In [19]:
# max_lentgh = []
# for i, l in enumerate(answer_list):
#     max_lentgh.append(len(processor.tokenizer(text = l,return_tensors='pt').input_ids[0]))

In [20]:
# max_lentgh.sort(reverse=True)

In [21]:
# max_lentgh

In [22]:
# list.sort(reverse=True)

In [23]:
# list

In [24]:
# list=[]
# for i,l in enumerate(train_dataset):
#     list.append(l['labels'].size())


In [25]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="pix2struct_1",
    learning_rate=2e-5,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=1,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    dataloader_pin_memory=False
)

In [26]:
from transformers import Trainer
torch.multiprocessing.set_start_method('spawn')
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collator
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbigchoi3449[0m ([33mlevel2-cv-10-detection[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

# Inference

In [None]:
image = Image.open("./hy_info/task3//images/10065.jpeg")
question = "Which market crash had the lowest impact on the S&P 500, Dot-com crash, Coronavirus crash, or Great recession ?"
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs)
pred = processor.decode(predictions[0], skip_special_tokens=True)

In [None]:
print('Question : ', question,'\nAnswer   : ', pred)

Question :  Which market crash had the lowest impact on the S&P 500, Dot-com crash, Coronavirus crash, or Great recession ? 
Answer   :  The coronavirus crash
