In [11]:
# batch = next(iter(dataloader))
# image, tokens, (image_path, image_id) = batch
# image, tokens = to_device(image), to_device(tokens)
# image.shape, image.device
# showtell_core = to_device(showtell_core)
# with torch.no_grad():
#     logits = showtell_core(image, teacher_forcing=False)
# out_tokens = logits.argmax(-1).detach().cpu().squeeze(0).numpy()
# logits.shape, out_tokens.shape
# out_tokens, tokens
# print(f'GT {vocab.decode_indexes(tokens.detach().cpu().squeeze(0).numpy())}')
# print(f'Pred {vocab.decode_indexes(out_tokens)}')
# temp = np.random.randint(low=0, high=924, size=(10, 12))
# list(map(vocab.decode_indexes, temp))

# Inference and Evaluation

In [1]:
import pickle

import pandas as pd
import numpy as np

import torch
from transformers import Dinov2Config
import pytorch_lightning as pl

import wandb

from dataset import ImageCaptionDataset, Vocab, random_sample, choose_index
from model import Dinov2Encoder, TextEncoder, ShowAndTell, Model, to_device, load_pretrained

In [2]:
with open('./coco-2014/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)
    
vocab.size 

924

In [3]:
import torch.utils
import torch.utils.data


dataset = ImageCaptionDataset(
    vocab=vocab,
    dataset_path="./coco-2014/dataset.json",
    sampling_fn=choose_index,
    **dict(index=0)    
)

dataloader = torch.utils.data.DataLoader(dataset)
dataset, len(dataset), len(dataloader)

(<dataset.ImageCaptionDataset at 0x7ec6f86f73d0>, 100, 100)

In [4]:
project_name = "ShowAndTell"
run_name = "overfit_batch_choose_index[extreme]"
trained_weights_path = f'./weights/{run_name}-{project_name}.pth'

In [5]:
config = Dinov2Config(patch_size=14)
image_encoder = Dinov2Encoder(
    config=config, dinov2_weights_path="./weights/dinov2-base-weights.pth", freeze=True
)
text_encoder = TextEncoder(vocab_size=vocab.size)

showtell_core = ShowAndTell(
    vocab,
    image_encoder,
    text_encoder,
)
showtell_core = to_device(showtell_core)
load_pretrained(showtell_core, trained_weights_path)

model = Model(vocab=vocab, showtell_core=showtell_core)

In [6]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=[0],
    overfit_batches=1,
    max_epochs=100,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/harsh/anaconda3/envs/DL/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.


In [8]:
trainer.validate(model, dataloaders=dataloader)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/harsh/anaconda3/envs/DL/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.11604510992765427, 'bleu_score': 1.0}]

In [10]:
preds = trainer.predict(model, dataloaders=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/harsh/anaconda3/envs/DL/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [15]:
vocab.decode_indexes()

<bound method Vocab.encode_document of <dataset.Vocab object at 0x7ec6f86f7490>>

In [17]:
preds_gt = []
for pred, batch in map(preds, dataset):
    preds_gt.append((pred[0], vocab.decode_indexes(batch[1])))      
preds_gt[:5]

TypeError: 'list' object is not callable