## Load corpus

In [42]:
import torch

from utils.data import load_captions

captions = load_captions()
corpus = captions['caption']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Create vocabulary

In [2]:
from utils.vocab import Vocabulary, get_tokens

vocabulary = Vocabulary(tokens=get_tokens(corpus))

## Load image embeddings

In [33]:
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl


class BasicModel(pl.LightningModule):
    def __init__(self, vocab, emb_size=64, hid_size=128):
        super().__init__()
        self.inp_voc = vocab
        self.emb_out = nn.Embedding(len(self.inp_voc), emb_size)

        self.dec0 = nn.GRUCell(emb_size, hid_size)
        self.logits = nn.Linear(hid_size, len(self.inp_voc))

    def forward(self, embeddings, captions):
        """ Apply model in training mode """
        return self.decode(embeddings, captions)

    def decode(self, initial_state, out_tokens, **flags):
        """ Iterate over reference tokens (out_tokens) with decode_step """
        batch_size = out_tokens.shape[0]
        state = initial_state

        # initial logits: always predict BOS
        onehot_bos = F.one_hot(torch.full([batch_size], self.inp_voc.bos_ix, dtype=torch.int64),
                               num_classes=len(self.inp_voc)).to(device=out_tokens.device)
        first_logits = torch.log(onehot_bos.to(torch.float32) + 1e-9)

        logits_sequence = [first_logits]
        for i in range(out_tokens.shape[1] - 1):
            state, logits = self.decode_step(state, out_tokens[:, i])
            logits_sequence.append(logits)
        return torch.stack(logits_sequence, dim=1)

    def decode_step(self, prev_state, prev_tokens):
        """
        Takes previous decoder state and tokens, returns new state and logits for next tokens
        :param prev_state: a list of previous decoder state tensors, same as returned by encode(...)
        :param prev_tokens: previous output tokens, an int vector of [batch_size]
        :return: a list of next decoder state tensors, a tensor of logits [batch, len(inp_voc)]
        """
        embedded_tokens = self.emb_out(prev_tokens)  # batch_size X emb_size
        new_dec_state = self.dec0(embedded_tokens, prev_state)
        output_logits = self.logits(new_dec_state)

        return new_dec_state, output_logits

    def decode_inference(self, initial_state, max_length):
        """ Generate solutions from model (greedy version) """
        batch_size = len(initial_state)
        state = initial_state
        outputs = [torch.full([batch_size], self.inp_voc.bos_ix, dtype=torch.int64,
                              device=device)]
        all_states = [initial_state]

        for i in range(max_length):
            state, logits = self.decode_step(state, outputs[-1])
            outputs.append(logits.argmax(dim=-1))
            all_states.append(state)
        return torch.stack(outputs, dim=1), all_states

    def predict(self, embeddings, max_length=20):
        captions, _ = self.create_captions(embeddings, max_length)
        return captions

    def create_captions(self, embeddings, max_length):
        out_ids, states = self.decode_inference(embeddings, max_length)
        return self.inp_voc.to_lines(out_ids.cpu().numpy()), states

    # PyTorch Lightning methods
    def training_step(self, batch, *args):
        # training_step defined the train loop.
        # It is independent of forward
        embeddings, captions = batch
        outputs = self.forward(embeddings, captions)
        outputs = outputs.view(-1, len(self.inp_voc))
        captions = captions.view(-1)

        loss = F.cross_entropy(outputs, captions)

        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

hid_size = dataset.get_embedding_size()
model = BasicModel(vocabulary, hid_size=hid_size)
model = model.to(device)

In [34]:
image_emb = dataset[0][0].unsqueeze(0)
gen_captions = model.decode_inference(image_emb, 10)
gen_captions

(tensor([[9428, 2824, 1384, 1384, 5720, 5000, 5940,  773, 9165, 8342, 5876]]),
 [tensor([[4.5942e-01, 1.3169e+00, 2.1997e+00, 5.1553e-01, 1.3884e+00, 1.1270e-01,
           2.2398e+00, 1.3023e+00, 3.6841e-01, 1.0618e+00, 3.1622e-01, 9.8743e-01,
           2.0474e-01, 1.1671e+00, 5.5059e-01, 8.9741e-01, 1.6429e+00, 8.6365e-01,
           1.5410e+00, 1.0466e+00, 1.2675e+00, 2.3765e+00, 2.9265e-01, 5.4496e-01,
           1.2527e+00, 5.2479e-01, 1.3628e+00, 6.1880e-01, 1.0940e-01, 1.1321e-01,
           1.0457e+00, 1.0767e+00, 1.1283e-01, 1.0297e+00, 7.6173e-01, 9.9569e-01,
           2.8583e+00, 2.6657e+00, 5.5009e-01, 1.4049e+00, 9.2663e-01, 2.8282e-01,
           2.5377e-01, 6.1729e-01, 4.1290e-01, 2.8497e-01, 7.4697e-01, 5.5146e-01,
           7.7172e-01, 6.5697e-01, 2.4369e+00, 6.5214e-01, 2.0131e+00, 2.2778e+00,
           3.6309e-01, 1.2101e-01, 6.2812e-01, 7.6428e-02, 5.7939e-01, 4.5406e-01,
           5.4265e-01, 1.9778e+00, 3.5323e-01, 1.0189e+00, 3.1592e-01, 6.7330e-01,
        

In [5]:
trainer = pl.Trainer()
trainer.fit(model, dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name    | Type      | Params
--------------------------------------
0 | emb_out | Embedding | 603 K 
1 | dec0    | GRUCell   | 887 K 
2 | logits  | Linear    | 4.8 M 
--------------------------------------
6.3 M     Trainable params
0         Non-trainable params
6.3 M     Total params
25.318    Total estimated model params size (MB)
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: -1it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [91]:
folds_mapping = pd.read_csv(FOLDS_DATA_PATH, index_col='image_id')

In [93]:
fold = 1
images_ids = {
    'train': folds_mapping.loc[folds_mapping['kfold'] != fold].index.tolist(),
    'valid': folds_mapping.loc[folds_mapping['kfold'] == fold].index.tolist(),
    'test': pd.read_csv(ORIGINAL_TEST_PATH, sep='\n', names=['image_id'])['image_id'].tolist()
}

In [99]:
set_captions = captions[captions['image_id'].isin(images_ids['valid'])]
set_captions

Unnamed: 0,image_id,caption,caption_number
15,1003163366_44323f5815.jpg,a man lays on a bench while his dog sits by him,0
16,1003163366_44323f5815.jpg,a man lays on the bench to which a white dog i...,1
17,1003163366_44323f5815.jpg,a man sleeping on a bench outside with a white...,2
18,1003163366_44323f5815.jpg,a shirtless man lies on a park bench with his dog,3
19,1003163366_44323f5815.jpg,man laying on bench holding leash of dog sitti...,4
...,...,...,...
40435,990890291_afc72be141.jpg,a man does a wheelie on his bicycle on the sid...,0
40436,990890291_afc72be141.jpg,a man is doing a wheelie on a mountain bike,1
40437,990890291_afc72be141.jpg,a man on a bicycle is on only the back wheel,2
40438,990890291_afc72be141.jpg,asian man in orange hat is popping a wheelie o...,3


In [101]:
set_captions.reset_index(drop=True)

Unnamed: 0,image_id,caption,caption_number
0,1003163366_44323f5815.jpg,a man lays on a bench while his dog sits by him,0
1,1003163366_44323f5815.jpg,a man lays on the bench to which a white dog i...,1
2,1003163366_44323f5815.jpg,a man sleeping on a bench outside with a white...,2
3,1003163366_44323f5815.jpg,a shirtless man lies on a park bench with his dog,3
4,1003163366_44323f5815.jpg,man laying on bench holding leash of dog sitti...,4
...,...,...,...
6995,990890291_afc72be141.jpg,a man does a wheelie on his bicycle on the sid...,0
6996,990890291_afc72be141.jpg,a man is doing a wheelie on a mountain bike,1
6997,990890291_afc72be141.jpg,a man on a bicycle is on only the back wheel,2
6998,990890291_afc72be141.jpg,asian man in orange hat is popping a wheelie o...,3


In [85]:
train_loader, valid_loader, test_loader = create_dataloaders(fold=1)