In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torchvision.models import ResNet50_Weights
import spacy, re
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import os, sys, io, json
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence


In [3]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [7]:
spacy_eng = spacy.load("en_core_web_sm")

# class Vocabulary:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        
        ## REMOVE SPECIAL CHARACTERS
        #cleaned_text = ''.join(e for e in text if e.isalnum())
        #cleaned_text = re.sub(r"[^a-zA-Z0-9 ]", "", text)
        cleaned_text = text
        
        tokenized_text = [tok.text.lower() for tok in spacy_eng.tokenizer(cleaned_text)]
        return tokenized_text
    
    # Create dictionary of vocabulary and frequency
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    # add word to frequencies dictionary, set frequency to 1 (initial word)
                    frequencies[word] = 1
                else:
                    # word is in frequencies dictionary, increment frequency by 1
                    frequencies[word] += 1
                    
                if frequencies[word] == self.freq_threshold:
                    # word has reached frequency threshold in vocab dictionary
                    # add word to vocab dictionary (at most once)
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
                    
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        
        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]



In [16]:
# Flickr8k Dataset
class Flickr8k(Dataset):
    def __init__(self, imgs_dir, captions_file, test_file, transform=None, freq_threshold=5):
        self.imgs_dir = imgs_dir
        self.df = pd.read_csv(captions_file)
        self.test_file = test_file
        self.transform = transform
        
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.imgs_dir, img_id)).convert("RGB")
        
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [None]:
# PCCD Dataset
class PCCD(Dataset):
    def __init__(self, imgs_dir, captions_file, test_file, transform=None, freq_threshold=5):
        self.imgs_dir = imgs_dir
        self.df = pd.read_json(captions_file)
        self.test_file = test_file
        self.transform = transform
        
        self.df["general_impression"] = self.df["general_impression"].fillna("")
        
        # Get img, caption columns
        self.imgs = self.df["title"]
        self.captions = self.df["general_impression"]

        # Initialize and build vocabulary
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.imgs_dir, img_id)).convert("RGB")
                
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [8]:
# AVA-Captions Dataset
class AVA(Dataset):
    def __init__(self, imgs_dir, captions_file, test_file, transform=None, freq_threshold=5):
        self.imgs_dir = imgs_dir
        self.captions_file = captions_file
        self.test_file = test_file
        self.transform = transform
        self.freq_threshold = freq_threshold
        
        """
        # Open dataframe
        with io.open(self.captions_file, 'r', encoding='utf-8') as f:
            json_file = json.load(f)
        self.df = pd.DataFrame(json_file['images'])
        
        #Get img, caption columns
        self.imgs = []
        self.captions = []
        for i, img in enumerate(self.df['filename']):
            for j, caption in enumerate(self.df['sentences'][i]):
                if os.path.exists(os.path.join(self.imgs_dir, img)):
                    self.imgs.append(img)
                    self.captions.append(caption['clean'])
        """
        self.df = pd.read_feather(captions_file)
        
        self.imgs = self.df['filename']
        self.captions = self.df['clean_sentence']
        self.split = self.df['split']
        
        # Initialize and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions)
        
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.imgs_dir, img_id)).convert("RGB")
        
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [13]:
# custom collate
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        # Sort batch list by caption length in descending order (longest to shortest)
        #batch.sort(key=lambda x: len(x[1]), reverse=True)
        
        ## TEST IF SORTING IS NEEDED OR NOT
        
        imgs = [item[0] for item in batch]
        imgs = torch.stack(imgs)
        
        captions = [item[1] for item in batch]
        #truncated = [cap[0:50] if 50 > len(cap) else cap for cap in captions]
        targets = pad_sequence(captions, batch_first=False, padding_value=self.pad_idx)
        #packed_targets = pack_padded_sequence(targets, batch_first=False, enforce_sorted=False)
        
        lengths = [len(cap) for cap in captions]

        # imgs:    (batch size, 3, 224, 224)
        # targets: (sequence length, batch size)
        # lengths: (batch size)
        return imgs, targets, lengths

In [5]:
# def get_loader()
def get_loader(dataset_to_use, imgs_folder, annotation_file, transform, test_file="", batch_size=32, num_workers=8, freq_threshold=5, shuffle=True, pin_memory=True):
    
    if dataset_to_use == "PCCD":
        dataset = PCCD(imgs_folder, annotation_file, test_file, transform=transform, freq_threshold=freq_threshold)
    elif dataset_to_use == "flickr8k":
        dataset = Flickr8k(imgs_folder, annotation_file, test_file, transform=transform, freq_threshold=freq_threshold)
    elif dataset_to_use == "AVA":
        dataset = AVA(imgs_folder, annotation_file, test_file, transform=transform, freq_threshold=freq_threshold)
    
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )
    
    return loader, dataset

In [6]:
def save_checkpoint(state, filename):
    print("-- Saving Checkpoint --")
    torch.save(state, filename)
    
    
def load_checkpoint(checkpoint, model, optimizer):
    print("-- Loading Checkpoint --")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

In [11]:
def print_examples(model, device, dataset):
    transform = transforms.Compose(
        [
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        ]
    )
    
    filename_loc = dataset.test_file
    images_loc = dataset.imgs_dir
    
    filename_list = pd.read_csv(filename_loc, header=None)
    filename_list = filename_list.values.reshape(-1).tolist()
    
    
    model.eval()
    
    for i, dir in enumerate(filename_list):
        path = os.path.join(images_loc, dir)
        test_img = transform(Image.open(path).convert("RGB")).unsqueeze(0)
        print(f"Example {i}) OUTPUT: " + " ".join(model.caption_image(test_img.to(device), dataset.vocab)))
        if i > 5:
            break
        
    model.train()

In [17]:
# class EncoderCNN(nn.Module):
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, dropout, train_model=False):
        super(EncoderCNN, self).__init__()
        self.train_model = train_model
        
        self.resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.in_features = self.resnet.fc.in_features
        
        self.resnet = nn.Sequential(*(list(self.resnet.children())[:-1]))
        self.linear = nn.Linear(self.in_features, embed_size)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        #self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        # images: (batch_size, 3, 224, 224)
        
        features = self.resnet(images)                      # features: (batch_size, 2048, 1, 1)
        features = features.view(features.size(0), -1)      # features: (batch_size, 2048)
        features = self.linear(features)                    # features: (batch_size, embed_size)
        #features = self.relu(features)                      
        # features = self.bn(features)
        features = self.dropout(features)                   # features: (batch_size, embed_size)
        return features

In [36]:
# Language Model
class DecoderLSTM(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout, bidirectional=False):
        super(DecoderLSTM, self).__init__()
        if num_layers < 2: dropout=0.0
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        
        # Word Embeddings - https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        # Input:  (sequence length, batch size)
        # Output: (sequence length, batch size, embed size)
        self.embed = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=embed_size,
            )
        
        # LSTM - https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        # Input:  (sequence length, batch size, embed size)
        # States: (2 if bidirectional else 1 * num layers, batch size, hidden size)
        # Output: (sequence length, batch size, 2 if bidirectional else 1 * hidden size)
        self.lstm = nn.LSTM(
            input_size=embed_size,
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            dropout=dropout, 
            batch_first=False,
            bidirectional=bidirectional,
            )
        
        # Fully Connected
        # Input:  (sequence length, batch size, hidden size)
        # Output: (sequence length, batch size, vocab size)
        self.linear = nn.Linear(
            in_features=hidden_size, 
            out_features=vocab_size,
            )
    
    def forward(self, features, captions):
        # features: (batch_size, embed_size)
        # captions: (caption_length, batch_size)
        
        embeddings = self.embed(captions)
        states = torch.stack([features]*(self.num_layers), dim=0)

        #packed = pack_padded_sequence(embeddings, lengths, batch_first=False, enforce_sorted=True)    
        
        lstm_out, states = self.lstm(embeddings, states)
        linear_outputs = self.linear(lstm_out)
        
        #outputs = linear_outputs.reshape(-1, self.vocab_size)
        
        return linear_outputs
    
    def generate_text(self, inputs, vocabulary, max_length=50):
        result_text = []
        
        with torch.no_grad():
            features = inputs
            lstm_in = features.unsqueeze(0)
            hidden = None

            for _ in range(max_length):
                lstm_out, hidden = self.lstm(lstm_in, hidden)
                linear_in = lstm_out.squeeze(0)
                linera_out = self.linear(linear_in)
                
                predicted = torch.argmax(linera_out, dim=1)
                result_text.append(predicted.item())
                
                lstm_in = self.embed(predicted)
                lstm_in = lstm_in.unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break
                
        return [vocabulary.itos[idx] for idx in result_text]

In [19]:
# class CNNtoRNN(nn.Module):
class CNNtoLSTM(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout):
        super(CNNtoLSTM, self).__init__()
        self.encoder = EncoderCNN(embed_size, dropout)
        self.decoder = DecoderLSTM(embed_size, hidden_size, vocab_size, num_layers, dropout)
    
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs
    
    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []
        
        with torch.no_grad():
            # image: (3, 224, 224)

            features = self.encoder(image)                                    # inputs: (batch_size=1, embed_size)
            lstm_in = features.unsqueeze(0)                                    # inputs: (1, batch_size=1, embed_size)
            hidden = None

            for _ in range(max_length):
                lstm_out, hidden = self.decoder.lstm(lstm_in, hidden)        # lstm_out: (1, batch_size=1, hidden_size)
                linear_in = lstm_out.squeeze(0)                              # lstm_out: (batch_size=1, hidden_size)
                linera_out = self.decoder.linear(linear_in)                      # output: (batch_size=1, vocab_size)
                
                predicted = torch.argmax(linera_out, dim=1)                     # predicted: (batch_size=1)
                result_caption.append(predicted.item())
                
                lstm_in = self.decoder.embed(predicted)                      # input: (batch_size=1, embed_size)
                lstm_in = lstm_in.unsqueeze(0)                                # input: (1, batch_size=1, embed_size)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break
                
        return [vocabulary.itos[idx] for idx in result_caption]

In [27]:
# Hyperparameters
embed_size = 256
hidden_size = 256
num_layers = 1
learning_rate = 3e-4
batch_size = 64
num_workers = 0
dropout = 0.0

num_epochs = 5

In [30]:
#dataset_to_use = "PCCD"
dataset_to_use = "flickr8k"
#dataset_to_use = "AVA"

if dataset_to_use == "PCCD":
    imgs_folder = "../datasets/PCCD/images/full"
    annotation_file = "../datasets/PCCD/raw.json"
    test_file = "../datasets/PCCD/images/PCCD_test.txt"
    
elif dataset_to_use == "flickr8k":
    imgs_folder = "../datasets/flickr8k/images"
    annotation_file = "../datasets/flickr8k/captions.txt"
    test_file = "../datasets/flickr8k/flickr8k_test.txt"

elif dataset_to_use == "AVA":
    imgs_folder = "../datasets/AVA/images"
    #annotation_file = "../datasets/AVA/CLEAN_AVA_FULL_COMMENTS.feather"
    annotation_file = "../datasets/AVA/CLEAN_AVA_SAMPLE_COMMENTS.feather"
    test_file = "../datasets/AVA/AVA_test.txt"

In [31]:
transform = transforms.Compose(
    [
        transforms.Resize((356,356)),
        transforms.RandomCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ]
)

train_loader, dataset = get_loader(
    dataset_to_use=dataset_to_use,
    imgs_folder=imgs_folder,
    annotation_file=annotation_file,
    test_file=test_file,
    transform=transform,
    batch_size=batch_size,
    num_workers=num_workers,
    freq_threshold=8,
)
vocab_size = len(dataset.vocab)

In [None]:
#for i in range(vocab_size - 500, vocab_size):
    #print(dataset.vocab.itos[i])
    
for i in range(500):
    print(dataset.vocab.itos[i])

In [23]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")           ## Nvidia CUDA Acceleration
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")    ## Apple M1 Metal Acceleration

In [32]:
# initialize model, loss, etc
model = CNNtoLSTM(
    embed_size=embed_size, 
    hidden_size=hidden_size, 
    vocab_size=vocab_size,
    num_layers=num_layers,
    dropout=dropout,
    ).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [33]:
load_model = False
save_model = True

In [34]:
if save_model:
    writer = SummaryWriter(os.path.join("../CNN-LSTM/runs", dataset_to_use))
step = 0

if load_model:
    step = load_checkpoint(torch.load("../CNN-LSTM/runs/checkpoint.path.tar"), model, optimizer)

In [38]:
model.train()

for epoch in range(num_epochs):
        #for idx, (imgs, captions, lengths) in enumerate(train_loader):
        for idx, (imgs, captions, lengths) in tqdm(enumerate(train_loader), total=len(train_loader), leave=True):
            optimizer.zero_grad()
            
            imgs = imgs.to(device)
            captions = captions.to(device)
            
            outputs = model(imgs, captions)
            
            targets = captions.view(-1)
            
            outputs = outputs.view(-1, outputs.shape[2])
            
            loss = criterion(outputs, targets)
            
            if save_model:
                writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1
            
            loss.backward()
            optimizer.step()
            
            if step % 1000 == 0:
                print("Epoch [{}/[{}], Step [{}], Loss: {:.4f}".format(epoch+1, num_epochs, step, loss.item()))
            
        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            loc = "../CNN-LSTM/runs/checkpoint.pth.tar"
            save_checkpoint(checkpoint, filename=loc)
        print("Epoch [{}/[{}], Loss: {:.4f}".format(epoch+1, num_epochs, loss.item()))



  0%|          | 0/633 [00:00<?, ?it/s]


IndexError: index 1 is out of bounds for dimension 0 with size 1

In [None]:
checkpoint = {
    "state_dict": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "step": step,
}
loc = "../CNN-LSTM/runs/checkpoint.pth.tar"
save_checkpoint(checkpoint, filename=loc)

In [None]:
load_checkpoint(torch.load("../CNN-LSTM/runs/checkpoint.pth.tar"), model, optimizer)

print_examples(model, device, dataset)

In [None]:
## Generate text from random initialization
inputs = torch.rand(1, embed_size).to(device)
outputs = model.decoder.generate_text(inputs, dataset.vocab)
print(outputs)