In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import OrderedDict
torch.manual_seed(1)
import tensorflow as tf

In [6]:
import numpy as np
import json
import os
from nltk.tokenize import word_tokenize
with open('config.json') as config_file:
    config = json.load(config_file)
data_path = config['data_path']

### CREATING DATA FROM LABELS

In [3]:
file_id = 250
with open(os.path.join(data_path,f'refCOCO/test/labels/lab_{file_id}.json')) as json_file:
    label = json.load(json_file)
print(label)

{'ref_sents': ['comp monitor', 'computer monitor, left', 'Left monitor'], 'label': 72, 'bbox': [[0.0, 98.36000061035156, 167.91000366210938, 277.989990234375]]}


In [13]:
word_tokenize(label['ref_sents'][0])

['comp', 'monitor']

In [16]:
label_data = []
for file_id in range(5000):
    with open(os.path.join(data_path,f'refCOCO/test/labels/lab_{file_id}.json')) as json_file:
        label = json.load(json_file)
    ref_sents = [word_tokenize(sent) for sent in label['ref_sents']]
    label_data += ref_sents

In [27]:
train_label_data = []
for file_id in range(40000):
    with open(os.path.join(data_path,f'refCOCO/train/labels/lab_{file_id}.json')) as json_file:
        label = json.load(json_file)
    ref_sents = [word_tokenize(sent) for sent in label['ref_sents']]
    train_label_data += ref_sents

In [42]:
len(train_label_data)

113762

In [43]:
np.save('train_label_tokenized.npy', train_label_data)

In [7]:
train_dat = np.load('train_label_tokenized.npy', allow_pickle=True)
test_dat = np.load('test_label_tokenized.npy', allow_pickle=True)


In [46]:
combined_dat = list(train_dat) + list(test_dat)

In [47]:
len(combined_dat)

127964

In [51]:
np.save('combined_label_tokenized.npy', combined_dat)

In [8]:
combined_data = np.load('combined_label_tokenized.npy', allow_pickle=True)

## BUILDING MODEL

In [9]:
combined_data = list(combined_data)

In [10]:
vocab = set()
for sent in combined_data:
    for w in sent:
        vocab.add(w.lower())
vocab.add('')

In [11]:
vocab_size = len(vocab)


In [8]:
def prepare_sequence(seq, to_idx):
    idxs = [to_idx[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


In [9]:
class LSTMNextword(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, num_layers=1):
        super(LSTMNextword, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers)
        self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, vocab_size)
        self.softmax = nn.Softmax()
    
    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
#         print(embeds.shape)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
#         print(lstm_out.shape)
        lin1_out = self.linear1(lstm_out.view(len(sentence), -1))
#         print(lin1_out.shape)
        lin2_out = self.softmax(self.linear2(lin1_out))
#         print(lin2_out.shape)
        return lin2_out
        

In [24]:
# model = nn.Sequential(OrderedDict([
#     ('embedding', nn.Embedding(num_embeddings=vocab_size, embedding_dim=32)),
#     ('LSTM', nn.LSTM(input_size=32, hidden_size=10, num_layers=2)),
#     ('linear', nn.Linear(10, vocab_size)),
#     ('relu', nn.ReLU()),
#     ('linear', nn.Linear(vocab_size, vocab_size)),
#     ('softmax', nn.Softmax())]))

In [14]:
print(model)

Sequential(
  (embedding): Embedding(9904, 10)
  (LSTM): LSTM(10, 10, num_layers=2)
  (linear): Linear(in_features=9904, out_features=9904, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
)


## PREPROCESS DATA

In [12]:
train_dat = list(train_dat)
train_data = []
for data_point in train_dat:
    d_point = [w.lower() for w in data_point[::-1]]
    train_data.append(d_point)
test_dat = list(test_dat)
test_data = []
for data_point in test_dat:
    d_point = [w.lower() for w in data_point[::-1]]
    test_data.append(d_point)

In [13]:
combined_data[10:20]

[['front', 'cnter', 'cow'],
 ['front', 'cow'],
 ['thingy', 'on', 'right'],
 ['candleholder', 'on', 'right'],
 ['right', 'thing'],
 ['biggest', 'monitor'],
 ['front', 'monitor'],
 ['the', 'monitor', 'dead', 'center', '(', 'apple', 'logo', ')'],
 ['person', 'in', 'chair', 'at', 'left'],
 ['guy', 'sitting', 'on', 'left']]

In [14]:
processed_data = []
for data_point in combined_data:
    processed_data_point = [w.lower() for w in data_point[::-1]]
    processed_data.append(processed_data_point)

In [14]:
train_data[145]

['head', 'no', ',', 'animal', 'rightmost']

In [15]:
one_gram_data = []
for sentence in processed_data:
    for i in range(len(sentence)-1):
        one_gram_data.append([[sentence[i]], sentence[i+1]])

In [16]:
two_gram_train = []
for sentence in train_data:
    sentence.insert(0, '')
    sentence.append('')
    for i in range(len(sentence)-2):
        two_gram_train.append([sentence[i:i+2], sentence[i+2]])
two_gram_test = []
for sentence in test_data:
    sentence.insert(0, '')
    sentence.append('')
    for i in range(len(sentence)-2):
        two_gram_test.append([sentence[i:i+2], sentence[i+2]])

In [None]:
three_gram_train = []
for sentence in train_data:
    sentence.insert(0, '')
    sentence.insert(0, '')
    sentence.append('')
    for i in range(len(sentence)-3):
        two_gram_train.append([sentence[i:i+3], sentence[i+3]])
three_gram_test = []
for sentence in test_data:
    sentence.insert(0, '')
    sentence.insert(0, '')
    sentence.append('')
    for i in range(len(sentence)-3):
        two_gram_test.append([sentence[i:i+3], sentence[i+3]])

In [17]:
print(processed_data[:5])
print(two_gram_train[:10])
print(len(two_gram_train))

[['center', 'and', 'front', 'creature', 'zebra'], ['zebra'], ['zebra', 'whole'], ['buny', 'most', 'left'], ['bunny', 'side', 'left']]
[[['', 'center'], 'and'], [['center', 'and'], 'front'], [['and', 'front'], 'creature'], [['front', 'creature'], 'zebra'], [['creature', 'zebra'], ''], [['', 'zebra'], ''], [['', 'zebra'], 'whole'], [['zebra', 'whole'], ''], [['', 'buny'], 'most'], [['buny', 'most'], 'left']]
407304


In [18]:
word_to_idx = {}
for sentence in processed_data:
    for word in sentence:
        if word.lower() not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)
print(len(word_to_idx))

9904


In [19]:
word_to_idx[''] = len(word_to_idx)

In [21]:
EMBEDDING_DIMS = 256
HIDDEN_DIMS = 128
LEARNING_RATE = 0.5

In [168]:
model = LSTMNextword(EMBEDDING_DIMS, HIDDEN_DIMS, vocab_size, num_layers=2)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

with torch.no_grad():
    inputs = prepare_sequence(two_gram_train[0][0], word_to_idx)
    out = model(inputs)
    print(torch.argmax(out))
    a = prepare_sequence([two_gram_train[0][1]], word_to_idx).cuda()
    b = torch.tensor(1).cuda()
    if a == b:
        print(True)
    reshaped = out[0].view(1,-1)
    print(out)
    print(out[0][-1])
    print(reshaped.shape)

tensor(19131)
True
tensor([[1.0794e-04, 9.4969e-05, 1.0696e-04,  ..., 9.1370e-05, 1.0080e-04,
         9.4991e-05],
        [1.0683e-04, 9.5908e-05, 1.0646e-04,  ..., 9.2259e-05, 1.0067e-04,
         9.5502e-05]])
tensor(9.4991e-05)
torch.Size([1, 9905])




In [169]:
num_epoch = 10
running_loss = 0.0
model = model.cuda()
for epoch in range(num_epoch):
    c = 0
    correct = 0
    for i, (sent, next_word) in enumerate(two_gram_train):
        model.zero_grad()
#         print(sent, next_word)
        sentence_in = prepare_sequence(sent, word_to_idx).cuda()
        next_word_out = prepare_sequence([next_word], word_to_idx).cuda()
#         print(next_word_out)
        out = model(sentence_in)
#         print(out.shape)
        loss = loss_function(out[0].view(1,-1), next_word_out)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        correct += 1 if torch.argmax(out).cuda == next_word_out else 0
        c += 1
        if i % 2000 == 1999:
            print(f'epoch:{epoch + 1}, {i+1} loss: {running_loss/2000}, accuracy: {correct / 2000}')
            running_loss = 0
            c = 0
            correct = 0
            
                        
                                          
                                          



epoch:1, 2000 loss: 9.200796207904816, accuracy: 0.0
epoch:1, 4000 loss: 9.200794809341431, accuracy: 0.0
epoch:1, 6000 loss: 9.20079367685318, accuracy: 0.0
epoch:1, 8000 loss: 9.200791687965394, accuracy: 0.0
epoch:1, 10000 loss: 9.200789872646332, accuracy: 0.0
epoch:1, 12000 loss: 9.200786926746368, accuracy: 0.0
epoch:1, 14000 loss: 9.200785173416138, accuracy: 0.0
epoch:1, 16000 loss: 9.200782630443573, accuracy: 0.0
epoch:1, 18000 loss: 9.20077845621109, accuracy: 0.0
epoch:1, 20000 loss: 9.20077441930771, accuracy: 0.0
epoch:1, 22000 loss: 9.200770786762238, accuracy: 0.0
epoch:1, 24000 loss: 9.200764177322387, accuracy: 0.0
epoch:1, 26000 loss: 9.20075569486618, accuracy: 0.0
epoch:1, 28000 loss: 9.200744349956512, accuracy: 0.0
epoch:1, 30000 loss: 9.20072652721405, accuracy: 0.0
epoch:1, 32000 loss: 9.200701904296874, accuracy: 0.0
epoch:1, 34000 loss: 9.200609782218933, accuracy: 0.0
epoch:1, 36000 loss: 9.0747164850235, accuracy: 0.0
epoch:1, 38000 loss: 8.919499505519868,

epoch:1, 304000 loss: 8.916968354225158, accuracy: 0.0
epoch:1, 306000 loss: 8.918468344211579, accuracy: 0.0
epoch:1, 308000 loss: 8.925968345165252, accuracy: 0.0
epoch:1, 310000 loss: 8.924468353748322, accuracy: 0.0
epoch:1, 312000 loss: 8.916968336105347, accuracy: 0.0
epoch:1, 314000 loss: 8.925968354701995, accuracy: 0.0
epoch:1, 316000 loss: 8.93396832370758, accuracy: 0.0
epoch:1, 318000 loss: 8.933468351364136, accuracy: 0.0
epoch:1, 320000 loss: 8.92096835231781, accuracy: 0.0
epoch:1, 322000 loss: 8.916968357086182, accuracy: 0.0
epoch:1, 324000 loss: 8.918968342304229, accuracy: 0.0
epoch:1, 326000 loss: 8.919468358516694, accuracy: 0.0
epoch:1, 328000 loss: 8.916968336582183, accuracy: 0.0
epoch:1, 330000 loss: 8.919968361377716, accuracy: 0.0
epoch:1, 332000 loss: 8.917468357086182, accuracy: 0.0
epoch:1, 334000 loss: 8.912468366622925, accuracy: 0.0
epoch:1, 336000 loss: 8.915468356132507, accuracy: 0.0
epoch:1, 338000 loss: 8.928468365192414, accuracy: 0.0
epoch:1, 340

epoch:2, 200000 loss: 8.912968710899353, accuracy: 0.0
epoch:2, 202000 loss: 8.927468708515168, accuracy: 0.0
epoch:2, 204000 loss: 8.924468707084655, accuracy: 0.0
epoch:2, 206000 loss: 8.930968705654145, accuracy: 0.0
epoch:2, 208000 loss: 8.906468708515167, accuracy: 0.0
epoch:2, 210000 loss: 8.916468716144562, accuracy: 0.0
epoch:2, 212000 loss: 8.906468715667724, accuracy: 0.0
epoch:2, 214000 loss: 8.902468715667725, accuracy: 0.0
epoch:2, 216000 loss: 8.927968718528748, accuracy: 0.0
epoch:2, 218000 loss: 8.918468707084656, accuracy: 0.0
epoch:2, 220000 loss: 8.929468710899354, accuracy: 0.0
epoch:2, 222000 loss: 8.919968707561493, accuracy: 0.0
epoch:2, 224000 loss: 8.915968717098236, accuracy: 0.0
epoch:2, 226000 loss: 8.922968713760376, accuracy: 0.0
epoch:2, 228000 loss: 8.92346871471405, accuracy: 0.0
epoch:2, 230000 loss: 8.911968716621399, accuracy: 0.0
epoch:2, 232000 loss: 8.927968712806702, accuracy: 0.0
epoch:2, 234000 loss: 8.921968707561494, accuracy: 0.0
epoch:2, 23

epoch:3, 96000 loss: 8.911968742370606, accuracy: 0.0
epoch:3, 98000 loss: 8.906968742370605, accuracy: 0.0
epoch:3, 100000 loss: 8.922468742370606, accuracy: 0.0
epoch:3, 102000 loss: 8.899968742370605, accuracy: 0.0
epoch:3, 104000 loss: 8.919468742370606, accuracy: 0.0
epoch:3, 106000 loss: 8.913468742370606, accuracy: 0.0
epoch:3, 108000 loss: 8.914468742370605, accuracy: 0.0
epoch:3, 110000 loss: 8.904468742370605, accuracy: 0.0
epoch:3, 112000 loss: 8.914968742370606, accuracy: 0.0
epoch:3, 114000 loss: 8.910968742370606, accuracy: 0.0
epoch:3, 116000 loss: 8.925968742370605, accuracy: 0.0
epoch:3, 118000 loss: 8.936468742370606, accuracy: 0.0
epoch:3, 120000 loss: 8.921968742370606, accuracy: 0.0
epoch:3, 122000 loss: 8.929468742370606, accuracy: 0.0
epoch:3, 124000 loss: 8.935968742370605, accuracy: 0.0
epoch:3, 126000 loss: 8.919968742370605, accuracy: 0.0
epoch:3, 128000 loss: 8.914968742370606, accuracy: 0.0
epoch:3, 130000 loss: 8.927968742370606, accuracy: 0.0
epoch:3, 132

epoch:3, 394000 loss: 8.932968742370605, accuracy: 0.0
epoch:3, 396000 loss: 8.910968742370606, accuracy: 0.0
epoch:3, 398000 loss: 8.917968742370606, accuracy: 0.0
epoch:3, 400000 loss: 8.930468742370605, accuracy: 0.0
epoch:3, 402000 loss: 8.919468742370606, accuracy: 0.0
epoch:3, 404000 loss: 8.934468742370605, accuracy: 0.0
epoch:3, 406000 loss: 8.919468742370606, accuracy: 0.0
epoch:4, 2000 loss: 14.75350036239624, accuracy: 0.0
epoch:4, 4000 loss: 8.939968742370606, accuracy: 0.0
epoch:4, 6000 loss: 8.924468742370605, accuracy: 0.0
epoch:4, 8000 loss: 8.933968742370606, accuracy: 0.0
epoch:4, 10000 loss: 8.908468742370605, accuracy: 0.0
epoch:4, 12000 loss: 8.910968742370606, accuracy: 0.0
epoch:4, 14000 loss: 8.925968742370605, accuracy: 0.0
epoch:4, 16000 loss: 8.930968742370606, accuracy: 0.0
epoch:4, 18000 loss: 8.915468742370605, accuracy: 0.0
epoch:4, 20000 loss: 8.912468742370605, accuracy: 0.0
epoch:4, 22000 loss: 8.931468742370605, accuracy: 0.0
epoch:4, 24000 loss: 8.92

epoch:4, 288000 loss: 8.919468742370606, accuracy: 0.0
epoch:4, 290000 loss: 8.926968742370606, accuracy: 0.0
epoch:4, 292000 loss: 8.930468742370605, accuracy: 0.0
epoch:4, 294000 loss: 8.927468742370605, accuracy: 0.0
epoch:4, 296000 loss: 8.935968742370605, accuracy: 0.0
epoch:4, 298000 loss: 8.921968742370606, accuracy: 0.0
epoch:4, 300000 loss: 8.940468742370605, accuracy: 0.0
epoch:4, 302000 loss: 8.926968742370606, accuracy: 0.0
epoch:4, 304000 loss: 8.916968742370605, accuracy: 0.0
epoch:4, 306000 loss: 8.918468742370605, accuracy: 0.0
epoch:4, 308000 loss: 8.925968742370605, accuracy: 0.0
epoch:4, 310000 loss: 8.924468742370605, accuracy: 0.0
epoch:4, 312000 loss: 8.916968742370605, accuracy: 0.0
epoch:4, 314000 loss: 8.925968742370605, accuracy: 0.0
epoch:4, 316000 loss: 8.933968742370606, accuracy: 0.0
epoch:4, 318000 loss: 8.933468742370605, accuracy: 0.0
epoch:4, 320000 loss: 8.920968742370606, accuracy: 0.0
epoch:4, 322000 loss: 8.916968742370605, accuracy: 0.0
epoch:4, 3

epoch:5, 182000 loss: 8.920468742370605, accuracy: 0.0
epoch:5, 184000 loss: 8.928468742370605, accuracy: 0.0
epoch:5, 186000 loss: 8.929968742370605, accuracy: 0.0
epoch:5, 188000 loss: 8.909468742370606, accuracy: 0.0
epoch:5, 190000 loss: 8.919968742370605, accuracy: 0.0
epoch:5, 192000 loss: 8.911468742370605, accuracy: 0.0
epoch:5, 194000 loss: 8.922968742370605, accuracy: 0.0
epoch:5, 196000 loss: 8.924468742370605, accuracy: 0.0
epoch:5, 198000 loss: 8.923468742370606, accuracy: 0.0
epoch:5, 200000 loss: 8.912968742370605, accuracy: 0.0
epoch:5, 202000 loss: 8.927468742370605, accuracy: 0.0
epoch:5, 204000 loss: 8.924468742370605, accuracy: 0.0
epoch:5, 206000 loss: 8.930968742370606, accuracy: 0.0
epoch:5, 208000 loss: 8.906468742370606, accuracy: 0.0
epoch:5, 210000 loss: 8.916468742370606, accuracy: 0.0
epoch:5, 212000 loss: 8.906468742370606, accuracy: 0.0
epoch:5, 214000 loss: 8.902468742370605, accuracy: 0.0
epoch:5, 216000 loss: 8.927968742370606, accuracy: 0.0
epoch:5, 2

epoch:6, 76000 loss: 8.925968742370605, accuracy: 0.0
epoch:6, 78000 loss: 8.916468742370606, accuracy: 0.0
epoch:6, 80000 loss: 8.916468742370606, accuracy: 0.0
epoch:6, 82000 loss: 8.910468742370606, accuracy: 0.0
epoch:6, 84000 loss: 8.935468742370606, accuracy: 0.0
epoch:6, 86000 loss: 8.914968742370606, accuracy: 0.0
epoch:6, 88000 loss: 8.936968742370606, accuracy: 0.0
epoch:6, 90000 loss: 8.911968742370606, accuracy: 0.0
epoch:6, 92000 loss: 8.932468742370606, accuracy: 0.0
epoch:6, 94000 loss: 8.927968742370606, accuracy: 0.0
epoch:6, 96000 loss: 8.911968742370606, accuracy: 0.0
epoch:6, 98000 loss: 8.906968742370605, accuracy: 0.0
epoch:6, 100000 loss: 8.922468742370606, accuracy: 0.0
epoch:6, 102000 loss: 8.899968742370605, accuracy: 0.0
epoch:6, 104000 loss: 8.919468742370606, accuracy: 0.0
epoch:6, 106000 loss: 8.913468742370606, accuracy: 0.0
epoch:6, 108000 loss: 8.914468742370605, accuracy: 0.0
epoch:6, 110000 loss: 8.904468742370605, accuracy: 0.0
epoch:6, 112000 loss: 

epoch:6, 376000 loss: 8.911468742370605, accuracy: 0.0
epoch:6, 378000 loss: 8.915968742370605, accuracy: 0.0
epoch:6, 380000 loss: 8.931968742370605, accuracy: 0.0
epoch:6, 382000 loss: 8.910968742370606, accuracy: 0.0
epoch:6, 384000 loss: 8.934468742370605, accuracy: 0.0
epoch:6, 386000 loss: 8.914468742370605, accuracy: 0.0
epoch:6, 388000 loss: 8.911468742370605, accuracy: 0.0
epoch:6, 390000 loss: 8.921468742370605, accuracy: 0.0
epoch:6, 392000 loss: 8.926968742370606, accuracy: 0.0
epoch:6, 394000 loss: 8.932968742370605, accuracy: 0.0
epoch:6, 396000 loss: 8.910968742370606, accuracy: 0.0
epoch:6, 398000 loss: 8.917968742370606, accuracy: 0.0
epoch:6, 400000 loss: 8.930468742370605, accuracy: 0.0
epoch:6, 402000 loss: 8.919468742370606, accuracy: 0.0
epoch:6, 404000 loss: 8.934468742370605, accuracy: 0.0
epoch:6, 406000 loss: 8.919468742370606, accuracy: 0.0
epoch:7, 2000 loss: 14.75350036239624, accuracy: 0.0
epoch:7, 4000 loss: 8.939968742370606, accuracy: 0.0
epoch:7, 6000 

epoch:7, 270000 loss: 8.927468742370605, accuracy: 0.0
epoch:7, 272000 loss: 8.927968742370606, accuracy: 0.0
epoch:7, 274000 loss: 8.912968742370605, accuracy: 0.0
epoch:7, 276000 loss: 8.905968742370606, accuracy: 0.0
epoch:7, 278000 loss: 8.921968742370606, accuracy: 0.0
epoch:7, 280000 loss: 8.917968742370606, accuracy: 0.0
epoch:7, 282000 loss: 8.920968742370606, accuracy: 0.0
epoch:7, 284000 loss: 8.921468742370605, accuracy: 0.0
epoch:7, 286000 loss: 8.911968742370606, accuracy: 0.0
epoch:7, 288000 loss: 8.919468742370606, accuracy: 0.0
epoch:7, 290000 loss: 8.926968742370606, accuracy: 0.0
epoch:7, 292000 loss: 8.930468742370605, accuracy: 0.0
epoch:7, 294000 loss: 8.927468742370605, accuracy: 0.0
epoch:7, 296000 loss: 8.935968742370605, accuracy: 0.0
epoch:7, 298000 loss: 8.921968742370606, accuracy: 0.0
epoch:7, 300000 loss: 8.940468742370605, accuracy: 0.0
epoch:7, 302000 loss: 8.926968742370606, accuracy: 0.0
epoch:7, 304000 loss: 8.916968742370605, accuracy: 0.0
epoch:7, 3

epoch:8, 164000 loss: 8.924468742370605, accuracy: 0.0
epoch:8, 166000 loss: 8.922468742370606, accuracy: 0.0
epoch:8, 168000 loss: 8.935468742370606, accuracy: 0.0
epoch:8, 170000 loss: 8.929968742370605, accuracy: 0.0
epoch:8, 172000 loss: 8.907468742370606, accuracy: 0.0
epoch:8, 174000 loss: 8.927968742370606, accuracy: 0.0
epoch:8, 176000 loss: 8.930468742370605, accuracy: 0.0
epoch:8, 178000 loss: 8.928968742370605, accuracy: 0.0
epoch:8, 180000 loss: 8.922968742370605, accuracy: 0.0
epoch:8, 182000 loss: 8.920468742370605, accuracy: 0.0
epoch:8, 184000 loss: 8.928468742370605, accuracy: 0.0
epoch:8, 186000 loss: 8.929968742370605, accuracy: 0.0
epoch:8, 188000 loss: 8.909468742370606, accuracy: 0.0
epoch:8, 190000 loss: 8.919968742370605, accuracy: 0.0
epoch:8, 192000 loss: 8.911468742370605, accuracy: 0.0
epoch:8, 194000 loss: 8.922968742370605, accuracy: 0.0
epoch:8, 196000 loss: 8.924468742370605, accuracy: 0.0
epoch:8, 198000 loss: 8.923468742370606, accuracy: 0.0
epoch:8, 2

epoch:9, 58000 loss: 8.928468742370605, accuracy: 0.0
epoch:9, 60000 loss: 8.925468742370606, accuracy: 0.0
epoch:9, 62000 loss: 8.930968742370606, accuracy: 0.0
epoch:9, 64000 loss: 8.922968742370605, accuracy: 0.0
epoch:9, 66000 loss: 8.937468742370605, accuracy: 0.0
epoch:9, 68000 loss: 8.931468742370605, accuracy: 0.0
epoch:9, 70000 loss: 8.934968742370605, accuracy: 0.0
epoch:9, 72000 loss: 8.923968742370606, accuracy: 0.0
epoch:9, 74000 loss: 8.936468742370606, accuracy: 0.0
epoch:9, 76000 loss: 8.925968742370605, accuracy: 0.0
epoch:9, 78000 loss: 8.916468742370606, accuracy: 0.0
epoch:9, 80000 loss: 8.916468742370606, accuracy: 0.0
epoch:9, 82000 loss: 8.910468742370606, accuracy: 0.0
epoch:9, 84000 loss: 8.935468742370606, accuracy: 0.0
epoch:9, 86000 loss: 8.914968742370606, accuracy: 0.0
epoch:9, 88000 loss: 8.936968742370606, accuracy: 0.0
epoch:9, 90000 loss: 8.911968742370606, accuracy: 0.0
epoch:9, 92000 loss: 8.932468742370606, accuracy: 0.0
epoch:9, 94000 loss: 8.92796

epoch:9, 358000 loss: 8.927968742370606, accuracy: 0.0
epoch:9, 360000 loss: 8.919468742370606, accuracy: 0.0
epoch:9, 362000 loss: 8.925968742370605, accuracy: 0.0
epoch:9, 364000 loss: 8.934968742370605, accuracy: 0.0
epoch:9, 366000 loss: 8.925468742370606, accuracy: 0.0
epoch:9, 368000 loss: 8.913468742370606, accuracy: 0.0
epoch:9, 370000 loss: 8.937468742370605, accuracy: 0.0
epoch:9, 372000 loss: 8.918968742370605, accuracy: 0.0
epoch:9, 374000 loss: 8.922968742370605, accuracy: 0.0
epoch:9, 376000 loss: 8.911468742370605, accuracy: 0.0
epoch:9, 378000 loss: 8.915968742370605, accuracy: 0.0
epoch:9, 380000 loss: 8.931968742370605, accuracy: 0.0
epoch:9, 382000 loss: 8.910968742370606, accuracy: 0.0
epoch:9, 384000 loss: 8.934468742370605, accuracy: 0.0
epoch:9, 386000 loss: 8.914468742370605, accuracy: 0.0
epoch:9, 388000 loss: 8.911468742370605, accuracy: 0.0
epoch:9, 390000 loss: 8.921468742370605, accuracy: 0.0
epoch:9, 392000 loss: 8.926968742370606, accuracy: 0.0
epoch:9, 3

epoch:10, 248000 loss: 8.922468742370606, accuracy: 0.0
epoch:10, 250000 loss: 8.916968742370605, accuracy: 0.0
epoch:10, 252000 loss: 8.910968742370606, accuracy: 0.0
epoch:10, 254000 loss: 8.921468742370605, accuracy: 0.0
epoch:10, 256000 loss: 8.920468742370605, accuracy: 0.0
epoch:10, 258000 loss: 8.925468742370606, accuracy: 0.0
epoch:10, 260000 loss: 8.903968742370605, accuracy: 0.0
epoch:10, 262000 loss: 8.940468742370605, accuracy: 0.0
epoch:10, 264000 loss: 8.935468742370606, accuracy: 0.0
epoch:10, 266000 loss: 8.912468742370605, accuracy: 0.0
epoch:10, 268000 loss: 8.924468742370605, accuracy: 0.0
epoch:10, 270000 loss: 8.927468742370605, accuracy: 0.0
epoch:10, 272000 loss: 8.927968742370606, accuracy: 0.0
epoch:10, 274000 loss: 8.912968742370605, accuracy: 0.0
epoch:10, 276000 loss: 8.905968742370606, accuracy: 0.0
epoch:10, 278000 loss: 8.921968742370606, accuracy: 0.0
epoch:10, 280000 loss: 8.917968742370606, accuracy: 0.0
epoch:10, 282000 loss: 8.920968742370606, accura

model

In [18]:
print(two_gram_train[0])

[['', 'center'], 'and']


In [20]:
two_gram_train_inputs = []
two_gram_train_outputs = []
for sent, next_word in two_gram_train:
    sentence_in = np.array([word_to_idx[w] for w in sent])
    two_gram_train_inputs.append(sentence_in)
    next_word_out = np.array([word_to_idx[next_word]])
    two_gram_train_outputs.append(next_word_out)
two_gram_train_inputs = np.array(two_gram_train_inputs)
two_gram_train_outputs = to_categorical(two_gram_train_outputs, num_classes=vocab_size)

In [172]:
training_doc3 = """
how many people live in atlanta georgia	Atlanta ( , stressed , locally ) is the capital of and the most populous city in the U.S. state of Georgia , with an estimated 2011 population of 432,427 .	1
how many people live in atlanta georgia	Atlanta is the cultural and economic center of the Atlanta metropolitan area , home to 5,457,831 people and the ninth largest metropolitan area in the United States .	1
how many people live in atlanta georgia	Atlanta is the county seat of Fulton County , and a small portion of the city extends eastward into DeKalb County .	0
how many people live in atlanta georgia	Atlanta was established in 1837 at the intersection of two railroad lines , and the city rose from the ashes of the Civil War to become a national center of commerce .	0
how many people live in atlanta georgia	In the decades following the Civil Rights Movement , during which the city earned a reputation as `` too busy to hate '' for the progressive views of its citizens and leaders , Atlanta attained international prominence .	0
how many people live in atlanta georgia	Atlanta is the primary transportation hub of the Southeastern United States , via highway , railroad , and air , with Hartsfield–Jackson Atlanta International Airport being the world 's busiest airport since 1998 .	0
how many people live in atlanta georgia	Atlanta is considered an `` alpha ( - ) world city , '' and , with a gross domestic product of US $ 270 billion , Atlanta’s economy ranks 15th among world cities and sixth in the nation .	0
how many people live in atlanta georgia	Although Atlanta’s economy is considered diverse , dominant sectors include logistics , professional and business services , media operations , government administration , and higher education .	0
how many people live in atlanta georgia	Topographically , Atlanta is marked by rolling hills and dense tree coverage .	0
how many people live in atlanta georgia	Revitalization of Atlanta 's neighborhoods , initially spurred by the 1996 Olympics , has intensified in the 21st century , altering the city 's demographics , politics , and culture .	0
how many people died at the pentagon in 9 11	The September 11 attacks ( also referred to as September 11 , September 11th , or 9/11 ) were a series of four coordinated terrorist attacks launched by the Islamic terrorist group al-Qaeda upon the United States in New York City and the Washington , D.C. area on September 11 , 2001 .	0
how many people died at the pentagon in 9 11	Four passenger airliners were hijacked by 19 al-Qaeda terrorists so they could be flown into buildings in suicide attacks .	0
how many people died at the pentagon in 9 11	Two of those planes , American Airlines Flight 11 and United Airlines Flight 175 , were crashed into the North and South towers , respectively , of the World Trade Center complex in New York City .	0
how many people died at the pentagon in 9 11	Within two hours , both towers collapsed with debris and the resulting fires causing partial or complete collapse of all other buildings in the WTC complex , as well as major damage to ten other large surrounding structures .	0
how many people died at the pentagon in 9 11	A third plane , American Airlines Flight 77 , was crashed into the Pentagon ( the headquarters of the United States Department of Defense ) , leading to a partial collapse in its western side .	0
how many people died at the pentagon in 9 11	The fourth plane , United Airlines Flight 93 , was targeted at the United States Capitol in Washington , D.C. , but crashed into a field near Shanksville , Pennsylvania , after its passengers tried to overcome the hijackers .	0
how many people died at the pentagon in 9 11	In total , almost 3,000 people died in the attacks , including the 227 civilians and 19 hijackers aboard the four planes .	0
how many people died at the pentagon in 9 11	Suspicion quickly fell on al-Qaeda .	0
how many people died at the pentagon in 9 11	Although the group 's leader , Osama bin Laden , initially denied any involvement , in 2004 he claimed responsibility for the attacks .	0
how many people died at the pentagon in 9 11	Al-Qaeda and bin Laden cited U.S. support of Israel , the presence of U.S. troops in Saudi Arabia , and sanctions against Iraq as motives for the attacks .	0
how many people died at the pentagon in 9 11	The United States responded to the attacks by launching the War on Terror and invading Afghanistan to depose the Taliban , which had harbored al-Qaeda .	0
how many people died at the pentagon in 9 11	Many countries strengthened their anti-terrorism legislation and expanded law enforcement powers .	0
how many people died at the pentagon in 9 11	Having evaded capture for years , bin Laden was located and killed by U.S. forces in May 2011 .	0
how many people died at the pentagon in 9 11	The destruction of the Twin Towers and other properties caused serious damage to the economy of Lower Manhattan and had a significant effect on global markets .	0
how many people died at the pentagon in 9 11	Cleanup of the World Trade Center site was completed in May 2002 , and the Pentagon was repaired within a year .	0
how many people died at the pentagon in 9 11	Numerous memorials have been constructed , including the National September 11 Memorial & Museum in New York , the Pentagon Memorial , and the Flight 93 National Memorial in Pennsylvania .	0
how many people died at the pentagon in 9 11	After a lengthy delay , the One World Trade Center is expected to be completed at Ground Zero in New York City in 2013 .	0
how many people visit crater lake national park each year	Crater Lake National Park is a United States National Park located in southern Oregon .	0
how many people visit crater lake national park each year	Established in 1902 , Crater Lake National Park is the fifth oldest national park in the United States and the only one in the state of Oregon .	0
how many people visit crater lake national park each year	The park encompasses the caldera of Crater Lake , a remnant of a destroyed volcano , Mount Mazama , and the surrounding hills and forests .	0
how many people visit crater lake national park each year	The lake is deep at its deepest point , which makes it the deepest lake in the United States , the second deepest in North America and the ninth deepest in the world .	0
how many people visit crater lake national park each year	Crater Lake is often referred to as the seventh deepest lake in the world , but this former listing excludes the approximately depth of subglacial Lake Vostok in Antarctica , which resides under nearly of ice , and the recent report of a maximum depth for Lake O'Higgins /San Martin , located on the border of Chile and Argentina .	0
how many people visit crater lake national park each year	However , when comparing its average depth of to the average depth of other deep lakes , Crater Lake becomes the deepest in the Western Hemisphere and the third deepest in the world .	0
how many people visit crater lake national park each year	The impressive average depth of this volcanic lake is due to the nearly symmetrical deep caldera formed 7,700 years ago during the violent climactic eruptions and subsequent collapse of Mount Mazama and the relatively moist climate that is typical of the crest of the Cascade Range .	0
how many people visit crater lake national park each year	The caldera rim ranges in elevation from .	0
how many people visit crater lake national park each year	The United States Geological Survey benchmarked elevation of the lake surface itself is .	0
how many people visit crater lake national park each year	This National Park encompasses .	0
how many people visit crater lake national park each year	Crater Lake has no streams flowing into or out of it .	0
how many people visit crater lake national park each year	All water that enters the lake is eventually lost from evaporation or subsurface seepage .	0
how many people visit crater lake national park each year	The lake 's water commonly has a striking blue hue , and the lake is re-filled entirely from direct precipitation in the form of snow and rain .	0
"""


In [5]:
from keras.preprocessing.text import Tokenizer
import re
from keras.utils import to_categorical

Using TensorFlow backend.


In [177]:
cleaned = re.sub(r'\W+', ' ', training_doc3).lower()
tokens = word_tokenize(cleaned)
train_len = 3+1
text_sequences = []
for i in range(train_len,len(tokens)):
    seq = tokens[i-train_len:i]
    text_sequences.append(seq)
sequences = {}
count = 1
for i in range(len(tokens)):
    if tokens[i] not in sequences:
        sequences[tokens[i]] = count
        count += 1

In [179]:
text_sequences[:3]

[['how', 'many', 'people', 'live'],
 ['many', 'people', 'live', 'in'],
 ['people', 'live', 'in', 'atlanta']]

In [180]:
tokenizer = Tokenizer()
tokenizer.fit_on_texts(text_sequences)
sequences = tokenizer.texts_to_sequences(text_sequences) 

#Collecting some information   
vocabulary_size = len(tokenizer.word_counts)+1

n_sequences = np.empty([len(sequences),train_len], dtype='int32')
for i in range(len(sequences)):
    n_sequences[i] = sequences[i]


In [26]:
train_inputs = n_sequences[:,:-1]
train_targets = n_sequences[:,-1]
train_targets = to_categorical(train_targets, num_classes=vocabulary_size)
seq_len = train_inputs.shape[1]
train_inputs.shape

NameError: name 'n_sequences' is not defined

In [223]:
print(seq_len)

3


In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Embedding
#model = load_model("mymodel.h5")
sequence_length = 2
model = Sequential()
model.add(Embedding(vocab_size, sequence_length, input_length=sequence_length))
model.add(LSTM(50,return_sequences=True))
model.add(LSTM(50))
model.add(Dense(50,activation='relu'))
model.add(Dense(vocab_size, activation='softmax'))
print(model.summary())
# compile network
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(two_gram_train_inputs,two_gram_train_outputs,epochs=500,verbose=1)
model.save("mymodel.h5")

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0


In [3]:
from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 5447545439523951135
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 22156005888
locality {
  bus_id: 1
  links {
  }
}
incarnation: 12256241962652004224
physical_device_desc: "device: 0, name: TITAN RTX, pci bus id: 0000:1a:00.0, compute capability: 7.5"
]


In [4]:
import tensorflow as tf
a = tf.constant([1, 2, 3])
print(a.device) 

/job:localhost/replica:0/task:0/device:CPU:0


In [3]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))


Num GPUs Available:  1


In [None]:
new_model = tf.keras.models.load_model('mymodel.h5')
