In [1]:
from paths import paths
import torch
import os
import pickle
import matplotlib.pyplot as plt
from tokenizers import Tokenizer
import sys
from transformer_components import (
    TransformerEncoderDecoder,
    get_causal_mask,
)
from image_captioner import ImageEncoder, CaptionDecoder
import yaml
from coco_loader import get_coco_loader, ImgFirstDataset, decode_predictions
from image_transforms import image_transform_index
from PIL import Image
import matplotlib.pyplot as plt
import evaluate
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint_path = os.path.join(paths["captioner_checkpoint"], "checkpoint1.pt")

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, pickle_module=pickle)
else:
    print(f"The path {checkpoint_path} does not exist!")
    sys.exit(1)

In [3]:
tokenizer = Tokenizer.from_file(paths["tokenizer"])
with open(paths["config"], "r") as f:
    config = yaml.safe_load(f)
# Set device.
if "device" in config:
    device = config["device"]
else:
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
print(f"You are using {device}.")

You are using cuda.


In [4]:
SOS_IDX = tokenizer.token_to_id("<SOS>")
EOS_IDX = tokenizer.token_to_id("<EOS>")
PAD_IDX = tokenizer.token_to_id("<PAD>")

BATCH_SIZE = config["batch_size"]
NUM_WORKERS = config["num_workers"]

VOCAB_SIZE = config["vocab_size"]
CONTEXT_SIZE = config["context_size"]
PATCH_SIZE = config["patch_size"]
IMAGE_SIZE = config["image_size"]

LENGTH_ALPHA = float(config["length_alpha"])
NUM_BEAMS = config["num_beams"]

image_encoder_config = config["image_encoder"]
caption_decoder_config = config["caption_decoder"]

In [5]:
loader_for_metrics = get_coco_loader(
    "val", BATCH_SIZE, image_transform_index["val"], NUM_WORKERS, mode="image_first"
)

In [6]:
# Initialize model.
model = TransformerEncoderDecoder(
    ImageEncoder(IMAGE_SIZE, PATCH_SIZE, image_encoder_config),
    CaptionDecoder(VOCAB_SIZE, CONTEXT_SIZE, caption_decoder_config),
).to(device)

# model.load_state_dict(checkpoint["model_state_dict"])

In [7]:
bleu = evaluate.load("bleu")

In [None]:
metric_batches = tqdm(
    loader_for_metrics,
    desc=f"Metrics for epoch {checkpoint["history"]["epochs_completed"]}:",
    leave=True,
)
model.eval()
with torch.no_grad():
    for img, references, _ in metric_batches:
        img = img.to(device)
        pred = model.generate(
            img, None, NUM_BEAMS, CONTEXT_SIZE, LENGTH_ALPHA, SOS_IDX, PAD_IDX, EOS_IDX
        )
        decoded_preds = decode_predictions(pred, tokenizer)
        decoded_refs = []
        for ref in references:
            decoded_ref = tokenizer.decode_batch(ref)
            decoded_refs.append(decoded_ref)
        result = bleu.compute(predictions=decoded_preds, references=decoded_refs)
        metric_batches.set_postfix({"bleu": result["bleu"]})

Metrics for epoch 1::   0%|          | 0/40 [00:00<?, ?it/s]

Metrics for epoch 1:: 100%|██████████| 40/40 [01:55<00:00,  2.88s/it, bleu=0]
