# Test Pix2Struct model on Pix2Code HTML Lorem Ipsum dataset

## Setup Envirnoment

In [1]:
!pip install transformers==4.33.1

Collecting transformers==4.33.1
  Downloading transformers-4.33.1-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers==4.33.1)
  Downloading huggingface_hub-0.17.2-py3-none-any.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.33.1)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers==4.33.1)
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [

In [2]:
#!pip install --upgrade git+https://github.com/huggingface/transformers

## Import necessary libraries

In [3]:
from google.colab import drive
import os
import zipfile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import re
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
import torch
from torch.nn import functional as F
from pathlib import Path
from nltk import edit_distance
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from torch.utils.data import random_split
import random

## Define variables and parameters

In [4]:
G_DRIVE_FOLDER = '/content/drive/MyDrive/Datasets/'
G_DRIVE_FOLDER_CHECKPOINTS = '/content/drive/MyDrive/Checkpoints/'
DATASET_NAME = 'pix2code_web_with_html_loremipsum'
ZIP_NAME = DATASET_NAME + '.zip'
DESTINATION_FOLDER= '/content/data/'
DATASET_FOLDER = DESTINATION_FOLDER + 'web_with_html_loremipsum/' # unzipped name
OUTPUT_FOLDER = '/content/drive/MyDrive/Testing_output/' + DATASET_NAME

EXPERIMENT_NAME = "Pix2Struct_Pix2Code_HTML_LI_FULL_TEST"

MAX_SENTENCE_LEN = 1024

MAX_PATCHES = 1024

DEBUG = False
VERBOSE = True

BATCH_SIZE = 10

TRAIN_SET_PERCENTAGE = 0.89
VALID_SET_PERCENTAGE = 0.01
# TEST_SET_PERCENTAGE is 1 - TRAIN_SET_PERCENTAGE - VALID_SET_PERCENTAGE # Use 1000 for test

RANDOM_SEED = 100

LOAD_FROM_CHECKPOINT = True
LAST_CHECKPOINT_NAME = "Pix2Struct_Pix2Code_HTML_LI_FULL_epoch[14].pth"

In [5]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## Load Pix2Code Dataset

### Mount Google Drive

In [6]:
drive.mount('/content/drive')

Mounted at /content/drive


### Import zip file from Google Drive

In [7]:
os.makedirs(DESTINATION_FOLDER, exist_ok=True)

with zipfile.ZipFile(G_DRIVE_FOLDER + ZIP_NAME, "r") as zf:
    zf.extractall(DESTINATION_FOLDER)

## Load Model and Processor

In [8]:
repo_id = "google/pix2struct-base"

processor = AutoProcessor.from_pretrained(repo_id)
model = Pix2StructForConditionalGeneration.from_pretrained(repo_id, is_encoder_decoder=True)

Downloading (…)rocessor_config.json:   0%|          | 0.00/231 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.61k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/851k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/3.27M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.92k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

## Create Dataset class

In [9]:
def preprocess_html_file(html_text):
    text_without_header = re.sub(r'<header>.*?</header>', '', html_text, flags=re.DOTALL)
    text_without_footer = re.sub(r'<footer class="footer">.*?</footer>', '', text_without_header, flags=re.DOTALL)
    text_without_script = re.sub(r'<script .*?</script>', '', text_without_footer, flags=re.DOTALL)
    text_without_linebreaks = text_without_script.replace('\n', ' ')
    text_without_multiple_spaces = re.sub(r'\s+', ' ', text_without_linebreaks)
    return text_without_multiple_spaces

### Filter files with less tokens than 1024 and add new unknown tokens

In [10]:
# Get a list of all files in root_dir
files = os.listdir(DATASET_FOLDER)

# Find only html files
all_html_files = [file for file in files if file.endswith('.html')]

In [11]:
# Find max length
max_length = 0

bigger_than_1024 = 0
lower_than_1024 = 0

html_files_filtered = []

tokens_to_add = set()

for html_file_path in all_html_files:
    with open(DATASET_FOLDER + "/" + html_file_path, "r") as reader:
        preprocessed_text = preprocess_html_file(reader.read())
        splitted_text = processor.tokenizer(preprocessed_text).tokens()
        if len(splitted_text) > 1024:
            bigger_than_1024 += 1
        else:
            lower_than_1024 += 1
            html_files_filtered.append(html_file_path)
            tokens_to_add = tokens_to_add.union(set(splitted_text))

print("bigger_than_1024= ", bigger_than_1024)
print("lower_than_1024= ", lower_than_1024)

newly_added_num = processor.tokenizer.add_tokens(list(tokens_to_add))
print(f"Number of new tokens = {newly_added_num}")

# Resize the model's token embeddings if there are new tokens
if newly_added_num > 0:
    model.decoder.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=8)

bigger_than_1024=  0
lower_than_1024=  1742
Number of new tokens = 0


In [12]:
print(len(html_files_filtered))

1742


In [13]:
random.seed(RANDOM_SEED)

# Use the same seed, so that parts remain the same
random.shuffle(html_files_filtered)

train_len = int(TRAIN_SET_PERCENTAGE * len(html_files_filtered))
valid_len = int(VALID_SET_PERCENTAGE * len(html_files_filtered))

train_paths = html_files_filtered[:train_len]
valid_paths = html_files_filtered[train_len:train_len+valid_len]
test_paths = html_files_filtered[train_len+valid_len:]

print(f"TRAIN_SET size = {len(train_paths)}")
print(f"VALID_SET size = {len(valid_paths)}")
print(f"TEST_SET size = {len(test_paths)}")

TRAIN_SET size = 1550
VALID_SET size = 17
TEST_SET size = 175


In [14]:
class Pix2CodeDataset(Dataset):
    def __init__(self, root_dir, transform, text_files_paths):

        self.root_dir = root_dir
        self.transform = transform
        self.text_files_paths = text_files_paths

        self.max_patches = MAX_PATCHES
        self.max_length = MAX_SENTENCE_LEN
        self.ignore_id = -100

        self.encodings = []

        for text_file in tqdm(text_files_paths):
            image_file = text_file.replace('.html', '.png')

            # Directly process the text files, and save them in the ram
            # Do the same also for images, if there is enough space in memory
            text_file_path = os.path.join(root_dir, text_file)
            image_file_path = os.path.join(root_dir, image_file)

            # Load image
            image = Image.open(image_file_path).convert('RGB')

            if DEBUG:
                image.show()

            if self.transform:
                image = self.transform(image)

            encoding = processor(images=image, max_patches=self.max_patches, return_tensors="pt")
            encoding = {k:v.squeeze() for k,v in encoding.items()}

            # Load text
            with open(text_file_path, 'r') as f:
                text = f.read()
                text_cleaned = preprocess_html_file(text)

            if DEBUG:
              print("text:")
              print(text)
              print("\n\n\ntext_cleaned:")
              print(text_cleaned)

            input_ids = processor.tokenizer(
                text_cleaned,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids

            labels = input_ids.squeeze().clone()
            labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token

            encoding["labels"] = labels

            # For each sample save directly the encoding of both text and image
            self.encodings.append(encoding)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings[idx], self.text_files_paths[idx].replace(".html", "")

In [15]:
# Transformations for the image
transform = transforms.Compose([
    transforms.ToTensor(),  # convert PIL Image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # normalize for pretrained models
])

# Instantiate the CustomDataset
test_dataset = Pix2CodeDataset(DATASET_FOLDER, transform, test_paths)

# Use DataLoader for batching and shuffling
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

100%|██████████| 175/175 [00:14<00:00, 11.85it/s]


In [16]:
print(f"test_dataloader size = {len(test_dataloader)}")

test_dataloader size = 18


In [17]:
batch = next(iter(test_dataloader))

In [18]:
encoding, text_file_paths = batch

In [19]:
print(text_file_paths)

('07040F49-1886-4F5B-ADB7-4B80D37738F5', '4414C575-EF1C-4F9F-8645-3213EA59F4AC', '355278CC-B263-4A1B-913B-9C2331988EEE', '7DB70CF7-EE4A-42FD-B5C2-AF5ACA7D7473', '74E5489C-7C59-4E0E-901F-16F2051853EC', 'A7BFA607-C139-4300-B64C-CC7B5F770B3E', '0EBD0467-076F-4946-8549-C3EEF47F37AF', '0EA46C27-81CE-4D3D-A613-E54BE9DB1890', '764A5E68-5A8C-41FA-8447-619823DA554B', '5C65608B-F47F-4877-8922-E13E3D926253')


In [20]:
encoding

{'flattened_patches': tensor([[[ 1.0000,  1.0000, -0.3106,  ..., -0.3106,  0.1386,  0.6672],
          [ 1.0000,  2.0000, -0.3106,  ..., -8.1887, -5.7268, -2.6882],
          [ 1.0000,  3.0000, -0.3106,  ..., -7.9945, -5.5811, -2.6050],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 1.0000,  1.0000, -0.1188,  ..., -0.1188,  0.2312,  0.6431],
          [ 1.0000,  2.0000, -0.1188,  ..., -6.9246, -6.7265, -6.2837],
          [ 1.0000,  3.0000, -0.1188,  ..., -6.9219, -6.7146, -6.2621],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 1.0000,  1.0000, -0.0567,  ..., -0.0567,  0.2832,  0.6833],
       

In [21]:
encoding["flattened_patches"][0]

tensor([[ 1.0000,  1.0000, -0.3106,  ..., -0.3106,  0.1386,  0.6672],
        [ 1.0000,  2.0000, -0.3106,  ..., -8.1887, -5.7268, -2.6882],
        [ 1.0000,  3.0000, -0.3106,  ..., -7.9945, -5.5811, -2.6050],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [22]:
encoding["labels"]

tensor([[50190, 50227,   411,  ...,  -100,  -100,  -100],
        [50190, 50227,   411,  ...,  -100,  -100,  -100],
        [50190, 50227,   411,  ...,  -100,  -100,  -100],
        ...,
        [50190, 50227,   411,  ...,  -100,  -100,  -100],
        [50190, 50227,   411,  ...,  -100,  -100,  -100],
        [50190, 50227,   411,  ...,  -100,  -100,  -100]])

In [23]:
encoding["labels"][0]

tensor([50190, 50227,   411,  ...,  -100,  -100,  -100])

In [24]:
labels_list = encoding["labels"][0].tolist()

# Filter out the -100 values
filtered_labels = [token for token in labels_list if token != -100]

# Decode the cleaned list of tokens
decoded_text_example = processor.tokenizer.batch_decode([filtered_labels], skip_special_tokens=True)[0]


In [25]:
decoded_text_example

'<html> <body> <main class="container"> <div class="header clearfix"> <nav> <ul class="nav nav-pills pull-left"> <li class="active"><a href="#">Aute reprehenderit</a></li> <li><a href="#">Fugiat anim</a></li> <li><a href="#">Aliquip sint</a></li> </ul> </nav> </div> <div class="row"><div class="col-lg-12"> <h4>Et</h4><p>Eiusmod velit culpa non ullamco aliquip dolor</p> <a class="btn btn-danger" href="#" role="button">Consequat nostrud</a> </div> </div> </main> </body> </html> '

In [26]:
for k,v in encoding.items():
    print(k,v.shape)

flattened_patches torch.Size([10, 1024, 770])
attention_mask torch.Size([10, 1024])
labels torch.Size([10, 1024])


### Main Testing function

In [27]:
START_TOKEN_ID = PAD_TOKEN_ID = processor.tokenizer.pad_token_id

In [28]:
def testing_loop(testing_dataloader, model, processor, config, description):
    model.eval()
    bleu_scores = []

    with torch.no_grad():
        test_loop = tqdm(enumerate(testing_dataloader), total=len(testing_dataloader), desc=description)
        for i, batch in test_loop:
            encoding, text_file_paths = batch
            encoding = move_to_device(encoding)
            labels, flattened_patches, attention_mask = encoding["labels"], encoding["flattened_patches"], encoding["attention_mask"]

            outputs = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_new_tokens=MAX_SENTENCE_LEN)

            predictions = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)

            labels[labels == -100] = 0
            answers = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)

            for pred, answer, text_file_path in zip(predictions, answers, text_file_paths):
                with open(f"{OUTPUT_FOLDER}/{text_file_path}_pred.txt", "w") as f:
                    print(pred, file=f)

                with open(f"{OUTPUT_FOLDER}/{text_file_path}_answer.txt", "w") as f:
                    print(answer, file=f)

    return


In [29]:
config = {
          "verbose": VERBOSE,
}

In [30]:
def validate_config(config):
    # Check required keys
    required_keys = [
        "verbose"
    ]
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Key '{key}' must be present in the configuration.")

    # Check that values are in expected ranges
    if not isinstance(config["verbose"], bool):
        raise ValueError("verbose must be a boolean value.")

In [31]:
validate_config(config)
print(config)

{'verbose': True}


### Utility functions

In [32]:
def move_to_device(data):
    if isinstance(data, (list,tuple)):
        return [move_to_device(x) for x in data]
    elif isinstance(data, dict):
        return {k: move_to_device(v) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.to(DEVICE)
    else:
        return data

## Test the model

In [33]:
def test_model(config, processor, model):
    print("Loading model from checkpoint: ", LAST_CHECKPOINT_NAME)
    checkpoint = torch.load(G_DRIVE_FOLDER_CHECKPOINTS + LAST_CHECKPOINT_NAME)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(DEVICE)
    testing_loop(test_dataloader, model, processor, config, f"Test loop")

In [34]:
test_model(config, processor, model)

Loading model from checkpoint:  Pix2Struct_Pix2Code_HTML_LI_FULL_epoch[14].pth


Test loop: 100%|██████████| 18/18 [26:18<00:00, 87.71s/it]
