<a href="https://colab.research.google.com/github/mehrshad-sdtn/DeepLearning/blob/master/PyTorch/5_PyTorch_Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
import os

In [None]:
# load data

In [None]:
# model

class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, images):
        features = self.inception(images)
        for name, param in self.inception.named_parameters():
            if 'fc.weight' in name or 'fc.bias' in name:
                param.requires_grad = True
            else:
                param.requires_grad = self.train_CNN
        features = self.relu(features)
        features = self.dropout(features)
        return features





class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), 0)
        lstm_out, _ = self.lstm(embeddings)
        outputs = self.linear(lstm_out)
        return outputs






class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []
        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)
                if vocabulary.itos[predicted.item()] == '<end>':
                    break

        return [vocabulary.itos[idx] for idx in result_caption]



In [None]:
# train
def train():
  transform = transforms.Compose([
       transforms.Resize((356, 356)),
       transforms.CenterCrop((299, 299)),
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
       ])

  train_loader, dataset = get_loader(
      root_folder='flickr8k/images',
      annotation_file='flickr8k/captions.txt',
      transform=transform,
      num_workers=2,
      )
  torch.backends.cudnn.benchmark = True
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  load_model = False
  save_model = True


  writer = SummaryWriter('runs/flickr8k')
  step=0


  model = CNNtoRNN(embed_size=256,
                   hidden_size=256,
                   vocab_size=len(dataset.vocab),
                   num_layers=1)
  criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
  optimizer = optim.Adam(model.parameters(), lr=0.0003)

  if load_model:
    step = load_checkpoint(torch.load('my_checkpoint.pth.tar'), model, optimizer)

  model.to(device)
  model.train()


### To Be Completed


