# Getting Data Ready 

In [2]:
import re
import json
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision.transforms import transforms
import numpy as np
import random

def preprocess(text):
    t = re.sub(r'\(.*?\)', ' ', str(text))
    t = re.sub('[،\t\u200c\u200e\xa0]',' ', t)
    t = re.sub('[#;&،»«؛/\\\']', ' ', t)
    t = re.sub('[!؟:].', ' ', t)
    t = re.sub('[0-9]', '_', t)
    t = re.sub('[a-z]', '_', t)
    t = re.sub(r'_+', 'یک', t)
    t = re.sub(r'-', ' ', t)
    t = re.sub(r' +', ' ', t)
    return t

data = ImageFolder('persian_image_captioning/',transform=transforms.Compose([
        transforms.Resize((140, 200)), transforms.ToTensor()
]))
image_names = {x.split('\\')[1]:i for i,x in enumerate(np.array(data.imgs)[:,0])}
with open('persian_image_captioning/news.json', encoding='utf8') as f:
    json_file = json.load(f)
title_images = []
words = set()
for x in json_file:
    title_images.append([x['title'], x['images']])
    words.update(preprocess(x['title']).split())
words.update(['<start>','<end>'])
word_to_idx = {}
idx_to_word = {}
for i,w in enumerate(words):
    word_to_idx[w] = i
    idx_to_word[i] = w
titles = []
image_indexes = []
for title,images in title_images:
    t = preprocess(title)
    if t in titles:
        continue
    for image in images:
        if image == '1400070516181726323700473.jpg':
            continue
        image_index = image_names[image]
        if image_index in image_indexes:
            continue
        titles.append(t)
        image_indexes.append(image_index)
title_tokens = [[] for i in range(17)]
image_indexes_splitted = [[] for i in range(17)]
for title,image_index in zip(titles, image_indexes):
    s = title.split()
    arr = [word_to_idx['<start>']]
    arr.extend([word_to_idx[x] for x in s])
    arr.extend([word_to_idx['<end>']])
    title_tokens[len(s)-1].append(arr)
    image_indexes_splitted[len(s)-1].append(image_index)
TRAIN_TOKENS_DATALOADER = []
TRAIN_IMAGES_DATALOADER = []
TEST_TOKENS_DATALOADER = []
TEST_IMAGES_DATALOADER = []
for t,i in zip(title_tokens, image_indexes_splitted):
    train_indexes = [x for x in range(len(t))]
    test_indexes = random.sample(train_indexes, int(0.2*len(train_indexes)))
    for x in test_indexes:
        train_indexes.remove(x)
    train_t = np.array(t)[train_indexes]
    train_i = np.array(i)[train_indexes]
    test_t = np.array(t)[test_indexes]
    test_i = np.array(i)[test_indexes]
    TRAIN_TOKENS_DATALOADER.append(DataLoader(np.array(train_t).reshape(len(train_t), -1), batch_size=64, shuffle=False))
    TRAIN_IMAGES_DATALOADER.append(DataLoader(Subset(data, train_i), batch_size=64, shuffle=False))
    TEST_TOKENS_DATALOADER.append(DataLoader(np.array(test_t).reshape(len(test_t), -1), batch_size=64, shuffle=False))
    TEST_IMAGES_DATALOADER.append(DataLoader(Subset(data, test_i), batch_size=64, shuffle=False))

# Initialize Models

In [3]:
import torch
from torch import nn
from torchvision import models
from torch.optim import Adam

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, feature_dim, hidden_size):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + feature_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, tokens, hidden, c):
        embeddings = self.embedding(tokens)
        f = torch.zeros(embeddings.size(1), features.size(0), features.size(1)).to(DEVICE)
        f[:] = features
        f = f.permute(1,0,2)
        lstm_input = torch.cat((embeddings, f), dim=2)
        lstm_output, (hidden_, c_) = self.lstm(lstm_input, (hidden, c))
        outputs = self.fc(lstm_output)
        outputs = outputs.permute(0,2,1)
        return outputs, hidden_, c_

class Img2Seq(nn.Module):
    def __init__(self, vocab_size, encoder, decoder):
        super().__init__()
        self.vocab_size = vocab_size
        self.encoder = encoder
        self.decoder = decoder

    def choose_guess_target(self, guess, target, ratio):
        arr = []
        for i in range(guess.shape[0]):
            if random.random() < ratio:
                arr.append(guess[i])
            else:
                arr.append(target[i])
        return torch.tensor(arr).unsqueeze(1).to(DEVICE)

    def forward(self, images, tokens, ratio=0.7):
        features = self.encoder(images)
        outputs = torch.zeros(tokens.shape[1], tokens.shape[0]).to(DEVICE)
        logits = torch.zeros(tokens.shape[1], tokens.shape[0], self.vocab_size).to(DEVICE)
        best_guess = tokens[:,0]
        hidden = c = torch.zeros((1, tokens.shape[0], self.decoder.lstm.hidden_size)).to(DEVICE)
        for t in range(tokens.shape[1]):
            x = self.choose_guess_target(best_guess.unsqueeze(1), tokens[:,t].unsqueeze(1), ratio).type(torch.int)
            predictions, hidden, c = self.decoder(features, x, hidden, c)
            logits[t] = predictions.squeeze(2)
            best_guess = predictions.squeeze(2).argmax(dim=-1)
            outputs[t] = best_guess
        return outputs.permute(1,0), logits.permute(1,2,0)
    
    def generate(self, image, start_token, end_token, max_length):
        features = self.encoder(image.unsqueeze(0))
        x = torch.tensor([[start_token]], dtype=torch.int).to(DEVICE)
        outputs = []
        hidden = c = torch.zeros((1, 1, self.decoder.lstm.hidden_size)).to(DEVICE)
        for t in range(max_length):
            predictions, hidden, c = self.decoder(features, x, hidden, c)
            predictions = predictions[:, :, -1].unsqueeze(2)
            if predictions.argmax(dim=1).item() == end_token:
                break
            else:
                outputs.append(predictions.argmax(dim=1).item())
            x = torch.concatenate([x, predictions.argmax(dim=1)], dim=-1)
        return outputs

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = models.resnet18()
encoder.fc = nn.Linear(encoder.fc.in_features, 256)
decoder = Decoder(len(words), 50, 256, 256).to(DEVICE)
img2seq = Img2Seq(len(words), encoder, decoder).to(DEVICE)
criterion = nn.CrossEntropyLoss()
encoder_optimizer = Adam(encoder.fc.parameters(), lr=0.0008)
decoder_optimizer = Adam(decoder.parameters(), lr=0.0008)

# Train Model

In [None]:
n_epochs = 50
for epoch in range(n_epochs):
    b1 = 0
    train_acc_sum = 0
    for image_loader, token_loader in zip(TRAIN_IMAGES_DATALOADER, TRAIN_TOKENS_DATALOADER):
        for i,((images,_), tokens) in enumerate(zip(image_loader, token_loader)):
            tokens = tokens.type(torch.long).to(DEVICE)
            expected = tokens[:, 1:]
            images = images.to(DEVICE)
            output, pred = img2seq(images, tokens[:, :-1], ratio=(0.2+(0.8-0.2)/n_epochs*epoch))
            loss = criterion(pred, expected)
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
            train_acc_sum += ((torch.softmax(pred, dim=1).argmax(dim=1) == expected).sum() / expected.size(1)).item()
            b1 += len(tokens)

    b2 = 0
    test_acc_sum = 0
    for image_loader, token_loader in zip(TEST_IMAGES_DATALOADER, TEST_TOKENS_DATALOADER):
        for i,((images,_), tokens) in enumerate(zip(image_loader, token_loader)):
            tokens = tokens.type(torch.long).to(DEVICE)
            expected = tokens[:, 1:]
            images = images.to(DEVICE)
            output, pred = img2seq(images, tokens[:, :-1], ratio=0)
            test_acc_sum += ((torch.softmax(pred, dim=1).argmax(dim=1) == expected).sum() / expected.size(1)).item()
            b2 += len(tokens)
    print(f'epoch {epoch}   train accuracy {train_acc_sum/b1}   test accuracy {test_acc_sum/b2}')

# Generate Caption

In [None]:
from PIL import Image
from torchvision.transforms import transforms

def generate_caption(image_path):
    img2seq.eval()
    img = Image.open(image_path)
    img = transform(img).to(DEVICE)
    generated_sequence = img2seq.generate(img, word_to_idx['<start>'], word_to_idx['<end>'], 17)
    caption = ''
    for idx in generated_sequence:
        caption += f'{idx_to_word[idx]} '
    return caption

image_path = './persian_image_captioning/selected_images/1.jpg'
transform = transforms.Compose([
    transforms.Resize((280, 400)), transforms.ToTensor()
])
generate_caption(image_path)