In [1]:
from debug_flickr8k import Flickr8kDataset
from debug_glove import embedding_matrix_creator
from torchvision import transforms
from resnet101_attention import Captioner as resnet101_attention
from resnet50_attention import Captioner as resnet50_attention
from resnext50_attention import Captioner as resnext50_attention
from incepv3_attention import Captioner as incepv3_attention

import itertools
import numpy as np
import torch
from torch.utils.data import DataLoader

In [2]:
DATASET_BASE_PATH = '../../data/flickr8k/'
BATCH_SIZE = 128
EMBEDDING = 300
NAME = "0503_1"
MODEL = "resnet101_attention"
MODEL_NAME = f'../../saved_models/{NAME}_m{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}'
device = "cuda"

In [3]:

train_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', device=device,
                            return_type='tensor',
                            load_img_to_memory=False)
vocab, word2idx, idx2word, max_len = vocab_set = train_set.get_vocab()
val_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='val', vocab_set=vocab_set, device=device,
                        return_type='corpus',
                        load_img_to_memory=False)
test_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='test', vocab_set=vocab_set, device=device,
                        return_type='corpus',
                        load_img_to_memory=False)
train_eval_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', vocab_set=vocab_set, device=device,
                                return_type='corpus',
                                load_img_to_memory=False)

# Dataset Transformation
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                        (0.229, 0.224, 0.225))
])
eval_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.CenterCrop(256),  # get 256x256 crop from random location
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                        (0.229, 0.224, 0.225))
])

train_set.transformations = train_transformations
val_set.transformations = eval_transformations
test_set.transformations = eval_transformations
train_eval_set.transformations = eval_transformations

vocab_size = len(vocab)


In [4]:
eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,
                        collate_fn=eval_collate_fn)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,
                        collate_fn=eval_collate_fn)
train_eval_loader = DataLoader(train_eval_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,
                            collate_fn=eval_collate_fn)

In [None]:
    # embedding_matrix = embedding_matrix_creator(embedding_dim=300, word2idx=word2idx, GLOVE_DIR='../../data/glove.6B/')

In [5]:
    model = resnet101_attention(encoded_image_size=14, encoder_dim=2048,
                            attention_dim=256, embed_dim=300, decoder_dim=256,
                            vocab_size=vocab_size,
                            # embedding_matrix=embedding_matrix, train_embd=False,
                            pretrained = False).to(device)

In [6]:
# MODEL_NAME
model.load_state_dict(torch.load(f"{MODEL_NAME}_0.pt")["state_dict"])

<All keys matched successfully>

In [20]:
def words_from_tensors_fn(idx2word, max_len=40, startseq='<start>', endseq='<end>'):
    def words_from_tensors(captions: np.array) -> list:
        """
        :param captions: [b, max_len]
        :return:
        """
        captoks = []
        for capidx in captions:
            # capidx = [1, max_len]
            captoks.append(list(itertools.takewhile(lambda word: word != endseq,
                                                    map(lambda idx: idx2word[idx], iter(capidx)))))
        return captoks

    return words_from_tensors

In [None]:
def accuracy_fn(ignore_value: int = 0):
    def accuracy_ignoring_value(source: torch.Tensor, target: torch.Tensor):
        mask = target != ignore_value
        return (source[mask] == target[mask]).sum().item() / mask.sum().item()

    return accuracy_ignoring_value

In [28]:
tensor_to_word_fn = words_from_tensors_fn(idx2word=idx2word)
acc_fn = accuracy_fn()
model.eval()
for idx, batch in enumerate(val_loader):
    image, captions, lengths = batch
    sampled_ids = model.sample(image, word2idx["<start>"])
    output = tensor_to_word_fn(sampled_ids)
    acc_fn()

KeyError: 'resembles'