In [1]:
#import libraries
import argparse
import importlib
import torch
import cv2
import os
import pandas as pd
import h5py as h5
import evaluate

from matplotlib import pyplot as plt

from src.common.configuration import get_dataset_configuration, get_model_configuration
from src.datasets.comics_dialogue_generation import ComicsDialogueGenerationDataset
# from src.datasets.comics_images_Sim_CLR_text import ComicsImageTextDataset, create_test_dataset
from src.datasets.comics_dialogue_generation_1_description_panel import create_test_dataset
from src.models.dialogue_generation_vlt5 import DialogueGenerationVLT5Model
from src.tokenizers.vlt5_tokenizers import VLT5TokenizerFast
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def parse_arguments():
    parser = argparse.ArgumentParser(description='Plotting script')

    parser.add_argument('--model', type=str, default="dialogue_generation_vlt5_inf",
                        help='Model to run')
    parser.add_argument('--load_cloze_checkpoint', type=str,  default="/home/jlafuente/Comics_dialogs_generation/runs/DialogueGenerationVLT5Model_comics_dialogue_generation_1_description_panel_2024-02-09_00:43:43/models/epoch_7.pt",
                        help='Path to text cloze model checkpoint')
    parser.add_argument('--dataset_config', type=str, default="comics_dialogue_generation_textract_Blip2",
                        help='Dataset config to use')
    parser.add_argument('--dataset_dir', type=str, default="/data/data/datasets/COMICS",
                        help='Dataset directory path')
    parser.add_argument('--output_dir', type=str, default="plots_textract_gen/",
                        help='Output directory path')
    parser.add_argument('--sample_id', type=int, default=275,
                        help='Sample id to plot')
    parser.add_argument('--seed', type=int, default=4,
                        help='Random seed')

    return parser.parse_args("")

In [3]:
def load_checkpoint(checkpoint_path, model):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    model.load_checkpoint(checkpoint["model_state_dict"])
    model.eval()
    return model

# Loading the model and dataset

In [4]:
args = parse_arguments()
torch.manual_seed(args.seed)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


### Model

In [5]:
model_text_cloze_config = get_model_configuration(args.model)

In [6]:
ModelClass = getattr(importlib.import_module(
    f"src.models.{args.model}"), model_text_cloze_config.classname)
model_text_cloze = ModelClass(model_text_cloze_config, device).to(device)
load_checkpoint(args.load_cloze_checkpoint, model_text_cloze)

DialogueGenerationVLT5Model(
  (shared): Embedding(32200, 768)
  (encoder): JointEncoder(
    (embed_tokens): Embedding(32200, 768)
    (visual_embedding): VisualEmbedding(
      (feat_embedding): Sequential(
        (0): Linear(in_features=2048, out_features=768, bias=True)
        (1): T5LayerNorm()
      )
      (absolute_vis_pos_embedding): Sequential(
        (0): Linear(in_features=5, out_features=768, bias=True)
        (1): T5LayerNorm()
      )
      (obj_order_embedding): Embedding(32200, 768)
      (img_order_embedding): Embedding(4, 768)
    )
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=

### Loading the dataset

In [7]:
tokenizer = VLT5TokenizerFast.from_pretrained(
                model_text_cloze_config.backbone,
                max_length=model_text_cloze_config.max_text_length,
                do_lower_case=model_text_cloze_config.do_lower_case,
            )
dataset_config = get_dataset_configuration(args.dataset_config)
dataset_config["test"] = True
df, dataset = create_test_dataset(args.dataset_dir, device, dataset_config, tokenizer)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'VLT5TokenizerFast'.
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [8]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

### Example of dialogue generation 
Without candidates in the encoder inpout

In [9]:
model_text_cloze.eval()
with torch.no_grad():
    for i, sample_data in enumerate(dataloader):
        args.sample_id = sample_data["sample_id"]
    
        sample_data = {key: value.type(torch.float32) if value.dtype == torch.float64 else value for key, value in sample_data.items() if isinstance(value, torch.Tensor)}

        output = model_text_cloze.run(**sample_data)
        prediction_text_cloze = tokenizer.decode(output["prediction"][0], skip_special_tokens=True)
        print(f"Prediction: {prediction_text_cloze}")
        if i == 1:
            break

Prediction: ah! there he goes!
Prediction: i'll have to, dick. i'll have to. i '


In [10]:
df.iloc[args.sample_id]

Unnamed: 0,book_id,page_id,context_panel_0_id,context_panel_1_id,context_panel_2_id,answer_panel_id,context_text_0_0,context_text_0_1,context_text_0_2,context_text_1_0,context_text_1_1,context_text_1_2,context_text_2_0,context_text_2_1,context_text_2_2,answer_candidate_0_text,answer_candidate_1_text,answer_candidate_2_text,correct_answer
1,3451,6,1,2,3,4,pepper sprinkled on the trail! the hound's no ...,alertly watching for footprints and other sign...,ow! my foot!,ouch! doggoned thing drove right through my sole!,"why, the trail is studded with 'em!",,"i'll have to drop out, dick.. can't keep up to...","tough luck, simba. i'll push on. tell the othe...",,championship and prog ! i don ' t get it ! wou...,speaking of the rifle look its gone !,"marked repair wheel boys , and even the fire d...",0


In [11]:
sample = dict(df.iloc[args.sample_id])
book_id = int(sample["book_id"])
page_id = int(sample["page_id"])
target_text = sample[f"answer_candidate_{int(sample['correct_answer'])}_text"].iloc[0]
print(f"Book id: {book_id}")
print(f"Page id: {page_id}")
print(f"Target text: {target_text}")

Book id: 3451
Page id: 6
Target text: championship and prog ! i don ' t get it ! would cooperate criminals use such sometime ?


### Example of dialogue generation 
Without candidates in the encoder inpout

In [28]:
CORRECT_COLOR = (0, 1, 0)
INCORRECT_COLOR = (1, 0, 0)

In [None]:
# Plot sample
fig = plt.figure(figsize=(16, 8))

# setting values to rows and column variables
rows = 2
columns = 4
# reading images
Image1 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_0_id"])}.jpg')
Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
Image2 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_1_id"])}.jpg')
Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
Image3 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_2_id"])}.jpg')
Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
Image4 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["answer_panel_id"])}.jpg')
Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

# Adds a subplot at the 1st position
fig.add_subplot(rows, columns, 1)

# showing image
plt.imshow(Image1)
plt.axis('off')
plt.title("Context panel 1")

# Adds a subplot at the 2nd position
fig.add_subplot(rows, columns, 2)

# showing image
plt.imshow(Image2)
plt.axis('off')
plt.title("Context panel 2")

# Adds a subplot at the 3rd position
fig.add_subplot(rows, columns, 3)

# showing image
plt.imshow(Image3)
plt.axis('off')
plt.title("Context panel 3")

# Adds a subplot at the 4th position
fig.add_subplot(rows, columns, 4)

# showing image
plt.imshow(Image4)
plt.axis('off')
plt.title("Answer panel")

# Adding a subplot at the 5th to 7th position
for i in range(1, 5):
    fig.add_subplot(rows, columns, i+4)
    if i <= 3:
        # showing text
        color = CORRECT_COLOR if i - 1 == int(sample["correct_answer"]) else INCORRECT_COLOR
        content = sample[f"answer_candidate_{i-1}_text"].iloc[0]
        plt.title(f"Candidate {i}")
        txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                        ha="center", va="top", color=color)
        txt._get_wrap_line_width = lambda: 300.
        plt.axis('off')
    else:
        # showing text
        plt.title("Predicted text")
        txt = plt.text(0.5, 0.5, prediction_text_cloze, fontsize=14, wrap=True,
                        ha="center", va="top", color=(0, 0, 0))
        txt._get_wrap_line_width = lambda: 300.
        plt.axis('off')

os.makedirs(args.output_dir, exist_ok=True)
plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')

### Iterating over the test set 
Generating an image of the prediction for each 25 examples

In [None]:
# Path to save the plots
args.output_dir = "plots_easy_vlt5_textract_blip2_gen/"

In [None]:
from tqdm import tqdm
CORRECT_COLOR = (0, 1, 0)
INCORRECT_COLOR = (1, 0, 0)

model_text_cloze.eval()
model_text_cloze.to(device)
for sample_data in tqdm(dataloader):
    args.sample_id = int(sample_data["sample_id"][0])
    # Make the plot every 25 samples
    if args.sample_id % 25 == 0:
        sample_data = {key: value.type(torch.float32) if value.dtype == torch.float64 else value for key, value in sample_data.items() if isinstance(value, torch.Tensor)}

        output = model_text_cloze.run(**sample_data)
        prediction_text_cloze = tokenizer.decode(output["prediction"][0], skip_special_tokens=True)


        sample = dict(df.iloc[args.sample_id])
        book_id = int(sample["book_id"])
        page_id = int(sample["page_id"])
        target_text = sample[f"answer_candidate_{int(sample['correct_answer'])}_text"]

        # Plot sample
        fig = plt.figure(figsize=(16, 8))

        # setting values to rows and column variables
        rows = 2
        columns = 4
        # reading images
        Image1 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_0_id"])}.jpg')
        Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
        Image2 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_1_id"])}.jpg')
        Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
        Image3 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_2_id"])}.jpg')
        Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
        Image4 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["answer_panel_id"])}.jpg')
        Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

        # Adds a subplot at the 1st position
        fig.add_subplot(rows, columns, 1)

        # showing image
        plt.imshow(Image1)
        plt.axis('off')
        plt.title("Context panel 1")

        # Adds a subplot at the 2nd position
        fig.add_subplot(rows, columns, 2)

        # showing image
        plt.imshow(Image2)
        plt.axis('off')
        plt.title("Context panel 2")

        # Adds a subplot at the 3rd position
        fig.add_subplot(rows, columns, 3)

        # showing image
        plt.imshow(Image3)
        plt.axis('off')
        plt.title("Context panel 3")

        # Adds a subplot at the 4th position
        fig.add_subplot(rows, columns, 4)

        # showing image
        plt.imshow(Image4)
        plt.axis('off')
        plt.title("Answer panel")

        # Adding a subplot at the 5th to 7th position
        for i in range(1, 5):
            fig.add_subplot(rows, columns, i+4)
            if i <= 3:
                # showing text
                color = CORRECT_COLOR if i - 1 == int(sample["correct_answer"]) else INCORRECT_COLOR
                content = sample[f"answer_candidate_{i-1}_text"]
                plt.title(f"Candidate {i}")
                txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                                ha="center", va="top", color=color)
                txt._get_wrap_line_width = lambda: 300.
                plt.axis('off')
            else:
                # showing text
                plt.title("Predicted text")
                txt = plt.text(0.5, 0.5, prediction_text_cloze, fontsize=14, wrap=True,
                                ha="center", va="top", color=(0, 0, 0))
                txt._get_wrap_line_width = lambda: 300.
                plt.axis('off')

        # save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
        os.makedirs(args.output_dir, exist_ok=True)
        plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')
        # save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
        os.makedirs(args.output_dir, exist_ok=True)
        plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')
        # plt.show()
        plt.close(fig)


### Generating an image of the prediction for each 25 examples but only showing target and prediction
Meant to be used with a dataset class that does not provide the posible answers in the encoder input.

In [12]:
args.output_dir = "plots_easy_vlt5_textract_blip2_gen_only_pred_and_target/"

#### Showing one example

In [None]:
# Plot sample
fig = plt.figure(figsize=(16, 8))

# setting values to rows and column variables
rows = 2
columns = 4
# reading images
Image1 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_0_id"])}.jpg')
Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
Image2 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_1_id"])}.jpg')
Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
Image3 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_2_id"])}.jpg')
Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
Image4 = cv2.imread(
    f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["answer_panel_id"])}.jpg')
Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

# Adds a subplot at the 1st position
fig.add_subplot(rows, columns, 1)

# showing image
plt.imshow(Image1)
plt.axis('off')
plt.title("Context panel 1")

# Adds a subplot at the 2nd position
fig.add_subplot(rows, columns, 2)

# showing image
plt.imshow(Image2)
plt.axis('off')
plt.title("Context panel 2")

# Adds a subplot at the 3rd position
fig.add_subplot(rows, columns, 3)

# showing image
plt.imshow(Image3)
plt.axis('off')
plt.title("Context panel 3")

# Adds a subplot at the 4th position
fig.add_subplot(rows, columns, 4)

# showing image
plt.imshow(Image4)
plt.axis('off')
plt.title("Answer panel")


# Showing target
fig.add_subplot(rows, columns, 5)
content = sample[f"answer_candidate_{int(sample['correct_answer'])}_text"].iloc[0]
plt.title(f"Target text")
txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                ha="center", va="top", color=(0, 0, 0))
txt._get_wrap_line_width = lambda: 300.
plt.axis('off')

# showing predicted text
fig.add_subplot(rows, columns, 7)
plt.title("Predicted text")
txt = plt.text(0.5, 0.5, prediction_text_cloze, fontsize=14, wrap=True,
                ha="center", va="top", color=(0, 0, 0))
txt._get_wrap_line_width = lambda: 300.
plt.axis('off')

os.makedirs(args.output_dir, exist_ok=True)
plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')

#### Iterating over the test set 

In [15]:
from tqdm import tqdm
model_text_cloze.eval()
model_text_cloze.to(device)
for sample_data in tqdm(dataloader):
    args.sample_id = int(sample_data["sample_id"][0])
    # Make the plot every 25 samples
    if args.sample_id % 25 == 0:
        sample_data = {key: value.type(torch.float32) if value.dtype == torch.float64 else value for key, value in sample_data.items() if isinstance(value, torch.Tensor)}

        output = model_text_cloze.run(**sample_data)
        prediction_text_cloze = tokenizer.decode(output["prediction"][0], skip_special_tokens=True)


        sample = dict(df.iloc[args.sample_id])
        book_id = int(sample["book_id"])
        page_id = int(sample["page_id"])
        target_text = sample[f"answer_candidate_{int(sample['correct_answer'])}_text"]

        # Plot sample
        fig = plt.figure(figsize=(16, 8))

        # setting values to rows and column variables
        rows = 2
        columns = 4
        # reading images
        Image1 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_0_id"])}.jpg')
        Image1 = cv2.cvtColor(Image1, cv2.COLOR_BGR2RGB)
        Image2 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_1_id"])}.jpg')
        Image2 = cv2.cvtColor(Image2, cv2.COLOR_BGR2RGB)
        Image3 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["context_panel_2_id"])}.jpg')
        Image3 = cv2.cvtColor(Image3, cv2.COLOR_BGR2RGB)
        Image4 = cv2.imread(
            f'{args.dataset_dir}/panels/{book_id}/{page_id}_{int(sample["answer_panel_id"])}.jpg')
        Image4 = cv2.cvtColor(Image4, cv2.COLOR_BGR2RGB)

        # Adds a subplot at the 1st position
        fig.add_subplot(rows, columns, 1)

        # showing image
        plt.imshow(Image1)
        plt.axis('off')
        plt.title("Context panel 1")

        # Adds a subplot at the 2nd position
        fig.add_subplot(rows, columns, 2)

        # showing image
        plt.imshow(Image2)
        plt.axis('off')
        plt.title("Context panel 2")

        # Adds a subplot at the 3rd position
        fig.add_subplot(rows, columns, 3)

        # showing image
        plt.imshow(Image3)
        plt.axis('off')
        plt.title("Context panel 3")

        # Adds a subplot at the 4th position
        fig.add_subplot(rows, columns, 4)

        # showing image
        plt.imshow(Image4)
        plt.axis('off')
        plt.title("Answer panel")

        # Showing target
        fig.add_subplot(rows, columns, 5)
        content = sample[f"answer_candidate_{int(sample['correct_answer'])}_text"]
        plt.title(f"Target text")
        txt = plt.text(0.5, 0.5, content, fontsize=14, wrap=True,
                        ha="center", va="top")
        txt._get_wrap_line_width = lambda: 300.
        plt.axis('off')

        # showing predicted text
        fig.add_subplot(rows, columns, 7)
        plt.title("Predicted text")
        txt = plt.text(0.5, 0.5, prediction_text_cloze, fontsize=14, wrap=True,
                        ha="center", va="top", color=(0, 0, 0))
        txt._get_wrap_line_width = lambda: 300.
        plt.axis('off')

        # save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
        os.makedirs(args.output_dir, exist_ok=True)
        plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')
        # save the figure with the name of the sample and difficulty and metrics rounded to 2 decimal places
        os.makedirs(args.output_dir, exist_ok=True)
        plt.savefig(f'{args.output_dir}/{args.dataset_config.split("_")[-1]}_{args.sample_id}.png')
        # plt.show()
        plt.close(fig)


  0%|          | 0/11909 [00:00<?, ?it/s]

100%|██████████| 11909/11909 [10:07<00:00, 19.59it/s]
