# Finetune Pix2Struct model on Synthetic Bootstrap dataset

## Setup Envirnoment

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-ts5t8clz
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-ts5t8clz
  Resolved https://github.com/huggingface/transformers to commit 1982dd3b15867c46e1c20645901b0de469fd935f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers==4.32.0.dev0)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.32.0.dev0)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━

In [None]:
!pip install -q wandb

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/2.1 MB[0m [31m1.9 MB/s[0m eta [36m0:00:02[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.1/2.1 MB[0m [31m32.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.6/215.6 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


## Import necessary libraries

In [None]:
from google.colab import drive
import os
import zipfile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import re
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
import torch
from torch.nn import functional as F
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
from pathlib import Path
from nltk import edit_distance
import numpy as np
import wandb
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from torch.utils.data import random_split
import random

## Define variables and parameters

In [None]:
G_DRIVE_FOLDER = '/content/drive/MyDrive/Datasets/'
G_DRIVE_FOLDER_CHECKPOINTS = '/content/drive/MyDrive/Checkpoints/'
DATASET_NAME = 'synthBootstrap_mini'
ZIP_NAME = DATASET_NAME + '.zip'
DESTINATION_FOLDER= '/content/data/'
DATASET_FOLDER = DESTINATION_FOLDER + DATASET_NAME

HTML_FILES_FOLDER = DESTINATION_FOLDER + "/html/"

EXPERIMENT_NAME = "Pix2Struct_SynthBootstrap_1000_Complete"

MAX_SENTENCE_LEN = 4096

CHUNK_LENGTH =  1024
CONTEXT_OVERLAP_LENGTH = 256

MAX_PATCHES = 1024

DEBUG = False
VERBOSE = True

BATCH_SIZE = 4
NUM_WARMUP_STEPS = 1000
MAX_EPOCHS = 20
LR = 1e-4
CHECK_VAL_EVERY_N_EPOCH = 5
GRADIENT_CLIP_VAL = 1.0
ACCUMULATE_GRAD_BATCHES = 8 / BATCH_SIZE

TRAIN_SET_PERCENTAGE = 0.88
VALID_SET_PERCENTAGE = 0.02 # Use 20 for validation
# TEST_SET_PERCENTAGE is 1 - TRAIN_SET_PERCENTAGE - VALID_SET_PERCENTAGE # Use 100 for test

RANDOM_SEED = 123

LOAD_FROM_CHECKPOINT = False
LAST_CHECKPOINT_NAME = ""

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
MAX_N_CHUNKS_PER_SENTENCE = 1 + (MAX_SENTENCE_LEN - CHUNK_LENGTH) // (CHUNK_LENGTH - CONTEXT_OVERLAP_LENGTH)
print("MAX_N_CHUNKS_PER_SENTENCE", MAX_N_CHUNKS_PER_SENTENCE)

MAX_N_CHUNKS_PER_SENTENCE 5



## Load Synthetic Bootstrap Dataset

### Mount Google Drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


### Import zip file from Google Drive

In [None]:
os.makedirs(DESTINATION_FOLDER, exist_ok=True)

with zipfile.ZipFile(G_DRIVE_FOLDER + ZIP_NAME, "r") as zf:
    zf.extractall(DESTINATION_FOLDER)

## Load Model and Processor

In [None]:
repo_id = "google/pix2struct-base"

processor = AutoProcessor.from_pretrained(repo_id)
model = Pix2StructForConditionalGeneration.from_pretrained(repo_id, is_encoder_decoder=True)

Downloading (…)rocessor_config.json:   0%|          | 0.00/231 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.61k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/851k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/3.27M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.92k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

## Create Dataset class

### Preprocessing functions

In [None]:
def round_floats_in_text(text, precision=0):
    # match float numbers with 2 or more decimal places in the text
    pattern = r"\b\d+\.\d{2,}\b"

    def replace(match):
        float_number = float(match.group())
        return f"{float_number:.{precision}f}"

    text = re.sub(pattern, replace, text)
    return text

In [None]:
def remove_html_comments(text):
    # match html comments
    pattern = r"<!--.*?-->"

    text = re.sub(pattern, '', text, flags=re.DOTALL)
    return text

In [None]:
def preprocess_html_file(html_text):
    text_cleaned = html_text.replace('\n', ' ')
    text_cleaned_without_multiple_spaces = re.sub(r'\s+', ' ', text_cleaned)
    text_without_comments = remove_html_comments(text_cleaned_without_multiple_spaces)
    text_without_long_floats = round_floats_in_text(text_without_comments)
    return text_without_long_floats

### Find max sentence length and new unknown tokens

In [None]:
# Find max length
all_paths = os.listdir(HTML_FILES_FOLDER)

max_length = 0

# Read text files and add new tokens to dictionary
tokens_to_add = set()
for html_file_path in all_paths:
    with open(HTML_FILES_FOLDER + html_file_path, "r") as reader:
        splitted_text = processor.tokenizer(preprocess_html_file(reader.read())).tokens()
        tokens_to_add = tokens_to_add.union(set(splitted_text))

        # Check if the current sentence has the largest number of tokens
        if len(splitted_text) > max_length:
            max_length = len(splitted_text)

print(f"Max sentence length = {max_length}")

newly_added_num = processor.tokenizer.add_tokens(list(tokens_to_add))
print(f"Number of new tokens = {newly_added_num}")

# Resize the model's token embeddings if there are new tokens
if newly_added_num > 0:
    model.decoder.resize_token_embeddings(len(processor.tokenizer))

Max sentence length = 3995
Number of new tokens = 0


### Split files into training - validation - test sets

In [None]:
random.seed(RANDOM_SEED)
random.shuffle(sorted(all_paths))

train_len = int(TRAIN_SET_PERCENTAGE * len(all_paths))
valid_len = int(VALID_SET_PERCENTAGE * len(all_paths))

train_paths = all_paths[:train_len]
valid_paths = all_paths[train_len:train_len+valid_len]
test_paths = all_paths[train_len+valid_len:]

print(f"TRAIN_SET size = {len(train_paths)}")
print(f"VALID_SET size = {len(valid_paths)}")
print(f"TEST_SET size = {len(test_paths)}")

TRAIN_SET size = 880
VALID_SET size = 20
TEST_SET size = 100


In [None]:
class SythBootstrapTrainingDataset(Dataset):
    # This is a modification of the dataset used for validation and testing
    # In this one the sentences are already split into chunks, already having
    # the context from the previous chunk, empty chunks are discarded
    def __init__(self, root_dir, transform, text_files_paths):

        self.root_dir = root_dir
        self.transform = transform
        self.text_files_paths = text_files_paths

        self.max_patches = MAX_PATCHES
        self.max_length = MAX_SENTENCE_LEN
        self.ignore_id = -100

        self.data = []
        self.images_encoding = []

        for text_file in tqdm(text_files_paths):
            image_file = text_file.replace('.html', '.png')

            # Directly process the text files, and save them in the ram
            # Do the same also for images, if there is enough space in memory
            text_file_path = os.path.join(root_dir + "html/", text_file)
            image_file_path = os.path.join(root_dir + "images/", image_file)

            # Each data entry has the following structure
            # labels, image_encoding_idx, part

            # image_encoding_idx points to an entry of images_encoding, which contains attention_mask and flattened_patches for the image
            # Since a single image is used for multiple slices of the same text, this approach is used to save memory

            # Load image
            image = Image.open(image_file_path).convert('RGB')

            if DEBUG:
                image.show()

            if self.transform:
                image = self.transform(image)

            encoding = processor(images=image, max_patches=self.max_patches, return_tensors="pt")
            encoding = {k:v.squeeze() for k,v in encoding.items()}

            self.images_encoding.append(encoding)
            image_encoding_idx = len(self.images_encoding) - 1

            # Load text
            with open(text_file_path, 'r') as f:
                text = f.read()
                text_cleaned = preprocess_html_file(text)

            if DEBUG:
              print("text:")
              print(text)
              print("\n\n\ntext_cleaned:")
              print(text_cleaned)

            input_ids = processor.tokenizer(
                text_cleaned,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids

            input_ids_slices = []

            start_index = 0
            end_index = CHUNK_LENGTH
            while end_index <= MAX_SENTENCE_LEN:
                input_ids_slices.append(input_ids[:, start_index:end_index])
                start_index = end_index - CONTEXT_OVERLAP_LENGTH
                end_index = start_index + CHUNK_LENGTH

            for part, input_ids_slice in enumerate(input_ids_slices):
                labels = input_ids_slice.squeeze().clone()

                labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token

                # Skip slices with only padding tokens, ignore context from the previous chunk
                if part != 0 and all(x == self.ignore_id for x in labels[CONTEXT_OVERLAP_LENGTH:]):
                    continue

                # labels, image_encoding_idx, part
                # Save them as int32 to save ram memory
                self.data.append((labels.to(torch.int32), image_encoding_idx, part))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        labels, image_encoding_idx, part = self.data[idx]
        encoding = self.images_encoding[image_encoding_idx]
        encoding["labels"] = labels.to(torch.int64)

        return encoding, part

In [None]:
class SythBootstrapDataset(Dataset):
    def __init__(self, root_dir, transform, text_files_paths):

        self.root_dir = root_dir
        self.transform = transform
        self.text_files_paths = text_files_paths

        self.max_patches = MAX_PATCHES
        self.max_length = MAX_SENTENCE_LEN
        self.ignore_id = -100

        self.encodings = []

        for text_file in tqdm(text_files_paths):
            image_file = text_file.replace('.html', '.png')

            # Directly process the text files, and save them in the ram
            # Do the same also for images, if there is enough space in memory
            text_file_path = os.path.join(root_dir + "html/", text_file)
            image_file_path = os.path.join(root_dir + "images/", image_file)

            # Load image
            image = Image.open(image_file_path).convert('RGB')

            if DEBUG:
                image.show()

            if self.transform:
                image = self.transform(image)

            encoding = processor(images=image, max_patches=self.max_patches, return_tensors="pt")
            encoding = {k:v.squeeze() for k,v in encoding.items()}

            # Load text
            with open(text_file_path, 'r') as f:
                text = f.read()
                text_cleaned = preprocess_html_file(text)

            if DEBUG:
              print("text:")
              print(text)
              print("\n\n\ntext_cleaned:")
              print(text_cleaned)

            input_ids = processor.tokenizer(
                text_cleaned,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids

            labels = input_ids.squeeze().clone()
            labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token

            encoding["labels"] = labels.to(torch.int32)

            # For each sample save directly the encoding of both text and image
            self.encodings.append(encoding)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings[idx]

In [None]:
# Transformations for the image
transform = transforms.Compose([
    transforms.ToTensor(),  # convert PIL Image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # normalize for pretrained models
])

# Instantiate the CustomDataset
train_dataset = SythBootstrapTrainingDataset(DESTINATION_FOLDER, transform, train_paths)
val_dataset = SythBootstrapDataset(DESTINATION_FOLDER, transform, valid_paths)

# Use DataLoader for batching and shuffling
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle=False) # Use 10 as batch for testing

100%|██████████| 880/880 [00:58<00:00, 15.15it/s]
100%|██████████| 20/20 [00:01<00:00, 18.77it/s]


In [None]:
print(f"train_dataloader size = {len(train_dataloader)}")
print(f"val_dataloader size = {len(val_dataloader)}")

train_dataloader size = 506
val_dataloader size = 2


## Training

In [None]:
START_TOKEN_ID = PAD_TOKEN_ID = processor.tokenizer.pad_token_id

### Utility functions

In [None]:
def move_to_device(data):
    if isinstance(data, (list,tuple)):
        return [move_to_device(x) for x in data]
    elif isinstance(data, dict):
        return {k: move_to_device(v) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.to(DEVICE)
    else:
        return data

In [None]:
def create_extended_attention_mask_for_decoder_with_context(input_shape, attention_mask, part):
    device = attention_mask.device
    batch_size, seq_length = input_shape
    seq_ids = torch.arange(seq_length, device=device)

    causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]

    # Expand part to have the same shape as the relevant slice of causal_mask
    part_expanded = part.unsqueeze(-1).unsqueeze(-1).expand(-1, seq_length, CONTEXT_OVERLAP_LENGTH)

    # Create a mask with ones where part is not zero
    context_mask = (part_expanded != 0).float()

    # Apply the context_mask to the corresponding part of causal_mask
    causal_mask[:, :, :CONTEXT_OVERLAP_LENGTH] = causal_mask[:, :, :CONTEXT_OVERLAP_LENGTH] * (1 - context_mask) + context_mask

    # in case past_key_values are used we need to add a prefix ones mask to the causal mask
    causal_mask = causal_mask.to(attention_mask.dtype)

    if causal_mask.shape[1] < attention_mask.shape[1]:
        print("!!should not enter here in my case!!")
        prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
        causal_mask = torch.cat(
            [
                torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
                causal_mask,
            ],
            axis=-1,
        )

    extended_attention_mask = causal_mask[:, :, :] * attention_mask[:, None, :]
    return extended_attention_mask


In [None]:
def get_attention_mask(decoder_input_ids, part):
    decoder_attention_mask = (decoder_input_ids.ne(PAD_TOKEN_ID).float())

    # always attend on first token
    decoder_attention_mask[:, 0] = 1

    # Expand part to have the same shape as the relevant slice of decoder_attention_mask
    part_expanded = part.unsqueeze(-1).expand(-1, CONTEXT_OVERLAP_LENGTH)

    # Create a mask with ones where part is not zero
    context_mask = (part_expanded != 0).float()

    # Apply the context_mask to the corresponding part of decoder_attention_mask
    decoder_attention_mask[:, 0:CONTEXT_OVERLAP_LENGTH] = decoder_attention_mask[:, 0:CONTEXT_OVERLAP_LENGTH] * (1 - context_mask) + context_mask

    return decoder_attention_mask

In [None]:
def shift_right_modified(input_ids, decoder_starting_token_idx):

    # shift inputs to the right
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
    shifted_input_ids[..., 0] = decoder_starting_token_idx

    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, PAD_TOKEN_ID)

    return shifted_input_ids

In [None]:
def get_decoder_input_ids(labels_chunk, start_id):
    return shift_right_modified(labels_chunk, start_id)

In [None]:
def get_decoder_input_ids_and_attention_mask(labels, part):
    decoder_input_ids = get_decoder_input_ids(labels, START_TOKEN_ID)
    decoder_attention_mask = get_attention_mask(decoder_input_ids, part)
    extended_decoder_attention_mask = create_extended_attention_mask_for_decoder_with_context(decoder_input_ids.shape, decoder_attention_mask, part)

    return decoder_input_ids, extended_decoder_attention_mask

### Main training function

In [None]:
def train_model(config, processor, model, train_dataloader, val_dataloader):
    # Extract configuration values
    lr = config.get("lr")
    max_epochs = config.get("max_epochs")
    num_warmup_steps = config.get("num_warmup_steps")

    model.to(DEVICE)

    optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=lr, weight_decay=1e-05)

    # Use total steps (i.e., max_epochs * length_of_train_data)
    total_steps = max_epochs * len(train_dataloader)
    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=num_warmup_steps,
                                                num_training_steps=total_steps)

    global_step = 0  # to keep track of total steps
    epoch_start = 0

    if LOAD_FROM_CHECKPOINT:
        print("Loading model from checkpoint:", LAST_CHECKPOINT_NAME)
        checkpoint = torch.load(G_DRIVE_FOLDER_CHECKPOINTS + LAST_CHECKPOINT_NAME)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        epoch_start = checkpoint["epoch"] + 1
        global_step = checkpoint["global_step"] + 1
        wandb_run_id = checkpoint["wandb_run_id"]

        # Resume the WandB run
        wandb.init(project="Pix2Struct", name="run-" + EXPERIMENT_NAME, config=config, resume=wandb_run_id)
    else:
        wandb.init(project="Pix2Struct", name="run-" + EXPERIMENT_NAME, config=config)

    epoch_last = epoch_start + max_epochs - 1
    for epoch in range(epoch_start, epoch_start + max_epochs):
        global_step, moving_avg_loss = training_loop(epoch, train_dataloader, model, config, optimizer, scheduler, global_step, epoch_last)

        if epoch == 0 + epoch_start or epoch == epoch_last or (epoch + 1) % config.get("check_val_every_n_epoch") == 0:
            avg_bleu_score = testing_loop(val_dataloader, model, processor, config, f"Epoch {epoch}/{epoch_last} - valid loop")

            # Save the model after each validation step
            save_checkpoint(model, optimizer, scheduler, epoch, global_step, wandb.run.id, avg_bleu_score, EXPERIMENT_NAME, G_DRIVE_FOLDER_CHECKPOINTS)

            if config.get("verbose", False):
                print(f"Moving Avg Loss: {moving_avg_loss:.3f}")
                print(f" Avg Bleu Score: {avg_bleu_score:.2f}")

            wandb.log({"moving_avg_loss": moving_avg_loss, "bleu": avg_bleu_score, **{f'lr_{i}': param_group['lr'] for i, param_group in enumerate(optimizer.param_groups)}})

    wandb.finish()

In [None]:
def training_loop(epoch, train_dataloader, model, config, optimizer, scheduler, global_step, epoch_last):
    model.train()
    train_loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch}/{epoch_last} - train loop")

    # Extract configuration values
    accumulate_grad_batches = config.get('accumulate_grad_batches', 1)
    gradient_clip_val = config.get("gradient_clip_val")

    moving_avg_loss = 0
    alpha = 0.1 # Smoothing factor

    for step, batch in train_loop:
        encoding, part = map(move_to_device, batch)
        labels, flattened_patches, attention_mask = encoding["labels"], encoding["flattened_patches"], encoding["attention_mask"]

        decoder_input_ids, decoder_attention_mask = get_decoder_input_ids_and_attention_mask(labels, part)
        outputs = model(labels=labels, flattened_patches=flattened_patches, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)
        loss = outputs.loss
        loss.backward()

        if global_step % accumulate_grad_batches == 0 or step == len(train_dataloader) - 1:
            if gradient_clip_val:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val)
            optimizer.step()
            optimizer.zero_grad()

        # Update the progress bar
        train_loop.set_postfix({'loss': loss.item()}, refresh=True)

        scheduler.step()
        global_step += 1

        # Update the moving average loss
        moving_avg_loss = loss.item() if moving_avg_loss == 0 else alpha * loss.item() + (1 - alpha) * moving_avg_loss

        # Log Loss after each step
        wandb.log({"loss": loss.item()})

    return global_step, moving_avg_loss

In [None]:
def testing_loop(testing_dataloader, model, processor, config, description):
    model.eval()
    bleu_scores = []

    with torch.no_grad():
        test_loop = tqdm(enumerate(testing_dataloader), total=len(testing_dataloader), desc=description)
        for i, batch in test_loop:
            encoding = move_to_device(batch)
            labels, flattened_patches, attention_mask = encoding["labels"], encoding["flattened_patches"], encoding["attention_mask"]

            # Initialize total_outputs with zeros
            total_outputs = None
            context_from_last = None

            # Initialize a mask to track which sentences are finished
            finished_sentences_mask = torch.zeros(flattened_patches.size(0), dtype=torch.bool, device=flattened_patches.device)

            for iteration in range(MAX_N_CHUNKS_PER_SENTENCE):

                generate_args = {
                    "flattened_patches": flattened_patches[~finished_sentences_mask],
                    "attention_mask": attention_mask[~finished_sentences_mask],
                    "max_new_tokens": CHUNK_LENGTH - (CONTEXT_OVERLAP_LENGTH if iteration else 0),
                }

                if iteration and context_from_last is not None:
                    generate_args["decoder_input_ids"] = context_from_last[~finished_sentences_mask]

                outputs = model.generate(**generate_args)

                # Remove context overlap only from the second iteration onwards
                new_chunks = outputs if iteration == 0 else outputs[:, CONTEXT_OVERLAP_LENGTH:]

                if iteration == 0:
                    total_outputs = new_chunks
                else:
                    # Update total_outputs by concatenating new chunks
                    new_chunks_with_padding_chunks = torch.full((flattened_patches.shape[0], new_chunks.shape[1]), PAD_TOKEN_ID, dtype=new_chunks.dtype, device=new_chunks.device)
                    new_chunks_with_padding_chunks[~finished_sentences_mask] = new_chunks
                    total_outputs = torch.cat((total_outputs, new_chunks_with_padding_chunks), dim=1)

                # Update the finished_sentences_mask
                finished_sentences_mask[~finished_sentences_mask] |= (outputs == processor.tokenizer.eos_token_id).any(dim=1)

                # If all sentences are finished, exit the loop
                if finished_sentences_mask.all():
                    break

                if outputs.shape[1] < CHUNK_LENGTH:
                    print("ERROR: !! should have already exited because all sentences reached the end!!")

                # -1 because it will put in front a START_TOKEN automatically
                context_from_last = total_outputs[:, -(CONTEXT_OVERLAP_LENGTH-1):]

            predictions = processor.tokenizer.batch_decode(total_outputs, skip_special_tokens=True)

            labels[labels == -100] = 0
            answers = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)

            bleu_scores += [corpus_bleu([[answer]], [pred], smoothing_function=SmoothingFunction().method4) for pred, answer in zip(predictions, answers)]

            avg_bleu_score = np.mean(bleu_scores)
            test_loop.set_postfix(bleu_score=avg_bleu_score)

            if config.get("verbose", False):
                for pred, answer, bleu_score in zip(predictions, answers, bleu_scores):
                    tqdm.write(f"\nPrediction: {pred}\n    Answer: {answer}\n      Bleu: {bleu_score:.2f}")


    return avg_bleu_score


In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, global_step, wandb_run_id, avg_bleu_score, experiment_name, folder_path):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "epoch": epoch,
        "global_step": global_step,
        'wandb_run_id': wandb_run_id
    }
    model_name = f"{experiment_name}_epoch[{epoch}]_bleu[{avg_bleu_score:.2f}].pth"
    torch.save(checkpoint, folder_path + model_name)


In [None]:
config = {
          "batch_size": BATCH_SIZE,
          "num_warmup_steps": NUM_WARMUP_STEPS,
          "max_epochs": MAX_EPOCHS,
          "lr": LR,
          "check_val_every_n_epoch": CHECK_VAL_EVERY_N_EPOCH,
          "gradient_clip_val": GRADIENT_CLIP_VAL,
          "accumulate_grad_batches": ACCUMULATE_GRAD_BATCHES,
          "verbose": VERBOSE,
}

In [None]:
def validate_config(config):
    # Check required keys
    required_keys = [
        "batch_size",
        "num_warmup_steps",
        "max_epochs",
        "lr",
        "check_val_every_n_epoch",
        "gradient_clip_val",
        "accumulate_grad_batches",
        "verbose"
    ]
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Key '{key}' must be present in the configuration.")

    # Check that values are in expected ranges
    if config["batch_size"] <= 0:
        raise ValueError("batch_size must be positive.")
    if config["num_warmup_steps"] < 0:
        raise ValueError("num_warmup_steps must be non-negative.")
    if config["max_epochs"] <= 0:
        raise ValueError("max_epochs must be positive.")
    if config["lr"] <= 0:
        raise ValueError("Learning rate must be positive.")
    if config["check_val_every_n_epoch"] <= 0:
        raise ValueError("check_val_every_n_epoch must be positive.")
    if config["gradient_clip_val"] < 0:
        raise ValueError("gradient_clip_val must be non-negative.")
    if config["accumulate_grad_batches"] <= 0:
        raise ValueError("accumulate_grad_batches must be positive.")
    if not isinstance(config["verbose"], bool):
        raise ValueError("verbose must be a boolean value.")


In [None]:
validate_config(config)
print(config)

{'batch_size': 4, 'num_warmup_steps': 1000, 'max_epochs': 20, 'lr': 0.0001, 'check_val_every_n_epoch': 5, 'gradient_clip_val': 1.0, 'accumulate_grad_batches': 2.0, 'verbose': True}


In [None]:
train_model(config, processor, model, train_dataloader, val_dataloader)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch 0/19 - train loop: 100%|██████████| 506/506 [06:32<00:00,  1.29it/s, loss=1.66]
Epoch 0/19 - valid loop:  50%|█████     | 1/2 [13:47<13:47, 827.14s/it, bleu_score=0.0623]


Prediction: <<!DOCTYPE html> <html> <head> <title>Dolore Cillum</title> <meta content="{&quot;primary&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;dark&quot;: &quot;rgb(122, 122, 122)&quot;, &quot;da

Epoch 0/19 - valid loop: 100%|██████████| 2/2 [27:26<00:00, 823.48s/it, bleu_score=0.0776]



Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;rgb(172, 172, 172)&quot;, &quot;dark&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;dark&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;dark&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="width=device-width, initial-scale=1.0" nam

Epoch 1/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.668]
Epoch 2/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=1.3]
Epoch 3/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.387]
Epoch 4/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.267]
Epoch 4/19 - valid loop:  50%|█████     | 1/2 [08:34<08:34, 514.17s/it, bleu_score=0.497]


Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(189, 189, 189)&quot;, &quot;secondary&quot;: &quot;rgb(189, 189, 189)&quot;, &quot;light&quot;: &quot;rgb(189, 189, 189)&quot;, &quot;dark&quot;: &quot;rgb(189, 189, 189)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="2" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-lg" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-lg navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Ex</a> <button class="n

Epoch 4/19 - valid loop: 100%|██████████| 2/2 [14:30<00:00, 435.41s/it, bleu_score=0.541]



Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;secondary&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;light&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;dark&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="4" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-lg" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-lg navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Incididunt</a> <button 

Epoch 5/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.423]
Epoch 6/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.553]
Epoch 7/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.191]
Epoch 8/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.511]
Epoch 9/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.258]
Epoch 9/19 - valid loop:  50%|█████     | 1/2 [07:46<07:46, 466.22s/it, bleu_score=0.692]


Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(189, 224, 224)&quot;, &quot;secondary&quot;: &quot;rgb(189, 224, 224)&quot;, &quot;light&quot;: &quot;rgb(189, 224, 224)&quot;, &quot;dark&quot;: &quot;rgb(189, 224, 224)&quot;, &quot;enable-gradients&quot;: true}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="2" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-lg" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-lg navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Ex</a> <button class="na

Epoch 9/19 - valid loop: 100%|██████████| 2/2 [13:24<00:00, 402.39s/it, bleu_score=0.746]



Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(240, 190, 190)&quot;, &quot;secondary&quot;: &quot;rgb(190, 190, 190)&quot;, &quot;light&quot;: &quot;rgb(240, 190, 190)&quot;, &quot;dark&quot;: &quot;rgb(240, 190, 190)&quot;, &quot;enable-gradients&quot;: true}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="2" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-lg" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-lg navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Incididunt</a> <button c

Epoch 10/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.278]
Epoch 11/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.087]
Epoch 12/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.186]
Epoch 13/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.306]
Epoch 14/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.0952]
Epoch 14/19 - valid loop:  50%|█████     | 1/2 [07:31<07:31, 451.75s/it, bleu_score=0.738]


Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(189, 204, 204)&quot;, &quot;secondary&quot;: &quot;rgb(189, 189, 204)&quot;, &quot;light&quot;: &quot;rgb(189, 204, 204)&quot;, &quot;dark&quot;: &quot;rgb(189, 189, 204)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="2" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-sm" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-sm navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Ex</a> <button class="n

Epoch 14/19 - valid loop: 100%|██████████| 2/2 [13:54<00:00, 417.00s/it, bleu_score=0.795]



Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(172, 212, 212)&quot;, &quot;secondary&quot;: &quot;rgb(172, 172, 212)&quot;, &quot;light&quot;: &quot;rgb(172, 212, 172)&quot;, &quot;dark&quot;: &quot;rgb(172, 172, 212)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="2" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-sm" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-sm navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Incididunt</a> <button 

Epoch 15/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.257]
Epoch 16/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.238]
Epoch 17/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.098]
Epoch 18/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.228]
Epoch 19/19 - train loop: 100%|██████████| 506/506 [06:29<00:00,  1.30it/s, loss=0.277]
Epoch 19/19 - valid loop:  50%|█████     | 1/2 [07:16<07:16, 436.56s/it, bleu_score=0.773]


Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(189, 204, 204)&quot;, &quot;secondary&quot;: &quot;rgb(189, 189, 204)&quot;, &quot;light&quot;: &quot;rgb(189, 204, 204)&quot;, &quot;dark&quot;: &quot;rgb(189, 189, 204)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="2" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-sm" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-sm navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Ex</a> <button class="n

Epoch 19/19 - valid loop: 100%|██████████| 2/2 [13:12<00:00, 396.24s/it, bleu_score=0.809]



Prediction: <!DOCTYPE html> <html> <head> <title>Dominate</title> <meta content="{&quot;primary&quot;: &quot;rgb(172, 212, 172)&quot;, &quot;secondary&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;light&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;dark&quot;: &quot;rgb(172, 172, 172)&quot;, &quot;enable-gradients&quot;: false}" name="wg-palette"> <meta content="width=device-width, initial-scale=1.0" name="viewport"> <meta content="Web Generator" name="author"> <meta content="1" name="wg-layout"> <link href="../css/custom-bootstrap.css" rel="stylesheet"> <link href="../css/wg-extras.css" rel="stylesheet"> <script src="../js/jquery-3.2.1.slim.min.js" type="text/javascript"></script> <script src="../js/bootstrap.min.js" type="text/javascript"></script> </head> <body class=""> <div class=" rounded shadow-sm" id="full-wrapper"> <nav class="bg-gradient-light navbar-light rounded shadow-sm navbar navbar-expand-md" data-wg-type="navbar"> <a class="navbar-brand" href="#">Incididunt</a> <button 

0,1
bleu,▁▅▇██
loss,█▄▃▂▂▂▁▃▁▂▁▁▂▁▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂
lr_0,▅█▅▂▁
moving_avg_loss,█▂▁▁▁

0,1
bleu,0.80886
loss,0.27705
lr_0,0.0
moving_avg_loss,0.23817
