In [3]:
!pip install torchtext==0.6.0 --quiet
!pip3 install pytorch_lightning --quiet
!pip install keybert --quiet
!pip install keytotext --quiet
!pip install pycocoevalcap --quiet

In [4]:
!wget -O train_data4.pkl https://rice.box.com/shared/static/9qjcfo0p71goy286chr6doni9ipws98u.pkl --quiet
!wget -O model_checkpoint.ckpt https://rice.box.com/shared/static/3cfbqj3a2c44esp5wadukhfwwbkjbmpu.ckpt --quiet
!wget -O vocab_obj.pth https://rice.box.com/shared/static/k648cy0rmnp5kefkg3c0zpxaq7f4qe2f.pth --quiet
!wget -O t5-base_best.bin https://rice.box.com/shared/static/xzznmo11rfuw2ncw1gbkgbjuqyrtyij4.bin --quiet

In [5]:
import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import json
import torchtext
import pandas, os, pickle
from tqdm.notebook import tqdm
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
import requests
from io import BytesIO
import torchvision
import torchvision.transforms as transforms
from keybert import KeyBERT

# Plot the image.
def display_image(img):
  plt.figure(); plt.imshow(img)
  plt.grid(False);  plt.axis('off'); plt.show()

class CategoryDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, 
                 filenames: [], # location of the file.
                 build_vocab = True,
                 vocabulary_size = 12000, 
                 prev_texts = [],
                 start = 0,
                 end = 7000):

      self.texts = []
      self.images = []
      self.og_captions = []
      self.url = []
      self.keywords = []
      self.tokenized_keywords = []
      self.build_vocab = build_vocab
      self.textTokenizer = tokenizer
      keyword_model = KeyBERT()
      
      for filename in filenames: 

        print('Loading %s ...\n' % filename, end = '')

        with open(filename, 'rb') as f:
          temp_data = pickle.load(f)

        print("Pickle File Loaded")
        
        if (start != 0 or end != 7000):
          for i in tqdm(range(start, min(len(temp_data), end))):
            self.texts.append(tokenizer.preprocess(temp_data[i][0]))
            self.og_captions.append(temp_data[i][0])
            self.url.append(temp_data[i][1])
            self.images.append(temp_data[i][2])

            keywords = keyword_model.extract_keywords(temp_data[i][0])
            keyword_text = ""

            for elem in keywords:
              keyword_text = keyword_text + str(elem[0]) + " "

            self.keywords.append(keyword_text)
            self.tokenized_keywords.append(tokenizer.preprocess(keyword_text))

            keywords = list()
            token_keywords = list()
            
            temp_data[i] = None
          
          temp_data = list()

          print("\n")
        else :
          for i in tqdm(range(0, len(temp_data))):
            self.texts.append(tokenizer.preprocess(temp_data[i][0]))
            self.og_captions.append(temp_data[i][0])
            self.url.append(temp_data[i][1])
            self.images.append(temp_data[i][2])

            keywords = keyword_model.extract_keywords(temp_data[i][0])
            keyword_text = ""

            for elem in keywords:
              keyword_text = keyword_text + str(elem[0]) + " "

            self.keywords.append(keyword_text)
            self.tokenized_keywords.append(tokenizer.preprocess(keyword_text))

            keywords = list()
            token_keywords = list()
            
            temp_data[i] = None
          
          temp_data = list()

          print("\n")

      if self.build_vocab:
        tokenizer.build_vocab(self.tokenized_keywords + self.tokenized_keywords + self.texts, max_size = vocabulary_size)

      for i in tqdm(range(0 , len(self.texts))):
        self.texts[i] = tokenizer.process([self.texts[i]])
        self.tokenized_keywords[i] = tokenizer.process([self.tokenized_keywords[i]])

        
    def __len__(self):
        return len(self.texts)
  
    def __getitem__(self, i):
        return self.images[i], self.tokenized_keywords[i].squeeze()
    
    def get_original_caption(self, i):
      return self.og_captions[i]
    
    def get_url(self, i):
      return self.url[i]
    
    def get_keywords(self, i):
      return self.keywords[i]
    
    def get_tokenized_keywords(self, i):
      return self.tokenized_keywords[i]
    
    # To be used in the Data Loader collate_fn parameter.
    def create_batch(self, batch):
        images, keywords = zip(*batch)

        # Compute text lengths for Pytorch's RNN library.
        keywords_lengths = [len(keyword) for keyword in keywords]

        # Stack images and pad text.
        stacked_images = torch.stack(images)
        padded_keywords = pad_sequence(keywords, batch_first = self.textTokenizer.batch_first, 
                                    padding_value = self.textTokenizer.vocab.stoi["<pad>"])

        return stacked_images, padded_keywords, keywords_lengths

In [7]:
textTokenizer = torchtext.data.Field(sequential = True,
                                            init_token = "<start>", eos_token = "<end>", 
                                            pad_token = "<pad>", unk_token = "<unk>",
                                            batch_first = True)

textTokenizer.vocab = torch.load("/content/vocab_obj.pth")

In [8]:
test_set = CategoryDataset(tokenizer = textTokenizer, filenames = ["/content/train_data4.pkl"], build_vocab = False, start = 3000, end = 4000)

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Loading /content/train_data4.pkl ...
Pickle File Loaded


  0%|          | 0/1000 [00:00<?, ?it/s]





  0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
import torch.nn as nn
import torchvision.models as models

class ImageEncoder(nn.Module):
    # Encode images using Resnet-152
    def __init__(self, encoding_size: int):
        super(ImageEncoder, self).__init__()
        self.base_network = models.resnet152(pretrained = True)
        self.base_network.fc = nn.Linear(self.base_network.fc.in_features, encoding_size)
        self.bn = nn.BatchNorm1d(encoding_size, momentum=0.01)
        self.init_weights()

    def init_weights(self):
        
        self.base_network.fc.weight.data.normal_(0.0, 0.02)
        self.base_network.fc.bias.data.fill_(0)

    def forward(self, image):

        with torch.no_grad():

            x = self.base_network.conv1(image)
            x = self.base_network.bn1(x)
            x = self.base_network.relu(x)
            x = self.base_network.maxpool(x)
          
            x = self.base_network.layer1(x)
            x = self.base_network.layer2(x)
            x = self.base_network.layer3(x)
            x = self.base_network.layer4(x)
          
            x = self.base_network.avgpool(x)
            x = torch.flatten(x, 1)
  
        featureMap = self.base_network.fc(x)
        featureMap = self.bn(featureMap)
        return featureMap

In [10]:
import torch.nn as nn
from torch.nn import functional as F

class TextDecoder(nn.Module):
    def __init__(self, input_size: int, state_size: int, vocab_size: int):
        super(TextDecoder, self).__init__()
        self.state_size = state_size
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.rnnCell = nn.LSTMCell(input_size, state_size, bias=True)
        self.predictionLayer = nn.Linear(state_size, vocab_size)
        self.init_weights()
        
    def dummy_input_state(self, batch_size ):

        return (torch.zeros(batch_size, self.state_size), torch.zeros(batch_size, self.state_size))
    
    def init_weights(self):
        
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.predictionLayer.bias.data.fill_(0)
        self.predictionLayer.weight.data.uniform_(-0.1, 0.1)
        
    def forward(self, input_state, current_token_id):
        # Embed the input token id into a vector.
        embedded_token = self.embedding(current_token_id)

        # Pass the embedding through the RNN cell.
        h,c = self.rnnCell(embedded_token, input_state)
        
        # Output prediction.
        prediction = self.predictionLayer(F.relu(h))
        
        return prediction, (h,c)

In [11]:
import random 
import pytorch_lightning as pl

# Image Captioning module.
# Using pytorch-lightning for simplicity.

class ImageCaptioner(pl.LightningModule):
    def __init__(self, textTokenizer, val_data = None, embedding_size = 512, state_size = 1024):
        super(ImageCaptioner, self).__init__()
        self.vocabulary_size = len(textTokenizer.vocab)
        self.padding_token_id = textTokenizer.vocab.stoi["<pad>"]

        
        self.val_data = val_data
        
        
        # Create image encoder and text decoder.
        self.image_encoder = ImageEncoder(state_size)
        self.text_decoder = TextDecoder(embedding_size, 
                                        state_size, 
                                        self.vocabulary_size)
        
        self.criterion = nn.CrossEntropyLoss(
            ignore_index = self.padding_token_id)
        
        self.init_image_transforms()
        self.text_tokenizer = textTokenizer

        self.image_encoder_learning_rate = 1e-4
        self.text_decoder_learning_rate = 1e-3
    
    def init_image_transforms(self):
        # Create image transforms using standard Imagenet-based model transforms.
        normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                         std = [0.229, 0.224, 0.225])
        
        self.image_train_transform = \
            transforms.Compose([transforms.Resize(256),
                                transforms.RandomCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                normalize])
        
        self.image_test_transform = \
            transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                normalize])

    # Predict text given image -- input text is for "teacher forcing" only.
    def forward(self, image, text, lengths, teacher_forcing = 1.0):
        # Keep output scores for tokens in a list.
        predicted_scores = list()
        
        # Encode the image.
        encoded_image = self.image_encoder(image)
        
        # Grab the first token in the sequence.
        start_token = text[:, 0]  # This should be the <start> symbol.
        
        # Predict the first token from the start token embedding 
        # and feed the image as the initial state.
        # let first input state = None
        token_scores, state = self.text_decoder((encoded_image,encoded_image), start_token)
        predicted_scores.append(token_scores)
        
        # Iterate as much as the longest sequence in the batch.
        # minus 1 because we already fed the first token above.
        # minus 1 because we don't need to feed the end token <end>.
        for i in range(0, max(lengths) - 2):
            if random.random() < teacher_forcing:
                current_token = text[:, i + 1]
            else:
                _, max_token = token_scores.max(dim = 1)
                current_token = max_token.detach() # No backprop.
            token_scores, state = self.text_decoder(state, current_token)
            predicted_scores.append(token_scores)
            
        # torch.stack(,1) forces batch_first = True on this output.
        return torch.stack(predicted_scores, 1), lengths

    def training_step(self, batch, batch_idx, optimizer_idx):
        images, texts, lengths = batch

        # Compute the predicted texts.
        predicted_texts, _ = self(images, texts, lengths, 
                                  teacher_forcing = 1.0)        

        # Define the target texts. 
        # We have to predict everything except the <start> token.
        target_texts =  texts[:, 1:].contiguous()

        # Use cross entropy loss.
        loss = self.criterion(predicted_texts.view(-1, self.vocabulary_size),
                              target_texts.view(-1))
        self.log('train_loss', loss, on_epoch = True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        images, texts, lengths = batch

        predicted_texts, _ = self(images, texts, lengths,
                                  teacher_forcing = 0.0)
        
        target_texts = texts[:, 1:].contiguous()

        loss = self.criterion(predicted_texts.view(-1, self.vocabulary_size),
                              target_texts.view(-1))
        self.log('val_loss', loss, on_epoch = True)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        print('Validation loss %.2f' %  loss_mean)
        
        return {'val_loss': loss_mean}
    
    def training_epoch_end(self, outputs):
        print(outputs[0])
        loss_mean = torch.stack([x['loss'] for x in outputs[0]]).mean()
        print('Training loss %.2f' %  loss_mean)


    def configure_optimizers(self):
        return [torch.optim.SGD(list(self.image_encoder.base_network.fc.parameters())+\
                                list(self.image_encoder.bn.parameters()),
                                lr = self.image_encoder_learning_rate), \
                torch.optim.Adam(self.text_decoder.parameters(), 
                                 lr = self.text_decoder_learning_rate)], []

In [13]:
# Load the pre-trained model
checkpoint = torch.load("/content/model_checkpoint.ckpt", map_location=torch.device('cpu'))
image_captioner = ImageCaptioner(textTokenizer)
image_captioner.load_state_dict(checkpoint['state_dict'])

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth


  0%|          | 0.00/230M [00:00<?, ?B/s]

<All keys matched successfully>

In [14]:
# Write a function to generate captions

def generate_caption(model, image, max_length = 15):

    # Set the starting text with just the start token
    text = "<start>"

    # We evaluate the model in order to get it started
    model.eval()

    # After initializing the transforms, we can apply them to our image
    model.init_image_transforms()
    image = model.image_test_transform(image)
    image = image[None, :]

    # Create a text_tensor which can be used for the model
    text_tensor = torch.tensor([textTokenizer.vocab.stoi["<start>"]]).unsqueeze(0);

    # Run the model and change the results to a desirable format
    predicted_scores, _ = model(image, text_tensor, [max_length], 0)
    predicted_scores = predicted_scores[0, :, :].detach().softmax(dim = 1)
    predicted_scores = predicted_scores.view(-1, model.vocabulary_size)

    # Choose the tokens with the highest prediction score
    _, max_token = predicted_scores.max(dim = 1)

    # Get started on creating a text from these tokens and stop when we hit an 
    # <end> token
    final_text = ""

    for id in max_token:
      final_text += " "
      final_text += textTokenizer.vocab.itos[id]
      if textTokenizer.vocab.itos[id] == "<end>":
        break

    return final_text

In [20]:
from transformers import T5ForConditionalGeneration,Adafactor,T5Tokenizer

t5_str = 't5-base'

t5_tokenizer = T5Tokenizer.from_pretrained(t5_str)
t5_model = T5ForConditionalGeneration.from_pretrained(t5_str, return_dict=True)
t5_model.load_state_dict(torch.load('t5-base_best.bin', map_location=torch.device('cpu')))

def generate(text_list,model,tokenizer):
   text=", ".join(text_list)
   model.eval()
   input_ids = tokenizer.encode(text, 
                               return_tensors="pt")  
   outputs = model.generate(input_ids)
   replaced = tokenizer.decode(outputs[0])
   replace_pairs = [('webNLG:', ''),
                    ('WebNLG:', ''),
                    ('webNLG', '<unk>'),
                    ('WebNLG', '<unk>'),
                    ('webNG', '<unk>'),
                    ('WebNG', '<unk>'),
                    ('webng', '<unk>'),
                    ('</s>', ''),
                    ('<pad>', '')]
   for (bad, good) in replace_pairs:
     replaced = replaced.replace(bad, good)
   return replaced

In [None]:
def create_keywords(filename):

    img = Image.open(filename)
    caption = generate_caption(image_captioner, img)

    # print("Caption from CNN: ", caption)

    keywords = caption.split(' ')[1:-1]
    # print("Keywords: ", keywords)

    keys = []
    for word in keywords:
      if word != '<unk>' and word != 'cat':
        keys.append(word)
    # print("Keys: ", keys)

    # display_image(img)

    return keys

In [106]:
def create_caption(url, input_words):

    filename=BytesIO(requests.get(url).content)

    display_image(Image.open(filename))

    try: 
      cnn_keywords = create_keywords(filename)
      print(cnn_keywords)
    except: 
      print("invalid")

    combined_keywords = list(set(input_words.split(" ")+cnn_keywords))
    # combined_keywords.remove("")

    if combined_keywords == []:
      combined_keywords = ['cat']

    # print(combined_keywords)
    caption = generate(combined_keywords, t5_model, t5_tokenizer)

    return caption

In [37]:
datasetGTS = {'annotations': []}
datasetRES = {'annotations': []}

for index in range(0, 1000):

  url=test_set.get_url(index)
  # url='https://i.redd.it/jmgo63dbxp721.jpg'
  filename=BytesIO(requests.get(url).content)
  input_words = ""
  input_words=test_set.get_keywords(index)

  try: 
    cnn_keywords = create_keywords(filename)
  except: 
    print("invalid")
    continue

  combined_keywords = list(set(input_words.split(" ")+cnn_keywords))
  # combined_keywords.remove("")

  if combined_keywords == []:
    combined_keywords = ['cat']

  # print(combined_keywords)
  caption = generate(combined_keywords, t5_model, t5_tokenizer)
  true_caption=test_set.get_original_caption(index)

  datasetGTS['annotations'].append({'image_id': index, 'caption': true_caption})
  datasetRES['annotations'].append({'image_id': index, 'caption': caption})

invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid
invalid


In [38]:
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider

class COCOEvalCap:
    def __init__(self,images,gts,res):
        self.evalImgs = []
        self.eval = {}
        self.imgToEval = {}
        self.params = {'image_id': images}
        self.gts = gts
        self.res = res

    def evaluate(self):
        imgIds = self.params['image_id']
        gts = self.gts
        res = self.res

        # =================================================
        # Set up scorers
        # =================================================
        print('tokenization...')
        tokenizer = PTBTokenizer()
        gts  = tokenizer.tokenize(gts)
        res = tokenizer.tokenize(res)

        # =================================================
        # Set up scorers
        # =================================================
        print('setting up scorers...')
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(),"METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr")
        ]

        # =================================================
        # Compute scores
        # =================================================
        eval = {}
        for scorer, method in scorers:
            print('computing %s score...'%(scorer.method()))
            score, scores = scorer.compute_score(gts, res)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    self.setEval(sc, m)
                    self.setImgToEvalImgs(scs, imgIds, m)
                    print("%s: %0.3f"%(m, sc))
            else:
                self.setEval(score, method)
                self.setImgToEvalImgs(scores, imgIds, method)
                print("%s: %0.3f"%(method, score))
        self.setEvalImgs()

    def setEval(self, score, method):
        self.eval[method] = score

    def setImgToEvalImgs(self, scores, imgIds, method):
        for imgId, score in zip(imgIds, scores):
            if not imgId in self.imgToEval:
                self.imgToEval[imgId] = {}
                self.imgToEval[imgId]["image_id"] = imgId
            self.imgToEval[imgId][method] = score

    def setEvalImgs(self):
        self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]

def calculate_metrics(rng,datasetGTS,datasetRES):
    imgIds = rng
    gts = {}
    res = {}

    imgToAnnsGTS = {ann['image_id']: [] for ann in datasetGTS['annotations']}
    for ann in datasetGTS['annotations']:
        imgToAnnsGTS[ann['image_id']] += [ann]

    imgToAnnsRES = {ann['image_id']: [] for ann in datasetRES['annotations']}
    for ann in datasetRES['annotations']:
        imgToAnnsRES[ann['image_id']] += [ann]

    for imgId in imgIds:
        gts[imgId] = imgToAnnsGTS[imgId]
        res[imgId] = imgToAnnsRES[imgId]

    evalObj = COCOEvalCap(imgIds,gts,res)
    evalObj.evaluate()
    return evalObj.eval

rng = range(2)
print(calculate_metrics(rng,datasetGTS,datasetRES))

tokenization...
setting up scorers...
computing Bleu score...
{'testlen': 20, 'reflen': 17, 'guess': [20, 18, 16, 14], 'correct': [8, 3, 2, 1]}
ratio: 1.17647058816609
Bleu_1: 0.400
Bleu_2: 0.258
Bleu_3: 0.203
Bleu_4: 0.156
computing METEOR score...
METEOR: 0.282
computing Rouge score...
ROUGE_L: 0.350
computing CIDEr score...
CIDEr: 1.392
{'Bleu_1': 0.3999999999800001, 'Bleu_2': 0.258198889733534, 'Bleu_3': 0.20274006650775622, 'Bleu_4': 0.1561969968366612, 'METEOR': 0.28164211027232655, 'ROUGE_L': 0.3497932415842864, 'CIDEr': 1.3921974002369581}
