In [1]:
%load_ext autoreload
%autoreload 2

# Training

In [2]:
import time
import os
import torch
import math
import numpy as np
from torch import nn
from torch.utils.data import Dataset, sampler
from models import Encoder, Decoder
from train_utils import train, validate, early_stopping
from data import Vocabulary, get_loader
from fastprogress.fastprogress import master_bar, progress_bar
import wandb

In [3]:
#!wandb login

Remember to run `wandb login` in the terminal to authenticate, for some reason `!wandb login` doesn't work

In [4]:
# Hyperparameters
cfg = { 
        "device" : "cuda" if torch.cuda.is_available() else "cpu",
        "batch_size" : 32,
        "num_epochs" : 50,
        "lr" : 0.001,
        "momentum": 0.01,
        "hidden_size": 512,
        "embed_size": 512,
        "n_layers": 1,
        "dropout": 0.1,
        "seed" : 0,
        "dataset": "flickr8k"
        }

# Setting seeds for reproducibility
torch.manual_seed(cfg["seed"])
np.random.seed(cfg["seed"])
#torch.backends.cudnn.deterministic = True # It makes training slower

# Logs
train_losses = []
val_losses = []
val_bleus = []
best_bleu = float("-INF")

In [5]:
# Data loaders
train_loader = get_loader("TRAIN", cfg["batch_size"])
val_loader = get_loader("VAL", cfg["batch_size"])

# Models
encoder = Encoder(cfg["embed_size"], cfg["momentum"]).to(cfg["device"])
decoder = Decoder(cfg["embed_size"], 
                  cfg["hidden_size"], 
                  len(train_loader.dataset.vocab),
                  cfg["n_layers"],
                  cfg["dropout"]).to(cfg["device"])

# Loss and optimizer
criterion = nn.CrossEntropyLoss().to(cfg["device"])
params = (list(filter(lambda p: p.requires_grad, encoder.parameters()))+
        list(filter(lambda p: p.requires_grad, decoder.parameters())))
optimizer = torch.optim.Adam(params, lr=cfg["lr"])

In [6]:
# Wandb project init
wandb.init(
  project="autocaption",
  notes="baseline",
  tags=["baseline"],
  config=cfg,
)

wandb.watch([encoder, decoder])

[<wandb.wandb_torch.TorchGraph at 0x1ed7c33eb88>,
 <wandb.wandb_torch.TorchGraph at 0x1ed7c339c88>]

In [7]:
mb = master_bar(range(cfg["num_epochs"]))

for epoch in mb:
    start = time.time()
    train_loss = train( loader=train_loader,
                        encoder=encoder,
                        decoder=decoder,
                        criterion=criterion,
                        opt=optimizer,
                        epoch=epoch,
                        cfg=cfg,
                        mb=mb)
    train_losses.append(train_loss)
    
    

    val_loss, val_bleu = validate(  loader=val_loader,
                                    encoder=encoder,
                                    decoder=decoder,
                                    criterion=criterion,
                                    epoch=epoch,
                                    cfg=cfg,
                                    mb=mb)
    val_losses.append(val_loss)
    val_bleus.append(val_bleu)
    
    mb.write('> Epoch {}/{}'.format(epoch + 1, cfg["num_epochs"]))
    mb.write('# TRAIN')
    mb.write('# Loss {:.3f}, Perplexity {:.3f}'.format(train_loss, np.exp(train_loss)))
    mb.write('# VALIDATION')
    mb.write('# Loss {:.3f}, Perplexity {:.3f}, BLEU {:.3f}'.format(val_loss, np.exp(val_loss), val_bleu))
    mb.write(">Runtime {:.3f}".format(time.time() - start))

    # Send logs to wandb
    wandb.log({'train_loss': train_loss,
               'train_perplexity': np.exp(train_loss),
               'val_loss': val_loss,
               'val_perplexity': np.exp(val_loss),
               'val_bleu': val_bleu}, step=epoch)
    
    if val_bleu > best_bleu:
        mb.write("Validation BLEU improved from {} to {}, saving model at ./data/models/best-model.ckpt".format(best_bleu, val_bleu))
        best_bleu = val_bleu
        
        filename = os.path.join("./data/models", "best-model.ckpt")
        torch.save({"encoder": encoder.state_dict(),
                    "decoder": decoder.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "train_losses": train_losses,
                    "val_losses": val_losses,
                    "best_bleu": best_bleu,
                    "val_bleus": val_bleus,
                    "epoch": epoch
                }, filename)
    else:
        mb.write("Validation BLEU did not improve, saving model at ./data/models/model-{}.ckpt".format(epoch))
        filename = os.path.join("./data/models", "model-{}.ckpt".format(epoch))
        torch.save({"encoder": encoder.state_dict(),
                    "decoder": decoder.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "train_losses": train_losses,
                    "val_losses": val_losses,
                    "best_bleu": best_bleu,
                    "val_bleus": val_bleus,
                    "epoch": epoch
                }, filename)
    
    # Saving last model to wandb, works only if Jupyter is executed as Admin
    try: 
        wandb.save(filename)
    except:
        pass
        
    if epoch > 5:
        if early_stopping(val_bleus, patience=3):
            mb.write("Validation BLEU did not improve for 3 consecutive epochs, stopping")
            break

# Testing

In [11]:
import os
import torch
from models import Encoder, Decoder
import matplotlib.pyplot as plt
from utils import get_caption
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

loader = get_loader("TEST", cfg["batch_size"])

model_name = 'best-model.ckpt'
checkpoint = torch.load(os.path.join('./data/models', model_name))
encoder = Encoder(cfg["embed_size"], cfg["momentum"]).to(cfg["device"])
decoder = Decoder(cfg["embed_size"], 
                  cfg["hidden_size"], 
                  len(train_loader.dataset.vocab),
                  cfg["n_layers"],
                  cfg["dropout"]).to(cfg["device"])

encoder.load_state_dict(checkpoint['encoder'])
decoder.load_state_dict(checkpoint['decoder'])

encoder.eval()
decoder.eval()

Decoder(
  (embed): Embedding(1842, 512)
  (dropout): Dropout(p=0.1, inplace=False)
  (lstm): LSTM(512, 512, batch_first=True)
  (linear): Linear(in_features=512, out_features=1842, bias=True)
)

In [13]:
results = []
results_beam = []

mb = progress_bar(loader)

with torch.no_grad():
    for (orig_image, image, all_caps) in mb:
        image = image.to(cfg["device"])
            
        # Greedy
        candidates = get_caption(image, 
                                encoder, 
                                decoder, 
                                loader.dataset.vocab,
                                "greedy",
                                True)

        candidates = [c.split() for c in candidates][0]
        all_caps = [c[0].split() for c in all_caps]
        
        bleu_score = sentence_bleu(all_caps, candidates, smoothing_function=SmoothingFunction().method1)
        
        candidates = [" ".join(candidates) for c in candidates]
        all_caps = [" ".join(c) for c in all_caps]
        
        results.append({"orig_image": orig_image, 
                        "caption": candidates,
                        "all_caps": all_caps,
                        "bleu": bleu_score})
        
        # Beam
        candidates = get_caption(image, 
                                encoder, 
                                decoder, 
                                loader.dataset.vocab,
                                "beam",
                                True)

        candidates_bleu = [c.split() for c in candidates][0]
        candidates_bleu = [candidates_bleu[0]]
        all_caps = [c[0].split() for c in all_caps]
        
        bleu_score = sentence_bleu(all_caps, candidates_bleu, smoothing_function=SmoothingFunction().method1)
        
        all_caps = [" ".join(c) for c in all_caps]
        
        results_beam.append({"orig_image": orig_image, 
                        "caption": candidates,
                        "all_caps": all_caps,
                        "bleu": bleu_score})
        

In [None]:
results = sorted(results, key=lambda k: k['bleu'], reverse=True) 

top = results[0:5]
worst = results[-5:]

for e in top:
    orig_image = e["orig_image"]
    caption = e["caption"]
    all_caps = e["all_caps"]
    bleu = e["bleu"]

    plt.imshow(np.squeeze(orig_image))
    #plt.title("Sampled Image")
    #plt.figtext(0.5, 0.01, caption, wrap=True, ha='center', fontsize=12)
    #plt.figtext(0.5, -0.05, "BLEU4 score: {:.2f}".format(bleu), wrap=True, ha='center', fontsize=12)
    plt.show()
    print(">Generated caption: ")
    print(caption)
    print(">Original captions: ")
    for c in all_caps:
        print(c)
    print(">BLEU4 score: {:.3f}".format(bleu))


In [None]:
for e in worst:
    orig_image = e["orig_image"]
    caption = e["caption"]
    all_caps = e["all_caps"]
    bleu = e["bleu"]

    plt.imshow(np.squeeze(orig_image))
    #plt.title("Sampled Image")
    #plt.figtext(0.5, 0.01, caption, wrap=True, ha='center', fontsize=12)
    #plt.figtext(0.5, -0.05, "BLEU4 score: {:.2f}".format(bleu), wrap=True, ha='center', fontsize=12)
    plt.show()
    print(">Generated caption: ")
    print(caption)
    print(">Original captions: ")
    for c in all_caps:
        print(c)
    print(">BLEU4 score: {:.3f}".format(bleu))


In [None]:
results_beam = sorted(results_beam, key=lambda k: k['bleu'], reverse=True) 

top = results_beam[0:5]
worst = results_beam[-5:]

for e in top:
    orig_image = e["orig_image"]
    caption = e["caption"]
    all_caps = e["all_caps"]
    bleu = e["bleu"]

    plt.imshow(np.squeeze(orig_image))
    #plt.title("Sampled Image")
    #plt.figtext(0.5, 0.01, caption, wrap=True, ha='center', fontsize=12)
    #plt.figtext(0.5, -0.05, "BLEU4 score: {:.2f}".format(bleu), wrap=True, ha='center', fontsize=12)
    plt.show()
    print(">Generated caption: ")
    print(caption)
    print(">Original captions: ")
    for c in all_caps:
        print(c)
    print(">BLEU4 score: {:.3f}".format(bleu))


In [None]:
for e in worst:
    orig_image = e["orig_image"]
    caption = e["caption"]
    all_caps = e["all_caps"]
    bleu = e["bleu"]

    plt.imshow(np.squeeze(orig_image))
    #plt.title("Sampled Image")
    #plt.figtext(0.5, 0.01, caption, wrap=True, ha='center', fontsize=12)
    #plt.figtext(0.5, -0.05, "BLEU4 score: {:.2f}".format(bleu), wrap=True, ha='center', fontsize=12)
    plt.show()
    print(">Generated caption: ")
    print(caption)
    print(">Original captions: ")
    for c in all_caps:
        print(c)
    print(">BLEU4 score: {:.3f}".format(bleu))


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

greedy_test_bleu = [d['bleu'] for d in results]
beam_test_bleu = [d['bleu'] for d in results_beam]

sns.set(style="whitegrid")
plt.figure(figsize=(20, 6))
sns.distplot(greedy_test_bleu, label="Greedy search")
sns.distplot(beam_test_bleu, label="Beam search")
plt.legend(fontsize="large")

In [None]:
from numpy import asarray
from numpy import savetxt
data = asarray([greedy_test_bleu, beam_test_bleu])
savetxt('data/results/greedy_beam_20_04.csv', data, delimiter=',')

wandb.run.summary["greedy_test_bleu"] = greedy_test_bleu
wandb.run.summary["beam_test_bleu"] = beam_test_bleu

In [None]:
from numpy import loadtxt
data = loadtxt('data/results/greedy_beam_19_04.csv', delimiter=',')
greedy_test_bleu = data[0]
beam_test_bleu = data[1]
