In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import config
import pickle
from data_loader import get_loader,CocoDataset
from build_vocab import Vocabulary

In [3]:
from torchvision import models
import torch

In [4]:
with open(config.VOCAB_PATH, 'rb') as f:
    vocab = pickle.load(f)

TRAIN_LOADER = {'root':config.TRAIN_FEATURE_PATH, 'json':config.TRAIN_JSON_PATH, 'vocab':vocab, 'batch_size':32, 'shuffle':True, 'num_workers':4}
VAL_LOADER = {'root':config.VAL_FEATURE_PATH, 'json':config.VAL_JSON_PATH, 'vocab':vocab, 'batch_size':32, 'shuffle':False, 'num_workers':4}

In [5]:
train_loader = get_loader(**TRAIN_LOADER)
val_loader = get_loader(**VAL_LOADER)

loading annotations into memory...
Done (t=0.71s)
creating index...
index created!
loading annotations into memory...
Done (t=0.38s)
creating index...
index created!


In [6]:
vocab_size = len(vocab)
vocab_size

4530

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
from torch import nn,optim
from model import EncoderCNN,DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm import tqdm
encoder = EncoderCNN(512,image_dim=256)
encoder.to(device)
decoder = DecoderRNN(256,300,300,vocab_size)
decoder.to(device)
enc_optimizer = optim.SGD(encoder.parameters(),lr=1e-4,momentum=0.9)
dec_optimizer = optim.SGD(decoder.parameters(),lr=1e-2,momentum=0.9)
for epoch in range(10):
    total_loss = 0
    for i,data in enumerate(tqdm(train_loader)):
        enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()
        images,captions,lengths = data
        images = images.to(device)
        captions = captions.to(device)
        encoded_images = encoder(images)
        predictions,captions,lengths,alphas = decoder(encoded_images,captions,lengths)
        if i%1000 ==0 and epoch ==1:
            predicted = predictions.argmax(dim=2)
            print(predicted.shape)
            for sentence in predicted:
                for word in sentence:
                    print(vocab.idx2word[word.cpu().detach().item()],end=" ")
                print("------------")
        targets = captions[:, 1:]
        scores, _,_,_ = pack_padded_sequence(predictions, lengths, batch_first=True)
        targets, _,_,_ = pack_padded_sequence(targets, lengths, batch_first=True)
        # Calculate loss
        loss = nn.CrossEntropyLoss()(scores, targets)
        if i%1000 == 0:
            print(loss.item())
        loss.backward()
        enc_optimizer.step()
        dec_optimizer.step()
        total_loss += loss.item()

    print(total_loss/len(train_loader))

  0%|          | 1/12942 [00:00<1:27:34,  2.46it/s]

8.437864303588867


  8%|▊         | 1004/12942 [00:48<10:38, 18.69it/s]

4.183603286743164


 15%|█▌        | 2005/12942 [01:37<08:04, 22.57it/s]

3.9741995334625244


 23%|██▎       | 3002/12942 [02:26<12:21, 13.40it/s]  

4.046089172363281


 31%|███       | 4004/12942 [03:14<06:43, 22.17it/s]

3.81294584274292


 39%|███▊      | 5002/12942 [04:03<06:14, 21.18it/s]

3.6298117637634277


 46%|████▋     | 6005/12942 [04:52<05:14, 22.08it/s]

3.885234832763672


 54%|█████▍    | 7003/12942 [05:40<04:36, 21.47it/s]

3.944934368133545


 62%|██████▏   | 8001/12942 [06:28<04:01, 20.49it/s]

3.1592273712158203


 70%|██████▉   | 9005/12942 [07:15<02:54, 22.56it/s]

3.047868490219116


 77%|███████▋  | 10003/12942 [08:04<02:22, 20.62it/s]

3.112248182296753


 85%|████████▌ | 11003/12942 [08:54<01:40, 19.27it/s]

3.3203375339508057


 93%|█████████▎| 12003/12942 [09:43<00:44, 20.87it/s]

3.2571170330047607


100%|██████████| 12942/12942 [10:28<00:00, 20.59it/s]
  0%|          | 0/12942 [00:00<?, ?it/s]

3.612608749998019


  0%|          | 1/12942 [00:00<1:38:20,  2.19it/s]

torch.Size([32, 23])
a man with a side a up to on . a <unk> . a street . a . . a <unk> . <end> ------------
a man parked a red on motorcycle a a street . a a front city . . <end> skiier skiier skiier skiier skiier ------------
a people are a beach beach with a water . a . a a of <end> skiier skiier skiier skiier skiier skiier skiier ------------
a man white white lot with on to a road . the street . <end> skiier skiier skiier skiier skiier skiier skiier skiier ------------
a man with a cake on with a and a glass of food . <end> skiier skiier skiier skiier skiier skiier skiier skiier ------------
a black with with a television with a building . . a desk . <end> skiier skiier skiier skiier skiier skiier skiier skiier ------------
a man with a red <unk> a <unk> of a <unk> . . <end> skiier skiier skiier skiier skiier skiier skiier skiier skiier ------------
a boy with on a park with with a frisbee bear bear . <end> skiier skiier skiier skiier skiier skiier skiier skiier skiier ------------

  8%|▊         | 1002/12942 [00:47<09:43, 20.45it/s]

torch.Size([32, 17])
a small and white fire is on top of a tree . . a tree . <end> ------------
a plate of food with a sandwich of fruit and . a table table . <end> skiier ------------
a man is a skateboard with a is holding to hit to the . . <end> skiier ------------
a people a and on a white with a . table <unk> table <unk> <end> skiier skiier ------------
a plate of a <unk> of food on a and and , cheese . <end> skiier skiier ------------
a and and and and yellow train station a with a . it . <end> skiier skiier ------------
a man sitting on a table <unk> to a train . a . <end> skiier skiier skiier ------------
a man of people riding a on a street of a city . <end> skiier skiier skiier ------------
a people standing in the beach of a beach beach water ocean . <end> skiier skiier skiier ------------
a man vase and and a vase <end> a wall . . . <end> skiier skiier skiier ------------
a table with a table white and a table . a . <end> skiier skiier skiier skiier ------------
a man holdi

 15%|█▌        | 2001/12942 [01:37<10:18, 17.70it/s]  

torch.Size([32, 17])
a vase table is a with a and sitting . a knife white . top of <end> ------------
a cat is a head on in on a head . the window . . <end> skiier ------------
a giraffe is in a of a tree green field . <end> a . <end> skiier skiier ------------
a woman man with with a laptop on a of a laptop . <end> skiier skiier skiier ------------
a man is a skateboard in a racket . . the air . <end> skiier skiier skiier ------------
a giraffes standing in a grass in a standing large and in is <end> skiier skiier skiier ------------
a man of a flying up a snowy . . man . . . skiier skiier skiier ------------
a man is on top bench with holding on a cell . <end> skiier skiier skiier skiier ------------
a man are in a room with standing is a . <end> skiier skiier skiier skiier skiier ------------
a dog dog dog sitting on top of a bed . <end> skiier skiier skiier skiier skiier ------------
a bear bear is standing next to a . a water . skiier skiier skiier skiier skiier ------------
a whi

 23%|██▎       | 3003/12942 [02:27<08:49, 18.76it/s]

torch.Size([32, 22])
a man building of a building and a car . a of a building building . a . the of it . ------------
a red of a street with a street and . it side . a . down the street . <end> skiier skiier ------------
a group baseball with to with a ready to a a man of . people of like a <end> skiier skiier skiier ------------
a pizza with with a sandwich , on and , cheese of pizza . <end> skiier skiier skiier skiier skiier skiier skiier ------------
a large vase with with with a and and other cup and flowers . <end> skiier skiier skiier skiier skiier skiier skiier ------------
a white up of a white on a toilet of a . . <end> skiier skiier skiier skiier skiier skiier skiier skiier ------------
a man holding in a tennis court holding to a ball . . <end> skiier skiier skiier skiier skiier skiier skiier skiier ------------
a man and white with with in has of be <unk> . <end> skiier skiier skiier skiier skiier skiier skiier skiier skiier ------------
a man with a cell phone in a <unk> .

 31%|███       | 4001/12942 [03:16<07:31, 19.79it/s]

torch.Size([32, 17])
a plate of pizza pizza on a plate with to a cup <end> food . . <end> ------------
a man of cake of pizza with is to on a plate . . the . <end> ------------
a bathroom bathroom is a <unk> in toilet toilet and and a sink . . <end> skiier ------------
a person is in a mountain on at . a mountain . . <end> skiier skiier skiier ------------
a men are a grass of a field with a a frisbee and <end> skiier skiier skiier ------------
a group and with a and <unk> with a window wall . . <end> skiier skiier skiier ------------
a plate with with on a table table with a table . . skiier skiier skiier skiier ------------
a man man holding a tennis of a in a hand . <end> skiier skiier skiier skiier ------------
a skier of a snowboard of down a hill slope . mountain . skiier skiier skiier skiier ------------
a group of people are on a with a boat umbrella . <end> skiier skiier skiier skiier ------------
a vase large with with a flowers on on a vase . <end> skiier skiier skiier skiie

 39%|███▊      | 5003/12942 [04:04<06:38, 19.91it/s]

torch.Size([32, 19])
a white is phone is a <unk> table on it bowl <end> it table . . it . <end> ------------
a man is a black shirt and a and . a black and . a <unk> . <end> skiier ------------
a bird bird white dog is a <unk> dog <unk> and and . . . . <end> skiier skiier ------------
a red truck truck parked a truck truck . parked on to a bus . <end> skiier skiier skiier ------------
a large boat with the water with water the the river . . body . <end> skiier skiier skiier ------------
a woman holding a shirt shirt holding glasses cell . a table . . <end> skiier skiier skiier skiier ------------
a open <unk> is a window room with . a bed . a <end> skiier skiier skiier skiier skiier ------------
a man man in a woman standing playing a game in . <end> skiier skiier skiier skiier skiier skiier ------------
a trains train to a train tracks on on the tracks . <end> skiier skiier skiier skiier skiier skiier ------------
a train is a train in a a a train train . <end> skiier skiier skiier sk

 46%|████▋     | 6001/12942 [04:51<06:18, 18.34it/s]

torch.Size([32, 18])
a bathroom bathroom with sink a sink and . . . <unk> window . . . . <end> ------------
a plate is sitting a table top to a bowl . food . a . . <end> skiier ------------
a man is a a laptop of a tie . and and and a and <end> skiier skiier ------------
a plate of <unk> sitting on on top table . to a table . coffee . skiier skiier ------------
a women are on a table holding <unk> in holding on a <unk> . <end> skiier skiier skiier ------------
a large bear bear walking in to a tree fence . a large . <end> skiier skiier skiier ------------
a white and white cat of a old and . a room . . <end> skiier skiier skiier ------------
a large <unk> sitting a on and and and , a items . . <end> skiier skiier skiier ------------
a holding a sitting a table of food . and and a . <end> skiier skiier skiier skiier ------------
a man boat is on a field with a beach of a water . skiier skiier skiier skiier ------------
a person riding a of a snow on a snowy day . <end> skiier skiier ski

 47%|████▋     | 6145/12942 [05:00<06:25, 17.64it/s]

In [62]:
vocab.idx2word[2]

'<end>'

KeyError: 4530