## Base Inference

In [1]:
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from PIL import Image

model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-base").to("cuda")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-infographics-vqa-base")


In [2]:
image = Image.open("/home/jjh/level3-cv-productserving-cv-10/data/images/10065.jpeg")
question = "Which market crash had the lowest impact on the S&P 500, Dot-com crash, Coronavirus crash, or Great recession ?"
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")

In [3]:
inputs["flattened_patches"].shape

torch.Size([1, 2048, 770])

In [4]:
ins = processor(images = image,text=question,return_tensors='pt').to('cuda')

In [5]:
predictions = model.generate(**inputs)
pred = processor.decode(predictions[0], skip_special_tokens=True)
print(pred)

Great recession




## train

In [1]:
import os
import torch
from torch.utils.data import Dataset
import json
from PIL import Image

In [2]:
from transformers import Pix2StructForConditionalGeneration, AutoProcessor

model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-base").to("cuda")
auto_processor = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-base")

In [2]:
class Pix2StructDataset(Dataset):
    def __init__(self, image_dir, json_data, processor, train):
        self.img_dir = image_dir
        self.json_data = json_data
        self.processor = processor
        self.file_list = os.listdir(image_dir)
        self.train = train
        
    def __getitem__(self, index): 
        data = self.json_data["data"][index]
        image_name = data["image_local_name"]
        img = Image.open(os.path.join(self.img_dir, image_name))
        q = data["question"]
        inputs = self.processor(images=img, text=q, return_tensors="pt")
        if self.train:
            a = data["answers"][0]
            label = self.processor.tokenizer(text=a, padding="max_length", return_tensors="pt", add_special_tokens=True, max_length=45).input_ids
            inputs["labels"] = label
            return inputs
        return inputs
  
    
    def __len__(self): 
        return len(self.json_data["data"])

In [5]:
def collator(batch):
  new_batch = {"flattened_patches":[], "attention_mask":[], "labels":[]}
  
  for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"][0])
    new_batch["attention_mask"].append(item["attention_mask"][0])
    new_batch["labels"].append(item["labels"][0])
  new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
  new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])
  new_batch["labels"] = torch.stack(new_batch["labels"])

  return new_batch

In [6]:
img_dir = '/home/jjh/level3-cv-productserving-cv-10/data/images/'
train_dir = '../data/qas/infographicsVQA_train_v1.0.json'
val_dir = '../data/qas/infographicsVQA_val_v1.0_withQT.json'
test_dir = '../data/qas/infographicsVQA_test_v1.0.json'

train_dataset = Pix2StructDataset(image_dir=img_dir, json_data=json_data, processor=auto_processor, train=True)
val_dataset = Pix2StructDataset(image_dir=img_dir, json_data=json_data, processor=auto_processor, train=True)
test_dataset = Pix2StructDataset(image_dir=img_dir, json_data=json_data, processor=auto_processor, train=False)


In [None]:
def load_json(json_dir):
    with open(json_dir) as f:
        json_data = json.load(f)
    return json_data

In [7]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    report_to="wandb",
    run_name="pix2struct_jjh",
    output_dir="pix2struct_1",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True
)

In [8]:
import wandb
wandb.init(entity="level2-cv-10-detection", project="pix2struct", name="pix2struct_jjh")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maurantiacus1220[0m ([33mlevel2-cv-10-detection[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from datasets import load_metric
import numpy as np

accuracy_metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    # metrics from the datasets library have a `compute` method
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [9]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collator,
    compute_metrics=compute_metrics
    
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
