In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torchvision.models import ResNet50_Weights
import spacy
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import os, sys
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from utils import print_examples, save_checkpoint, load_checkpoint



In [16]:
text = "this is very"
words = text.split(' ')

In [20]:
torch.tensor([[dataset.vocab.stoi[w] for w in words[0:]]])

tensor([[482,  17, 371]])

In [5]:
spacy_eng = spacy.load("en_core_web_sm")

# class Vocabulary:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        
        ## REMOVE USELESS WORDS AND CHARACTERS
        
        tokenized_text = [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
        return tokenized_text
    
    # Create dictionary of vocabulary and frequency
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    # add word to frequencies dictionary, set frequency to 1 (initial word)
                    frequencies[word] = 1
                else:
                    # word is in frequencies dictionary, increment frequency by 1
                    frequencies[word] += 1
                    
                if frequencies[word] == self.freq_threshold:
                    # word has reached frequency threshold in vocab dictionary
                    # add word to vocab dictionary (at most once)
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
                    
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        
        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]



In [6]:
# Flickr8k Dataset
class Flickr8k(Dataset):
    def __init__(self, imgs_dir, captions_file, test_file, transform=None, freq_threshold=5):
        self.imgs_dir = imgs_dir
        self.df = pd.read_csv(captions_file)
        self.test_file = test_file
        self.transform = transform
        
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.imgs_dir, img_id)).convert("RGB")
        
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [7]:
# custom collate
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        # Sort batch list by caption length in descending order (longest to shortest)
        batch.sort(key=lambda x: len(x[1]), reverse=True)
        
        ## TEST IF SORTING IS NEEDED OR NOT
        
        imgs, captions = zip(*batch)
            
        imgs = torch.stack(imgs)
        lengths = [len(cap) for cap in captions]
        targets = pad_sequence(captions, batch_first=False, padding_value=self.pad_idx)
        """
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        captions = [item[1] for item in batch]
        lengths = [len(cap) for cap in captions]
        targets = pad_sequence(captions, batch_first=False, padding_value=self.pad_idx)
        """
        # imgs:    (batch size, 3, 224, 224)
        # targets: (sequence length, batch size)
        # lengths: (batch size)
        return imgs, targets, lengths

In [8]:
# def get_loader()
def get_loader(dataset_to_use, imgs_folder, annotation_file, transform, test_file="", batch_size=32, num_workers=8, freq_threshold=5, shuffle=True, pin_memory=True):
    
    #if dataset_to_use == "PCCD":
        #dataset = PCCD(imgs_folder, annotation_file, test_file, transform=transform, freq_threshold=freq_threshold)
    if dataset_to_use == "flickr8k":
        dataset = Flickr8k(imgs_folder, annotation_file, test_file, transform=transform, freq_threshold=freq_threshold)
    
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )
    
    return loader, dataset

In [9]:
def save_checkpoint(state, filename):
    print("-- Saving Checkpoint --")
    torch.save(state, filename)
    
    
def load_checkpoint(checkpoint, model, optimizer):
    print("-- Loading Checkpoint --")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

In [10]:
def print_examples(model, device, dataset):
    transform = transforms.Compose(
        [
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        ]
    )
    
    filename_loc = dataset.test_file
    images_loc = dataset.imgs_dir
    
    filename_list = pd.read_csv(filename_loc, header=None)
    filename_list = filename_list.values.reshape(-1).tolist()
    
    
    model.eval()
    
    for i, dir in enumerate(filename_list):
        path = os.path.join(images_loc, dir)
        test_img = transform(Image.open(path).convert("RGB")).unsqueeze(0)
        print(f"Example {i}) OUTPUT: " + " ".join(model.caption_image(test_img.to(device), dataset.vocab)))
        if i > 5:
            break
        
    model.train()

In [11]:
# Language Model
class DecoderLSTM(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout, bidirectional=False):
        super(DecoderLSTM, self).__init__()
        if num_layers < 2: dropout=0.0
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        
        # Word Embeddings - https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        # Input:  (sequence length, batch size)
        # Output: (sequence length, batch size, embed size)
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
        
        # LSTM - https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        # Input:  (sequence length, batch size, embed size)
        # Output: (sequence length, batch size, 2 if bidirectional else 1, hidden size)
        self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=False,bidirectional=bidirectional)
        
        # Fully Connected
        # Input:  (sequence length, batch size, hidden size)
        # Output: (sequence length, batch size, vocab size)
        self.linear = nn.Linear(in_features=hidden_size, out_features=vocab_size,)
    
    def forward(self, features, captions, lengths):
        
        embeddings = self.embed(captions)
        
        #https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html#torch.nn.utils.rnn.pack_padded_sequence
        packed = pack_padded_sequence(embeddings, lengths, batch_first=False)    
        
        lstm_out, _ = self.lstm(packed)
        outputs = self.linear(lstm_out)
        
        return outputs
    
    def generate_text(self, vocabulary, max_length=50):
        result_text = []
        
        with torch.no_grad():
            inputs = vocabulary.stoi["<SOS>"]
            inputs = self.embed(inputs)
            hidden = None
            
            for _ in range(max_length):
                lstm_out, hidden = self.lstm(inputs, hidden)
                lstm_out = lstm_out.squeeze(0)
                output = self.linear(lstm_out)
                
                predicted = torch.argmax(output, dim=1)
                result_text.append(predicted.item())
                
                inputs = self.embed(predicted)
                inputs = inputs.unsqueeze(0)
                
                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break
                
        return [vocabulary.itos[idx] for idx in result_text]

In [12]:
# Hyperparameters
embed_size = 256
hidden_size = 256
num_layers = 2
learning_rate = 3e-4
batch_size = 32
num_workers = 2
dropout = 0.0

num_epochs = 5

In [13]:
#dataset_to_use = "PCCD"
dataset_to_use = "flickr8k"

if dataset_to_use == "PCCD":
    imgs_folder = "../datasets/PCCD/images/full"
    annotation_file = "../datasets/PCCD/raw.json"
    
elif dataset_to_use == "flickr8k":
    imgs_folder = "../datasets/flickr8k/images"
    annotation_file = "../datasets/flickr8k/captions.txt"

In [14]:
transform = transforms.Compose(
    [
        transforms.Resize((356,356)),
        transforms.RandomCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ]
)

train_loader, dataset = get_loader(
    dataset_to_use=dataset_to_use,
    imgs_folder=imgs_folder,
    annotation_file=annotation_file,
    transform=transform,
    batch_size=batch_size,
    num_workers=num_workers,
    freq_threshold=5,
)
vocab_size = len(dataset.vocab)

In [15]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")           ## Nvidia CUDA Acceleration
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")    ## Apple M1 Metal Acceleration

In [19]:
language_model = DecoderLSTM(
    embed_size=embed_size,
    hidden_size=hidden_size,
    vocab_size=vocab_size,
    num_layers=num_layers,
    dropout=dropout,
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.AdamW(language_model.parameters(), lr=learning_rate)

In [21]:
load_model = False
save_model = False

In [22]:
if save_model:
    writer = SummaryWriter(os.path.join("CNN_LSTM/runs/language_model", dataset_to_use))
step = 0

if load_model:
    step = load_checkpoint(torch.load("CNN_LSTM/runs/language_model/checkpoint.path.tar"), language_model, optimizer)

In [26]:
language_model.train()

for epoch in range(num_epochs):
    
    for idx, (imgs, captions, lengths) in tqdm(enumerate(train_loader), total=len(train_loader), leave=True):
        optimizer.zero_grad()
        
        captions = captions.to(device)
        inputs = None
        
        outputs = language_model(inputs, captions, lengths)
        
        captions = captions.reshape(-1)
        
        outputs = outputs.reshape(-1, outputs.shape[2])
        
        loss = criterion(outputs, captions)
        
        loss.backward()
        optimizer.step()
    
    if save_model:
        checkpoint = {
            'epoch': epoch,
            'step': step,
            'loss': loss,
            'state_dict': language_model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        save_checkpoint(checkpoint, filename="CNN-LSTM/runs/language_model/checkpoint.pth.tar")
        
    print("Epoch [{}/[{}], Loss: {:.4f}".format(epoch+1, num_epochs, loss.item()))


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/keenansamway/opt/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/keenansamway/opt/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'Flickr8k' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [None]:
load_checkpoint(torch.load("CNN_LSTM/runs/language_model/checkpoint.path.tar"), language_model, optimizer)

print_examples(language_model, device, dataset)