# Test Pix2Struct model on WebUI2Code dataset, version 4096

## Setup Envirnoment

In [None]:
!pip install transformers==4.33.1

Collecting transformers==4.33.1
  Downloading transformers-4.33.1-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers==4.33.1)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.33.1)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m111.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers==4.33.1)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m

## Import necessary libraries

In [None]:
from google.colab import drive
import os
import zipfile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import re
from transformers import Pix2StructForConditionalGeneration, AutoProcessor, GenerationConfig
import torch
from torch.nn import functional as F
from pathlib import Path
from nltk import edit_distance
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from torch.utils.data import random_split
import random

## Define variables and parameters

In [None]:
G_DRIVE_FOLDER = '/content/drive/MyDrive/Datasets/'
G_DRIVE_FOLDER_CHECKPOINTS = '/content/drive/MyDrive/Checkpoints/'
DATASET_NAME = 'WebUI2Code_4096_preprocessed'
ZIP_NAME = DATASET_NAME + '.zip'

DESTINATION_FOLDER= '/content/data/'
DATASET_FOLDER = DESTINATION_FOLDER + DATASET_NAME
OUTPUT_FOLDER = '/content/drive/MyDrive/Testing_output/webUI2Code_4096'


EXPERIMENT_NAME = "Pix2Struct_WebUI2Code_Complete_4096_FULL_TEST"

MAX_SENTENCE_LEN = 4096

CHUNK_LENGTH = 1024
CONTEXT_OVERLAP_LENGTH = 256

MAX_PATCHES = 1024

DEBUG = False
VERBOSE = True

BATCH_SIZE = 8

TRAIN_SET_PERCENTAGE = 0.898
VALID_SET_PERCENTAGE = 0.002 # Use 20 samples for validation
# TEST_SET_PERCENTAGE is 1 - TRAIN_SET_PERCENTAGE - VALID_SET_PERCENTAGE

RANDOM_SEED = 100

LOAD_FROM_CHECKPOINT = True
LAST_CHECKPOINT_NAME = "Pix2Struct_WebUI2Code_Complete_4096_FULL_epoch[9].pth"

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
MAX_N_CHUNKS_PER_SENTENCE = 1 + (MAX_SENTENCE_LEN - CHUNK_LENGTH) // (CHUNK_LENGTH - CONTEXT_OVERLAP_LENGTH)
print("MAX_N_CHUNKS_PER_SENTENCE", MAX_N_CHUNKS_PER_SENTENCE)

MAX_N_CHUNKS_PER_SENTENCE 5



## Load WebUI2Code Dataset

### Mount Google Drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


### Import zip file from Google Drive

In [None]:
os.makedirs(DESTINATION_FOLDER, exist_ok=True)

with zipfile.ZipFile(G_DRIVE_FOLDER + ZIP_NAME, "r") as zf:
    zf.extractall(DESTINATION_FOLDER)

## Load Model and Processor

In [None]:
repo_id = "google/pix2struct-base"

processor = AutoProcessor.from_pretrained(repo_id)
model = Pix2StructForConditionalGeneration.from_pretrained(repo_id, is_encoder_decoder=True)

(…)se/resolve/main/preprocessor_config.json:   0%|          | 0.00/231 [00:00<?, ?B/s]

(…)-base/resolve/main/tokenizer_config.json:   0%|          | 0.00/2.61k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/851k [00:00<?, ?B/s]

(…)2struct-base/resolve/main/tokenizer.json:   0%|          | 0.00/3.27M [00:00<?, ?B/s]

(…)ase/resolve/main/special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

(…)pix2struct-base/resolve/main/config.json:   0%|          | 0.00/4.92k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

## Create Dataset class

### Find max sentence length and new unknown tokens

In [None]:
all_paths = [path.replace(".png", "") for path in os.listdir(DATASET_FOLDER) if path.endswith(".png")]

In [None]:
len(all_paths)

2322

In [None]:
print(all_paths[0])

deltasonetabs.monster


### Filter samples with images which are too big, too small, or with strange aspect ratios

In [None]:
# Find max length
max_length = 0

# Read text files and add new tokens to dictionary
tokens_to_add = set()
for path in all_paths:
    with open(DATASET_FOLDER + "/" + path + ".txt", "r") as reader:
        splitted_text = processor.tokenizer(reader.read()).tokens()
        tokens_to_add = tokens_to_add.union(set(splitted_text))

        # Check if the current sentence has the largest number of tokens
        if len(splitted_text) > max_length:
            max_length = len(splitted_text)

print(f"Max sentence length = {max_length}")

newly_added_num = processor.tokenizer.add_tokens(list(tokens_to_add))
print(f"Number of new tokens = {newly_added_num}")

# Resize the model's token embeddings if there are new tokens
if newly_added_num > 0:
    model.decoder.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=8)

Max sentence length = 4092
Number of new tokens = 101


In [None]:
print("Loading model from checkpoint:", LAST_CHECKPOINT_NAME)
checkpoint = torch.load(G_DRIVE_FOLDER_CHECKPOINTS + LAST_CHECKPOINT_NAME)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE)

Loading model from checkpoint: Pix2Struct_WebUI2Code_Complete_4096_FULL_epoch[9].pth


Pix2StructForConditionalGeneration(
  (encoder): Pix2StructVisionModel(
    (embeddings): Pix2StructVisionEmbeddings(
      (patch_projection): Linear(in_features=768, out_features=768, bias=True)
      (row_embedder): Embedding(4096, 768)
      (column_embedder): Embedding(4096, 768)
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (encoder): Pix2StructVisionEncoder(
      (layer): ModuleList(
        (0-11): 12 x Pix2StructVisionLayer(
          (attention): Pix2StructVisionAttention(
            (query): Linear(in_features=768, out_features=768, bias=False)
            (key): Linear(in_features=768, out_features=768, bias=False)
            (value): Linear(in_features=768, out_features=768, bias=False)
            (output): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): Pix2StructVisionMlp(
            (wi_0): Linear(in_features=768, out_features=2048, bias=False)
            (wi_1): Linear(in_features=768, out_features=2048, bias=False)
 

### Split files into training - validation - test sets

In [None]:
random.seed(RANDOM_SEED)

# Use the same seed, so that parts remain the same
random.shuffle(all_paths)

train_len = int(TRAIN_SET_PERCENTAGE * len(all_paths))
valid_len = int(VALID_SET_PERCENTAGE * len(all_paths))

train_paths = all_paths[:train_len]
valid_paths = all_paths[train_len:train_len+valid_len]
test_paths = all_paths[train_len+valid_len:]

print(f"TRAIN_SET size = {len(train_paths)}")
print(f"VALID_SET size = {len(valid_paths)}")
print(f"TEST_SET size = {len(test_paths)}")

TRAIN_SET size = 2085
VALID_SET size = 4
TEST_SET size = 233


In [None]:
class WebUI2CodeDataset(Dataset):
    def __init__(self, root_dir, transform, paths):

        self.root_dir = root_dir
        self.transform = transform
        self.paths = paths

        self.max_patches = MAX_PATCHES
        self.max_length = MAX_SENTENCE_LEN
        self.ignore_id = -100

        self.encodings = []

        for path in tqdm(paths):
            # Directly process the text files, and save them in the ram
            # Do the same also for images, if there is enough space in memory
            text_file_path = os.path.join(root_dir , path + ".txt")
            image_file_path = os.path.join(root_dir , path + ".png")

            # Load image
            image = Image.open(image_file_path).convert('RGB')

            if DEBUG:
                image.show()

            if self.transform:
                image = self.transform(image)

            encoding = processor(images=image, max_patches=self.max_patches, return_tensors="pt")
            encoding = {k:v.squeeze() for k,v in encoding.items()}

            # Load text
            with open(text_file_path, 'r') as f:
                text = f.read()

            if DEBUG:
              print("text:")
              print(text)

            input_ids = processor.tokenizer(
                text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids

            labels = input_ids.squeeze().clone()
            labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token

            encoding["labels"] = labels.to(torch.int32)

            # For each sample save directly the encoding of both text and image
            self.encodings.append(encoding)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings[idx], self.paths[idx]

In [None]:
# Transformations for the image
transform = transforms.Compose([
    transforms.ToTensor(),  # convert PIL Image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # normalize for pretrained models
])

# Instantiate the CustomDataset
test_dataset = WebUI2CodeDataset(DATASET_FOLDER, transform, test_paths)

# Use DataLoader for batching and shuffling
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

100%|██████████| 233/233 [00:22<00:00, 10.53it/s]


In [None]:
print(f"test_dataloader size = {len(test_dataloader)}")

test_dataloader size = 30


In [None]:
batch = next(iter(test_dataloader))

In [None]:
encoding, text_file_paths = batch

In [None]:
print(len(encoding))

3


In [None]:
encoding["labels"]

tensor([[  411,   812,   482,  ...,  -100,  -100,  -100],
        [  411,   812, 11789,  ...,  -100,  -100,  -100],
        [  411,   812, 11789,  ...,  -100,  -100,  -100],
        ...,
        [  411,   812, 11789,  ...,  -100,  -100,  -100],
        [50190, 50194, 50130,  ...,  -100,  -100,  -100],
        [  411,   812, 11789,  ...,  -100,  -100,  -100]], dtype=torch.int32)

In [None]:
encoding["labels"][0]

tensor([ 411,  812,  482,  ..., -100, -100, -100], dtype=torch.int32)

In [None]:
labels_list = encoding["labels"][0].tolist()

# Filter out the -100 values
filtered_labels = [token for token in labels_list if token != -100]

# Decode the cleaned list of tokens
decoded_text_example = processor.tokenizer.batch_decode([filtered_labels], skip_special_tokens=True)[0]


In [None]:
decoded_text_example

'<html class="no-js seed-csp4" lang="en"> <head> <title></title> <link href="aracne.biz.css" rel="stylesheet"/> <link href="aracne.biz.css" rel="stylesheet"/> <style type="text/css"> calculated styles Background Style html background fafafa urlhttpsexample.com no-repeat top center fixed; -webkit-background-size cover; -moz-background-size cover; -o-background-size cover; background-size cover;.seed-csp4 body background transparent; Text Styles.seed-csp4 body font-family Helvetica Neue, Helvetica, Arial, sans-serif.seed-csp4 h1,.seed-csp4 h2,.seed-csp4 h3,.seed-csp4 h4,.seed-csp4 h5,.seed-csp4 h6 font-family Helvetica Neue, Helvetica, Arial, sans-serif.seed-csp4 body colorffffff;.seed-csp4 h1,.seed-csp4 h2,.seed-csp4 h3,.seed-csp4 h4,.seed-csp4 h5,.seed-csp4 h6 colorffffff;.seed-csp4 a,.seed-csp4 avisited,.seed-csp4 ahover,.seed-csp4 aactive,.seed-csp4 afocus color27AE60; supports -webkit-overflow-scrolling touch html height 100; overflow hidden; body height100; overflow auto; -webkit-o

In [None]:
for k,v in encoding.items():
    print(k,v.shape)

flattened_patches torch.Size([8, 1024, 770])
attention_mask torch.Size([8, 1024])
labels torch.Size([8, 4096])


In [None]:
print(text_file_paths)

('aracne.biz', 'albendazole.works', 'diflucan.run', 'vldb.org', 'hcch.net', 'tadalafilx.online', 'adium.im', 'tetracyclinetab.quest')


### Utility functions

In [None]:
def move_to_device(data):
    if isinstance(data, (list,tuple)):
        return [move_to_device(x) for x in data]
    elif isinstance(data, dict):
        return {k: move_to_device(v) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.to(DEVICE)
    else:
        return data

### Main Testing function

In [None]:
START_TOKEN_ID = PAD_TOKEN_ID = processor.tokenizer.pad_token_id

In [None]:
def testing_loop(testing_dataloader, model, processor, config, description, generation_config=None, do_sample=False):

    if (generation_config):
        print("using custom generation config in testing loop: \n")
        print(generation_config)
    if (do_sample):
        print("\nusing sampling\n")

    model.eval()

    with torch.no_grad():
        test_loop = tqdm(enumerate(testing_dataloader), total=len(testing_dataloader), desc=description)
        for i, batch in test_loop:
            encoding, text_file_paths = batch
            encoding = move_to_device(encoding)
            labels, flattened_patches, attention_mask = encoding["labels"], encoding["flattened_patches"], encoding["attention_mask"]

            # Initialize total_outputs with zeros
            total_outputs = None
            context_from_last = None

            # Initialize a mask to track which sentences are finished
            finished_sentences_mask = torch.zeros(flattened_patches.size(0), dtype=torch.bool, device=flattened_patches.device)

            for iteration in range(MAX_N_CHUNKS_PER_SENTENCE):

                generate_args = {
                    "flattened_patches": flattened_patches[~finished_sentences_mask],
                    "attention_mask": attention_mask[~finished_sentences_mask],
                    "max_new_tokens": CHUNK_LENGTH - (CONTEXT_OVERLAP_LENGTH if iteration else 0),
                    "generation_config": generation_config,
                    "do_sample": do_sample
                }

                if iteration and context_from_last is not None:
                    generate_args["decoder_input_ids"] = context_from_last[~finished_sentences_mask]

                outputs = model.generate(**generate_args)

                # Remove context overlap only from the second iteration onwards
                new_chunks = outputs if iteration == 0 else outputs[:, CONTEXT_OVERLAP_LENGTH:]

                if iteration == 0:
                    total_outputs = new_chunks
                else:
                    # Update total_outputs by concatenating new chunks
                    new_chunks_with_padding_chunks = torch.full((flattened_patches.shape[0], new_chunks.shape[1]), PAD_TOKEN_ID, dtype=new_chunks.dtype, device=new_chunks.device)
                    new_chunks_with_padding_chunks[~finished_sentences_mask] = new_chunks
                    total_outputs = torch.cat((total_outputs, new_chunks_with_padding_chunks), dim=1)

                # Update the finished_sentences_mask
                finished_sentences_mask[~finished_sentences_mask] |= (outputs == processor.tokenizer.eos_token_id).any(dim=1)

                # If all sentences are finished, exit the loop
                if finished_sentences_mask.all():
                    break

                if outputs.shape[1] < CHUNK_LENGTH:
                    print("ERROR: !! should have already exited because all sentences reached the end!!")

                # -1 because it will put in front a START_TOKEN automatically
                context_from_last = total_outputs[:, -(CONTEXT_OVERLAP_LENGTH-1):]

            predictions = processor.tokenizer.batch_decode(total_outputs, skip_special_tokens=True)

            labels[labels == -100] = 0
            answers = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)

            for pred, answer, text_file_path in zip(predictions, answers, text_file_paths):
                file_path_answer = f"{OUTPUT_FOLDER}/{text_file_path}_answer.txt"
                file_path_pred = f"{OUTPUT_FOLDER}/{text_file_path}_pred.txt"

                try:
                    with open(file_path_answer, "w", encoding="utf-8") as f:
                        print(answer, file=f)
                except UnicodeEncodeError:
                    cleaned_answer = ''.join(char for char in answer if ord(char) < 128)
                    with open(file_path_answer, "w") as f:
                        print(cleaned_answer, file=f)
                except Exception as e:
                    print(f"An unexpected error occurred for file {file_path_answer}: {e}")

                try:
                    with open(file_path_pred, "w", encoding="utf-8") as f:
                        print(pred, file=f)
                except UnicodeEncodeError:
                    cleaned_pred = ''.join(char for char in pred if ord(char) < 128)
                    with open(file_path_pred, "w") as f:
                        print(cleaned_pred, file=f)
                except Exception as e:
                    print(f"An unexpected error occurred for file {file_path_pred}: {e}")

    return

In [None]:
config = {
          "verbose": VERBOSE,
}

In [None]:
def validate_config(config):
    # Check required keys
    required_keys = [
        "verbose"
    ]
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Key '{key}' must be present in the configuration.")

    # Check that values are in expected ranges
    if not isinstance(config["verbose"], bool):
        raise ValueError("verbose must be a boolean value.")

In [None]:
validate_config(config)
print(config)

{'verbose': True}


## Test the model

In [None]:
generation_config = GenerationConfig.from_model_config(model.config)
generation_config.repetition_penalty = 1.3
testing_loop(test_dataloader, model, processor, config, f"Test loop", generation_config=generation_config, do_sample=False)

using custom generation config in testing loop: 

GenerationConfig {
  "_from_model_config": true,
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "repetition_penalty": 1.3,
  "transformers_version": "4.33.1",
  "use_cache": false
}



Test loop: 100%|██████████| 30/30 [5:39:59<00:00, 679.97s/it]
