# Setup

In [None]:
device = 'cuda:0'

# Load annotated data

This should let us check how good the captioner is relative to what people say.

In [None]:
import csv
import pathlib

from torch.utils import data
from torchvision import datasets, transforms


class Dataset(data.Dataset):

    def __init__(self, root, annotation_csv_file, **kwargs):
        self.image_folder = datasets.ImageFolder(root, **kwargs)

        with pathlib.Path(annotation_csv_file).open('r') as handle:
            rows = tuple(csv.DictReader(handle))

        self.annotations = {}
        for row in rows:
            ids = tuple(row['Input.image_url_1'].split('/')[-2:])
            self.annotations[ids] = row['Answer.summary']

    def __getitem__(self, index):
        image, _ = self.image_folder[index]
        sample_id, _ = self.image_folder.samples[index]
        sample_ids = tuple(sample_id.split('/')[-2:])
        category = sample_ids[0]
        return image, self.annotations[sample_ids], category

    def __len__(self):
        return len(self.image_folder)

dataset = Dataset(
    'Places_label',
    'Places_label.csv',
    transform=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]))

# Load pretrained captioner

Uncomment the below if for some reason we delete the model. It'll download it from Drive.

In [None]:
!pip install gdown
import gdown

gdown.download('https://drive.google.com/uc?id=1FYZ446OPEqhe-uLkgyVICjD_3-N3IZn1','model.pth.tar',quiet=False)

gdown.download('https://drive.google.com/uc?id=1bt_TmTC_rUcss2MJsG_C_6DtwEttRVKc','wordmap.json')

In [None]:
import torch

model = torch.load('model.pth.tar')

encoder = model['encoder']
encoder.eval()

decoder = model['decoder']
decoder.eval()

In [None]:
import json

with open('wordmap.json', 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

The decoder class is clunky as hell, so we just copy paste the eval code below and modify it to produce captions.

In [None]:
import torch.utils.data
import torch.nn.functional as F
from tqdm import tqdm


MEAN = torch.tensor((0.485, 0.456, 0.406), device=device).view(1, 3, 1, 1)
STD = torch.tensor((0.229, 0.224, 0.225), device=device).view(1, 3, 1, 1)


def evaluate(dataset, beam_size=1):
    """
    Evaluation
    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU-4 score
    """
    # DataLoader
    loader = torch.utils.data.DataLoader(dataset, batch_size=1)

    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    hypotheses = list()

    # For each image
    for i, (image, *_) in enumerate(tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):

        k = beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Easier to normalize here.
        image = (image - MEAN) / STD
        
        # Encode
        encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(3)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

            awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

            gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe

            h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

            scores = decoder.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words // vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)

            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != word_map['<end>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # Hypotheses
        hypothesis = [
            rev_word_map[w]
            for w in seq
            if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}
        ]
        hypotheses.append(hypothesis)

    return hypotheses

captions = evaluate(dataset, beam_size=5)

In [None]:
index = 5
print(captions[index])
transforms.ToPILImage()(dataset[index][0])

In [None]:
captions_subset = captions 

In [None]:
import csv

rows = [('caption')]
for i in range(len(captions)):

    sample = captions[i]
    sentence = ' '.join(sample)
    print(sentence)
    captions[i] = (sentence)
   # rows.append((sample.layer, str(sample.unit), predictions[index]))

with open('places-dataset-captions.csv', 'w') as handle:
    writer = csv.writer(handle)
    writer.writerows(captions)

In [None]:
for i in range(len(captions)):
    sample = captions[i]
    sentence = ' '.join(sample)
    print(sentence)
    

# Caption Places365

Cool, now let's just, you know, caption all of Places365.

In [None]:
dataset = datasets.ImageFolder('/raid/lingo/dez/data/places365/files/train',
                               transform=transforms.Compose([
                                   transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                               ]))
captions = evaluate(dataset, beam_size=5)

In [None]:
import csv

rows = [('caption')]
for index in range(len(captions)):
    sample = captions[index]
    rows.append(sample)
   # rows.append((sample.layer, str(sample.unit), predictions[index]))

with open('places-dataset-captions.csv', 'w') as handle:
    writer = csv.writer(handle)
    writer.writerows(captions)

# Load BigGAN

In [None]:
import pathlib
import re

from PIL import Image
from torch.utils import data
from torchvision import transforms
from tqdm.auto import tqdm

pattern = re.compile(r'cat(\d+)im(\d+)')

class BigGANDataset(data.Dataset):
    
    def __init__(self, root, transform=None):
        self.transform = transform
        
        self.images = []
        self.targets = []
        self.idx = []
        for file in tqdm(tuple(pathlib.Path(root).iterdir())):
            self.images.append(Image.open(file))
            
            target, idx = pattern.match(file.name).groups()
            self.targets.append(int(target))
            self.idx.append(int(idx))

    def __getitem__(self, index):
        image = self.images[index]
        target = self.targets[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, target

    def __len__(self):
        return len(self.images)

generated = BigGANDataset('BIGGANplaces',
                         transform=transforms.Compose([
                             transforms.Resize(256),
                             transforms.CenterCrop(224),
                             transforms.ToTensor(),
                         ]))

In [None]:
captions = evaluate(generated, beam_size=5)

In [None]:
index = 1
print(' '.join(captions[index]))
generated.images[index]

In [None]:
import csv

with pathlib.Path('BIGGANplaces-captions.csv').open('w') as handle:
    writer = csv.writer(handle)
    writer.writerow(('id', 'cat', 'im', 'caption'))
    for index in range(len(generated)):
        writer.writerow((
            f'cat{generated.targets[index]}im{generated.idx[index]}',
            str(generated.targets[index]),
            str(generated.idx[index]),
            ' '.join(captions[index]),
        ))