In [1]:
import torchvision.models as models
from torch import nn
import torch
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision.transforms import transforms
from nltk.tokenize import word_tokenize
from string import punctuation
from torchtext.vocab import build_vocab_from_iterator
from sklearn.model_selection import train_test_split
from tqdm import tqdm
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'



In [2]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, dropout = 0.5, grad = False):
        super(Encoder, self).__init__()
        self.resnet = models.resnet50(weights='DEFAULT')
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_dim)
        
        if not grad:
            for param in self.resnet.parameters():
                param.requires_grad = False

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        feature = self.resnet(x)
        return self.dropout(self.relu(feature))

In [3]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers, device, encoder, dropout=0.5):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.device = device
        self.encoder = encoder.to(device)
    
    def forward(self, image, caption):
        features = self.encoder(image)
        
        embeddings = self.dropout(self.embed(caption))
       
        embeddings = torch.cat((features.unsqueeze(1),embeddings), dim=1)
        
        outputs, state = self.lstm(embeddings)
        outputs = self.linear(outputs)
        
        return outputs

In [4]:
data = pd.read_csv("captions.txt")

In [5]:
data.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [6]:
def clean_text(text, lowercase=False, remove_punc=False, remove_num=False, sos_token='<sos>', eos_token='<eos>'):
    if lowercase:
        text = text.lower()
    if remove_punc:
        text = ''.join([ch for ch in text if ch not in punctuation])
    if remove_num:
        text = ''.join([ch for ch in text if ch not in '1234567890'])
    text = [sos_token] + word_tokenize(text) + [eos_token]
    return text

In [7]:
import nltk

In [8]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\balaganapathi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [9]:
clean_text("A cat is sitting on the table.", lowercase=True, remove_punc=True, remove_num=True)

['<sos>', 'a', 'cat', 'is', 'sitting', 'on', 'the', 'table', '<eos>']

In [10]:
unk_token = '<unk>'
pad_token = '<pad>'
sos_token = '<sos>'
eos_token = '<eos>'

In [11]:
clean_cap = data['caption'].apply(lambda x: clean_text(x, lowercase=True, remove_punc=True, remove_num=True))

In [12]:
clean_cap.head()

0    [<sos>, a, child, in, a, pink, dress, is, clim...
1    [<sos>, a, girl, going, into, a, wooden, build...
2    [<sos>, a, little, girl, climbing, into, a, wo...
3    [<sos>, a, little, girl, climbing, the, stairs...
4    [<sos>, a, little, girl, in, a, pink, dress, g...
Name: caption, dtype: object

In [13]:
data['clean_caption'] = clean_cap

In [14]:
data.head()

Unnamed: 0,image,caption,clean_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,"[<sos>, a, child, in, a, pink, dress, is, clim..."
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<sos>, a, girl, going, into, a, wooden, build..."
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<sos>, a, little, girl, climbing, into, a, wo..."
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,"[<sos>, a, little, girl, climbing, the, stairs..."
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,"[<sos>, a, little, girl, in, a, pink, dress, g..."


In [15]:
vocab = build_vocab_from_iterator(clean_cap, specials=[unk_token, pad_token, sos_token, eos_token])

In [16]:
vocab.get_itos()[:10]

['<unk>', '<pad>', '<sos>', '<eos>', 'a', 'in', 'the', 'on', 'is', 'and']

In [17]:
pad_token_idx = vocab[pad_token]
unk_token_idx = vocab[unk_token]

In [18]:
vocab.set_default_index(unk_token_idx)

In [19]:
# to number
def text_to_number(text, vocab):
    return [vocab[token] for token in text]
    

In [20]:
to_int = clean_cap.apply(lambda x: text_to_number(x, vocab))

In [21]:
to_int

0        [2, 4, 43, 5, 4, 91, 171, 8, 120, 54, 4, 400, ...
1                      [2, 4, 20, 316, 65, 4, 196, 118, 3]
2                 [2, 4, 41, 20, 120, 65, 4, 196, 2569, 3]
3             [2, 4, 41, 20, 120, 6, 394, 21, 61, 2569, 3]
4        [2, 4, 41, 20, 5, 4, 91, 171, 316, 65, 4, 196,...
                               ...                        
40450         [2, 4, 12, 5, 4, 91, 38, 253, 4, 85, 124, 3]
40451             [2, 4, 12, 8, 85, 120, 197, 5, 6, 66, 3]
40452    [2, 4, 44, 5, 4, 26, 38, 120, 54, 4, 85, 124, ...
40453                     [2, 4, 85, 359, 5, 4, 26, 38, 3]
40454         [2, 4, 85, 359, 1915, 7, 4, 85, 120, 110, 3]
Name: caption, Length: 40455, dtype: object

In [22]:
data['embed_caption'] = to_int

In [23]:
data.head()

Unnamed: 0,image,caption,clean_caption,embed_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,"[<sos>, a, child, in, a, pink, dress, is, clim...","[2, 4, 43, 5, 4, 91, 171, 8, 120, 54, 4, 400, ..."
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<sos>, a, girl, going, into, a, wooden, build...","[2, 4, 20, 316, 65, 4, 196, 118, 3]"
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<sos>, a, little, girl, climbing, into, a, wo...","[2, 4, 41, 20, 120, 65, 4, 196, 2569, 3]"
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,"[<sos>, a, little, girl, climbing, the, stairs...","[2, 4, 41, 20, 120, 6, 394, 21, 61, 2569, 3]"
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,"[<sos>, a, little, girl, in, a, pink, dress, g...","[2, 4, 41, 20, 5, 4, 91, 171, 316, 65, 4, 196,..."


In [24]:
vocab.lookup_tokens(data['embed_caption'][0])

['<sos>',
 'a',
 'child',
 'in',
 'a',
 'pink',
 'dress',
 'is',
 'climbing',
 'up',
 'a',
 'set',
 'of',
 'stairs',
 'in',
 'an',
 'entry',
 'way',
 '<eos>']

In [25]:
train, test = train_test_split(data, test_size=0.2, random_state=42)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

In [26]:
def collate_fn(batch, pad_index):
    images = []
    captions = []
    for img, cap in batch:
        images.append(img)
        captions.append(cap)
    images = torch.stack(images)
    captions = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=pad_index)
    return images, captions

def get_collate_fn(pad_index):
    return lambda batch: collate_fn(batch, pad_index)

In [27]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, data, transform=None):
        self.root_dir = root_dir
        self.captions = data['embed_caption']
        self.images = data['image']
        self.transform = transform
        
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.images[idx]))
        caption = torch.tensor(self.captions[idx])
        if self.transform:
            image = self.transform(image)
    
        return image, caption

In [28]:
transform = transforms.Compose([
    # data type convert to tensor
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
])

In [29]:
train_dataset = CustomDataset("images", train, transform=transform)
test_dataset = CustomDataset("images", test, transform=transform)

In [30]:
batch_size = 256
num_workers = 0

In [31]:
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                               collate_fn=get_collate_fn(pad_token_idx))
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                              collate_fn=get_collate_fn(pad_token_idx))


In [32]:
embed_dim = 100
hidden_dim = 200
vocab_size = len(vocab)
num_layers = 2
dropout = 0.5

In [33]:
encoder = Encoder(embed_dim, dropout)
model = Decoder(embed_dim, hidden_dim, vocab_size, num_layers, device, encoder, dropout )
model = model.to(device)

In [34]:
n_epochs = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index = pad_token_idx)
clip = 1.0
teacher_forcing_ratio = 0.5
best_valid_loss = float("inf")

In [35]:
def train_fn(model, data_loader, optimizer, criterion, clip, device):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(data_loader):
        images, captions = batch
        images, captions = images.to(device), captions.to(device)
     
        optimizer.zero_grad()
        
        captions_in = captions[:,:-1]
        outputs = model(images, captions_in)
        outputs = outputs.view(-1, outputs.shape[2]).to(device)
        
        captions = captions.view(-1)
        
        loss = criterion(outputs, captions)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(data_loader)

In [36]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            images, captions = batch
            images, captions = images.to(device), captions.to(device)
            
            captions_in = captions[:,:-1]
            outputs = model(images, captions_in)
            
            outputs = outputs.view(-1, outputs.shape[2]).to(device)
            captions = captions.view(-1)
        
            loss = criterion(outputs, captions)
            epoch_loss += loss.item()
            
    return epoch_loss / len(data_loader)

In [None]:
for epoch in tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        device)
    
    
    valid_loss = evaluate_fn(
        model,
        test_data_loader,
        criterion,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "best-model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f}")
torch.save(model.state_dict(), "last_model.pt")

  0%|                                                                                          | 0/10 [00:00<?, ?it/s]

In [None]:
import matplotlib.pyplot as plt
image = Image.open("Images/1015118661_980735411b.jpg")

In [None]:
image

In [None]:

def predict_caption(model, image, vocab, max_length=50):
    model.eval()
    with torch.no_grad():
        features = model.encoder(image).unsqueeze(1)
        input = features
        hidden = torch.zeros(model.num_layers, 1, model.lstm.hidden_size).to(model.device)
        cell = torch.zeros(model.num_layers, 1, model.lstm.hidden_size).to(model.device)

        caption = []
        for _ in range(max_length):
            output, (hidden, cell) = model.lstm(input, (hidden, cell))
            output = model.linear(output.squeeze(1))
            predicted = output.argmax(1)
            caption.append(predicted.item())
            input = model.dropout(model.embed(predicted)).unsqueeze(1)
            if predicted.item() == vocab['<eos>']:
                break
                
    return vocab.lookup_tokens(caption)

In [None]:
image = transform(image)
pred = predict_caption(model, image.unsqueeze(0).to(device), vocab, 20)
pred