In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.notebook import tqdm

# for plotting
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [2]:
from PIL import Image
from typing import Any, Callable, Optional, Tuple, List
from collections import defaultdict

class CocoCaptionDataset(torchvision.datasets.vision.VisionDataset):
        def __init__(self, root, annFile, caption_idx = 1, transform = None, target_transform = None, transforms = None):
            super().__init__(root, transforms, transform, target_transform)
            from pycocotools.coco import COCO
            assert caption_idx < 4
            
            self.coco = COCO(annFile)
            self.ids = list(sorted(self.coco.imgs.keys()))
            
            self.caption_idx = caption_idx
            self.vocab_to_idx, self. idx_to_vocab = self._build_vocab_to_idx_mapp()
            self.idx_captions = self._captions_to_idx()
        
        def _build_vocab_to_idx_mapp(self, num_words= 1000):
            
            vocab_count = defaultdict(int)
            for id in self.ids:
                annotations = self.coco.loadAnns(self.coco.getAnnIds(id))
                caption = annotations[self.caption_idx]['caption']
                
                for word in caption.lower().strip('.').split():
                    vocab_count[word] += 1
                    
            special_tokens = [('<NULL>', 0),('<START>', 0),('<END>', 0),('<UNK>', 0)]
            ordered = list(sorted(vocab_count.items(), key= lambda item: item[1], reverse=True))
            first1KTuple = sorted(ordered[:num_words-len(special_tokens)], key = lambda item: item[0])


            vocab_to_idx = {x[0] : i for x, i in zip(special_tokens + first1KTuple, range(num_words))}
            idx_to_vocab = [k for k, v in vocab_to_idx.items()]     
            return vocab_to_idx, idx_to_vocab
        
        def _captions_to_idx(self, fix_length= 16):
            print('Map caption words to idx!!')
            
            master = {}
            for id in tqdm(self.ids):
                annotations = self.coco.loadAnns(self.coco.getAnnIds(id))
                caption = annotations[self.caption_idx]['caption']
                caption = caption.lower().strip('.').split()
                idx = [self.vocab_to_idx['<START>']]
                for i in range(fix_length):
                    
                    if i < len(caption):
                        try:
                            idx.append(self.vocab_to_idx[caption[i]])
                        except:
                            idx.append(self.vocab_to_idx['<UNK>'])
                    elif i == len(caption):
                        idx.append(self.vocab_to_idx['<END>'])
                    
                    else:
                        idx.append(self.vocab_to_idx['<NULL>'])
                    
                    idx_tensor = torch.tensor(idx, dtype=torch.long)
                master[id] = idx_tensor
            return master

        def _load_image(self, id: int) -> Image.Image:
            path = self.coco.loadImgs(id)[0]["file_name"]
            return Image.open(os.path.join(self.root, path)).convert("RGB")

        def _load_target(self, id) -> List[str]:
            return self.idx_captions[id]

        def __getitem__(self, index: int) -> Tuple[Any, Any]:
            id = self.ids[index]
            image = self._load_image(id)
            target = self._load_target(id)

            if self.transforms is not None:
                image, target = self.transforms(image, target)

            return image, target

        def __len__(self) -> int:
            return len(self.ids)
        
        
        
class CocoCaptionsLT(pl.LightningDataModule):
    
    def __init__(self, folder_path, batch_size = 64, caption_idx = 1, transformer = None):
        '''
        Input:
          folder_path: A folder that contains both train and validation images and annotation (.json)
          caption_idx: Original file contain 4 or 5 captions per image. We will only pick one according to index.
          transformers: Pytorch data transforms that will used during preprocess data. 
                        **All data are prepocessed during set up phase. i.e. no transform are used during training**
        '''
        assert caption_idx < 4

        super().__init__()
        self.folder_path = folder_path
        self.batch_size = batch_size
        self.caption_idx = caption_idx
        self.transformer = transformer
        
        self.train_img_path = os.path.join(folder_path, 'train2017')
        self.train_caption_path = os.path.join(folder_path, 'annotations/captions_train2017.json')
        self.val_img_path = os.path.join(folder_path,'val2017')
        self.val_caption_path = os.path.join(folder_path,'annotations/captions_val2017.json')

        if self.transformer is None:
            self.transformer = transforms.Compose([
                                                    transforms.Resize(112),
                                                    transforms.CenterCrop(112),
                                                    transforms.ToTensor()
                                                    ])

        
    def prepare_data(self):
            
        self.train_dataset = CocoCaptionDataset(self.train_img_path, self.train_caption_path, transform = self.transformer)
        self.val_dataset = CocoCaptionDataset(self.val_img_path, self.val_caption_path, transform = self.transformer)
        
        self.vocab_to_idx = self.train_dataset.vocab_to_idx
        self.idx_to_vocab = self.train_dataset.idx_to_vocab
    
    def sample(self, num_samples= 4):
        idx = torch.randint(0, len(self.train_dataset), (num_samples,), dtype=torch.int)
   
        for i in idx:
            img, cap = self.train_dataset[i]
            plt.imshow(img.permute(1,2,0).numpy())
            plt.title(self._decode_captions(cap))
            plt.axis('off')
            plt.show()
    
    def _decode_captions(self, captions):
        """
        Decoding caption indexes into words.
        Inputs:
        - captions: Caption indexes in a tensor of shape (Nx)T.
        - idx_to_word: Mapping from the vocab index to word.
        Outputs:
        - decoded: A sentence (or a list of N sentences).
        """
        singleton = False
        if captions.ndim == 1:
            singleton = True
            captions = captions[None]
        decoded = []
        N, T = captions.shape
        for i in range(N):
            words = []
            for t in range(T):
                word = self.idx_to_vocab[captions[i, t]]
                if word != '<NULL>':
                    words.append(word)
                if word == '<END>':
                    break
            decoded.append(' '.join(words))
        if singleton:
            decoded = decoded[0]
        return decoded
    
    def train_dataloader(self):
        trainDataLoader = torch.utils.data.DataLoader(self.train_dataset, self.batch_size,
                                                      shuffle=True, num_workers=8, pin_memory=True)
        return trainDataLoader
    
    def val_dataloader(self):
        valDataLoader = torch.utils.data.DataLoader(self.val_dataset, self.batch_size,
                                                      shuffle=True, num_workers=8, pin_memory=True)
        return valDataLoader

In [3]:
class CaptioningRNN(pl.LightningModule):
    
    def __init__(self, img_size= 2048, hidden_size= 512, vocab_vec_size = 128, vocab_size= 1000):
        
        super().__init__()
        self.save_hyperparameters()
        
        model = torchvision.models.resnet50(pretrained=True)
        self.backBone = nn.Sequential(*list(model.children())[:-1]) # remove the last classifer
        
        # freeze backbone parameters
        for param in self.backBone.parameters():
            param.requires_grad = False

            
            
        self.feat_to_h0 = nn.Linear(img_size, hidden_size)
            
        self.cell = nn.LSTMCell(vocab_vec_size, hidden_size)
        
        self.wordEmbed = nn.Embedding(vocab_size, vocab_vec_size)
        
        self.cellOuput_to_vocab = nn.Linear(hidden_size, vocab_size)
         
        # TODO, information need to know when load data
        self.T = 16 # need to figure out when load the dataloader
        self.NULL = 0
        self.START = 1
        
            
    def _extract_image_features(self, x):   
        feat = self.backBone(x)
        feat = feat.squeeze()
        feat = F.normalize(feat, dim=1) # l2 normalize
        return feat
    
    def _loss(self, x, y):
        return F.cross_entropy(x.reshape(-1,x.shape[2]), y.reshape(-1,), ignore_index=self.NULL, reduction='sum') / x.shape[0]
    
    def forward(self, x):
        
        B = x.shape[0]
         
        img_feat = self._extract_image_features(x) # (B, img_size)
        
        # Initial hidden state as image features
        h = self.feat_to_h0(img_feat)
        
        # Initial cell state as zeros
        c = torch.zeros_like(h)
        
        # Initial start words
        words = self.wordEmbed(torch.full((B,), self.START, device=x.device))
        
        rawVocabScores = []
        
        for t in range(self.T):
            h, c = self.cell(words, (h, c))
            
            # Affine transform hidden state to vocab size
            vocabScore = self.cellOuput_to_vocab(h)
            
            # Grab the maximum idx of words
            predict = torch.argmax(vocabScore, dim=1)
            
            # saved for outputs
            rawVocabScores.append(vocabScore)            
            
            # update words for the next iteration
            words = self.wordEmbed(predict)
            
        rawVocabScores = torch.stack(rawVocabScores, dim=1)

        return rawVocabScores 
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        captions_in = y[:,:-1]
        captions_out = y[:,1:]
        
        # Grab embedding of ground truth words
        GTwords = self.wordEmbed(captions_in) #(B, sentence_length, word_vec_size)
  
        # Extract image features
        img_feat = self._extract_image_features(x) # (B, img_size)
        
        # Affine transform image features to match word embeding size
        h = self.feat_to_h0(img_feat)
        
        # initial cell state
        c = torch.zeros_like(h)
        
        output = []
        
        # Use image features as initial hidden and cell state
        for t in range(self.T):
            h, c = self.cell(GTwords[:,t,:], (h, c))
            output.append(h)
        
        output = torch.stack(output, dim=1)
        # Affine transform output to vocabulary size
        vocabScore = self.cellOuput_to_vocab(output) #(B, sentence_length, vocab_size)
        

        loss = self._loss(vocabScore, captions_out)
        self.log('train_loss', loss)
        return loss  
        
    def validation_step(self, batch, batch_idx):
        x, y = batch

        captions_out = y[:,1:]
        
        vocabScore = self(x) # calling forward here

        loss = self._loss(vocabScore, captions_out)
        self.log('val_loss', loss)
        return loss
    
    def sample(self, x):
        
        vocabScore = self(x) 
        
        predictWords = torch.argmax(vocabScore, dim=2)
        return predictWords
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)
    

In [4]:
folder_path = "/home/fred/datasets/coco/"
DataModule = CocoCaptionsLT(folder_path, batch_size = 1024)

In [5]:
model = CaptioningRNN()
trainer = pl.Trainer(gpus = 1, max_epochs = 2)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [6]:
trainer.fit(model, DataModule)

loading annotations into memory...
Done (t=0.48s)
creating index...
index created!
Map caption words to idx!!


  0%|          | 0/118287 [00:00<?, ?it/s]

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Map caption words to idx!!


  0%|          | 0/5000 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type       | Params
--------------------------------------------------
0 | backBone           | Sequential | 23.5 M
1 | feat_to_h0         | Linear     | 1.0 M 
2 | cell               | LSTMCell   | 1.3 M 
3 | wordEmbed          | Embedding  | 128 K 
4 | cellOuput_to_vocab | Linear     | 513 K 
--------------------------------------------------
3.0 M     Trainable params
23.5 M    Non-trainable params
26.5 M    Total params
106.052   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1