In [1]:
#import libraries
import argparse
import importlib
import torch
import cv2
import os
import pandas as pd
import h5py as h5
import evaluate

from matplotlib import pyplot as plt

from src.common.configuration import get_dataset_configuration, get_model_configuration
from src.datasets.comics_dialogue_generation import ComicsDialogueGenerationDataset
# from src.datasets.comics_images_Sim_CLR_text import ComicsImageTextDataset, create_test_dataset
from src.datasets.text_cloze_image_text_vlt5_simCLR import TextClozeImageTextVLT5Dataset, create_test_dataset
from src.models.dialogue_generation_vlt5 import DialogueGenerationVLT5Model
from src.tokenizers.vlt5_tokenizers import VLT5TokenizerFast
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def parse_arguments():
    parser = argparse.ArgumentParser(description='Plotting script')

    parser.add_argument('--model', type=str, default="text_cloze_image_text_vlt5",
                        help='Model to run')
    parser.add_argument('--load_cloze_checkpoint', type=str, default="runs/TextClozeImageTextVLT5Model_text_cloze_image_text_vlt5_simCLR_2023-12-13_11:22:31/models/epoch_3.pt", # Textract + blip2 hard
                        help='Path to text cloze model checkpoint')
    parser.add_argument('--dataset_config', type=str, default="text_cloze_image_text_blip2_hard_textract_vlt5",
                        help='Dataset config to use')
    parser.add_argument('--dataset_dir', type=str, default="/data/data/datasets/COMICS",
                        help='Dataset directory path')
    parser.add_argument('--output_dir', type=str, default="plots_textract/",
                        help='Output directory path')
    parser.add_argument('--sample_id', type=int, default=275,
                        help='Sample id to plot')
    parser.add_argument('--seed', type=int, default=4,
                        help='Random seed')

    return parser.parse_args("")

In [3]:
def load_checkpoint(checkpoint_path, model):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    model.load_checkpoint(checkpoint["model_state_dict"])
    model.eval()
    return model

# Loading the model and dataset

In [4]:
args = parse_arguments()
torch.manual_seed(args.seed)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


### Model

In [5]:
model_text_cloze_config = get_model_configuration(args.model)

In [None]:
ModelClass = getattr(importlib.import_module(
    f"src.models.{args.model}"), model_text_cloze_config.classname)
model_text_cloze = ModelClass(model_text_cloze_config, device).to(device)
load_checkpoint(args.load_cloze_checkpoint, model_text_cloze)

### Loading the dataset

In [7]:
tokenizer = VLT5TokenizerFast.from_pretrained(
                model_text_cloze_config.backbone,
                max_length=model_text_cloze_config.max_text_length,
                do_lower_case=model_text_cloze_config.do_lower_case,
            )
dataset_config = get_dataset_configuration(args.dataset_config)
dataset_config["test"] = True
df, dataset = create_test_dataset(args.dataset_dir, device, dataset_config, tokenizer)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'VLT5TokenizerFast'.
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [24]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

### Checking some predictions

In [25]:
model_text_cloze.eval()
with torch.no_grad():
    for i, sample_data in enumerate(dataloader):
        sample_data = {key: value.type(torch.float32) if value.dtype == torch.float64 else value for key, value in sample_data.items() if isinstance(value, torch.Tensor)}

        #args.sample_id = sample_data["idx"].item() # If it crashes, uncoment the line in the dataset class and in the collate function
        output = model_text_cloze.run(**sample_data)
        print(output["prediction"])
        prediction_text_cloze = tokenizer.decode(output["prediction"], skip_special_tokens=False)
        if prediction_text_cloze == "":
            prediction_text_cloze = 0
        print(f"Prediction: {prediction_text_cloze}")
        if i == 10:
            break

tensor([204], device='cuda:1')
Prediction: 2
tensor([3], device='cuda:1')
Prediction: 0
tensor([204], device='cuda:1')
Prediction: 2
tensor([204], device='cuda:1')
Prediction: 2
tensor([209], device='cuda:1')
Prediction: 1
tensor([204], device='cuda:1')
Prediction: 2
tensor([209], device='cuda:1')
Prediction: 1
tensor([204], device='cuda:1')
Prediction: 2
tensor([209], device='cuda:1')
Prediction: 1
tensor([204], device='cuda:1')
Prediction: 2
tensor([3], device='cuda:1')
Prediction: 0


#### Looking to a random example

In [26]:
prediction_text_cloze = tokenizer.decode(output["prediction"], skip_special_tokens=False)
if prediction_text_cloze == "":
    prediction_text_cloze = 0
print(f"Prediction: {prediction_text_cloze}")
#~prediction_text_cloze = int(prediction_text_cloze) + 1

Prediction: 0


In [27]:
target = tokenizer.decode(sample_data["target"], skip_special_tokens=True)
if target == "":
    target = 0
print(f"Target: {target}")
target = int(target) + 1

Target: 0


In [28]:
print(sample_data.keys())

dict_keys(['boxes', 'vis_feats', 'input_ids', 'input_length', 'label', 'target', 'target_ids', 'target_length', 'idx'])


In [29]:
sample = df.iloc[sample_data["idx"].item()]
book_id = sample["book_id"]
page_id = sample["page_id"]
target_text = sample[f"answer_candidate_{sample['correct_answer']}_text"]
print(f"Book id: {book_id}")
print(f"Page id: {page_id}")

Book id: 3452
Page id: 45


In [30]:
CORRECT_COLOR = (0, 1, 0)
INCORRECT_COLOR = (1, 0, 0)

In [None]:
# Plot sample
fig = plt.figure(figsize=(16, 8))

# setting values to rows and column variables
rows = 2
columns = 4

# reading images
Image1 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_0_id"]}.jpg')
Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
Image2 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_1_id"]}.jpg')
Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
Image3 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_2_id"]}.jpg')
Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
Image4 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["answer_panel_id"]}.jpg')
Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

# Adds a subplot at the 1st position
fig.add_subplot(rows, columns, 1)

# showing image
plt.imshow(Image1)
plt.axis('off')
plt.title("Context panel 1")

# Adds a subplot at the 2nd position
fig.add_subplot(rows, columns, 2)

# showing image
plt.imshow(Image2)
plt.axis('off')
plt.title("Context panel 2")

# Adds a subplot at the 3rd position
fig.add_subplot(rows, columns, 3)

# showing image
plt.imshow(Image3)
plt.axis('off')
plt.title("Context panel 3")

# Adds a subplot at the 4th position
fig.add_subplot(rows, columns, 4)

# showing image
plt.imshow(Image4)
plt.axis('off')
plt.title("Answer panel")

# Adding a subplot at the 5th to 7th position
for i in range(1, 4):
    fig.add_subplot(rows, columns, i+4)

    # showing text
    color = CORRECT_COLOR if i - \
        1 == sample["correct_answer"] else INCORRECT_COLOR
    bb = dict(facecolor='white', alpha=1.) if i-1 == prediction_text_cloze else None
    content = sample[f"answer_candidate_{i-1}_text"]
    plt.title(f"Candidate {i}")
    txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                    ha="center", va="top", color=color, bbox=bb)
    txt._get_wrap_line_width = lambda: 300.
    plt.axis('off')

# save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
os.makedirs(args.output_dir, exist_ok=True)
plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')

### Iterating over the test set 
Generating an image of the prediction for each 25 examples

In [32]:
# Output dir where to save all the generated images
args.output_dir = "plots_hard_vlt5_textract_blip2/"

In [33]:
from tqdm import tqdm
CORRECT_COLOR = (0, 1, 0)
INCORRECT_COLOR = (1, 0, 0)

model_text_cloze.eval()
model_text_cloze.to(device)
for sample_data in tqdm(dataloader):
    args.sample_id = sample_data["idx"].item()
    # Make the plot every 25 samples
    if args.sample_id % 25 == 0:
        sample_data = {key: value.type(torch.float32) if value.dtype == torch.float64 else value for key, value in sample_data.items() if isinstance(value, torch.Tensor)}
        output = model_text_cloze.run(**sample_data)
        
        prediction_text_cloze = tokenizer.decode(output["prediction"], skip_special_tokens=False)
        if prediction_text_cloze == "":
            prediction_text_cloze = 0
        prediction_text_cloze = int(prediction_text_cloze)


        sample = df.iloc[args.sample_id]
        book_id = sample["book_id"]
        page_id = sample["page_id"]
        target_text = sample[f"answer_candidate_{sample['correct_answer']}_text"]

        # Plot sample
        fig = plt.figure(figsize=(16, 8))

        # setting values to rows and column variables
        rows = 2
        columns = 4

        # reading images
        Image1 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_0_id"]}.jpg')
        Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
        Image2 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_1_id"]}.jpg')
        Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
        Image3 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_2_id"]}.jpg')
        Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
        Image4 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["answer_panel_id"]}.jpg')
        Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

        # Adds a subplot at the 1st position
        fig.add_subplot(rows, columns, 1)

        # showing image
        plt.imshow(Image1)
        plt.axis('off')
        plt.title("Context panel 1")

        # Adds a subplot at the 2nd position
        fig.add_subplot(rows, columns, 2)

        # showing image
        plt.imshow(Image2)
        plt.axis('off')
        plt.title("Context panel 2")

        # Adds a subplot at the 3rd position
        fig.add_subplot(rows, columns, 3)

        # showing image
        plt.imshow(Image3)
        plt.axis('off')
        plt.title("Context panel 3")

        # Adds a subplot at the 4th position
        fig.add_subplot(rows, columns, 4)

        # showing image
        plt.imshow(Image4)
        plt.axis('off')
        plt.title("Answer panel")

        # Adding a subplot at the 5th to 7th position
        for i in range(3):
            fig.add_subplot(rows, columns, i+5)

            # showing text
            color = CORRECT_COLOR if i == sample["correct_answer"] else INCORRECT_COLOR
            bb = dict(facecolor='white', alpha=1.) if i == prediction_text_cloze else None
            content = sample[f"answer_candidate_{i}_text"]
            plt.title(f"Candidate {i+1}")
            txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                            ha="center", va="top", color=color, bbox=bb)
            txt._get_wrap_line_width = lambda: 300.
            plt.axis('off')

        # save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
        os.makedirs(args.output_dir, exist_ok=True)
        if prediction_text_cloze != sample["correct_answer"]:
            plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}_WRONG.png')
        else:
            plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')
        # plt.show()
        plt.close(fig)


  2%|▏         | 201/11909 [00:04<03:42, 52.71it/s]

100%|██████████| 11909/11909 [04:26<00:00, 44.75it/s]


### Only showing dataset examples

In [None]:
CORRECT_COLOR = (0, 1, 0)
INCORRECT_COLOR = (1, 0, 0)

import random
for i in range(0, 10, 1):
    args.sample_id = random.randint(0, len(dataset))
    sample = df.iloc[args.sample_id]
    book_id = sample["book_id"]
    page_id = sample["page_id"]
    target_text = sample[f"answer_candidate_{sample['correct_answer']}_text"]

    # Plot sample
    fig = plt.figure(figsize=(16, 8))

    # setting values to rows and column variables
    rows = 2
    columns = 4

    # reading images
    Image1 = cv2.imread(
        f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_0"]}.jpg')
    Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
    Image2 = cv2.imread(
        f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_1"]}.jpg')
    Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
    Image3 = cv2.imread(
        f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["context_panel_2"]}.jpg')
    Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
    Image4 = cv2.imread(
        f'{args.dataset_dir}/panels/{book_id}/{page_id}_{sample["answer_panel"]}.jpg')
    Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

    # Adds a subplot at the 1st position
    fig.add_subplot(rows, columns, 1)

    # showing image
    plt.imshow(Image1)
    plt.axis('off')
    plt.title("Context panel 1")

    # Adds a subplot at the 2nd position
    fig.add_subplot(rows, columns, 2)

    # showing image
    plt.imshow(Image2)
    plt.axis('off')
    plt.title("Context panel 2")

    # Adds a subplot at the 3rd position
    fig.add_subplot(rows, columns, 3)

    # showing image
    plt.imshow(Image3)
    plt.axis('off')
    plt.title("Context panel 3")

    # Adds a subplot at the 4th position
    fig.add_subplot(rows, columns, 4)

    # showing image
    plt.imshow(Image4)
    plt.axis('off')
    plt.title("Answer panel")

    # Adding a subplot at the 5th to 7th position
    for i in range(3):
        fig.add_subplot(rows, columns, i+5)

        # showing text
        color = CORRECT_COLOR if i == sample["correct_answer"] else INCORRECT_COLOR
        content = sample[f"answer_candidate_{i}_text"]
        plt.title(f"Candidate {i+1}")
        txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                        ha="center", va="top", color=color)
        txt._get_wrap_line_width = lambda: 300.
        plt.axis('off')

    # save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
    os.makedirs(args.output_dir, exist_ok=True)
    plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')
    # plt.show()
    plt.close(fig)
