In [1]:
from data import Dataset
import torch
from models import RNNG, RNNLM
import torch.nn.functional as F
import pandas as pd

In [2]:
train_data = Dataset('data/ptb-train.pkl')
test_data = Dataset('data/ptb-test.pkl')
val_data = Dataset('data/ptb-val.pkl')

In [3]:
loaded_data = torch.load('rnng.pt')
model_args = loaded_data['args']
model_state_dict = loaded_data['model'].state_dict()

In [4]:
rnng = RNNG(
    vocab=len(loaded_data['word2idx']),
    w_dim=model_args['w_dim'],           # Dimensionality of word embeddings
    h_dim=model_args['h_dim'],           # Dimensionality of hidden states
    q_dim=model_args['q_dim'],           # Dimensionality of 'q' vector
    num_layers=model_args['num_layers'], # Number of layers
    dropout=model_args['dropout'],       # Dropout rate
    max_len=250
)
rnng.load_state_dict(model_state_dict)

<All keys matched successfully>

In [5]:
rnng.eval()
rnng.cuda()

RNNG(
  (emb): Embedding(288, 650)
  (dropout): Dropout(p=0.5, inplace=False)
  (stack_rnn): SeqLSTM(
    (linears): ModuleList(
      (0): Linear(in_features=1300, out_features=2600, bias=True)
      (1): Linear(in_features=1300, out_features=2600, bias=True)
    )
    (dropout_layer): Dropout(p=0.5, inplace=False)
  )
  (tree_rnn): TreeLSTM(
    (linear): Linear(in_features=1300, out_features=3250, bias=True)
  )
  (vocab_mlp): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=650, out_features=288, bias=True)
  )
  (q_binary): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=512, out_features=1, bias=True)
  )
  (action_mlp_p): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=650, out_features=1, bias=True)
  )
  (q_leaf_rnn): LSTM(650, 256, batch_first=Tru

In [6]:
def predict_next_token(model, input_tensor, device='cuda'):

    model.eval()
    with torch.no_grad():
        next_word_probs = model(input_tensor)
            
    _, max_idx = torch.max(next_word_probs, 1)
    print(max_idx)
    next_word = loaded_data['idx2word'][max_idx.item()]
    return next_word_probs, next_word

In [7]:
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[11]
sents = sents.cuda()
print(sents)
sents = sents[:, :-8]
print(sents)

tensor([[  2, 241,  28,  23, 242, 243, 144,  28, 244,  75, 173, 135, 136,  42,
         184, 134,  75,  58,   9, 130, 179, 180, 181, 245,  61,  96, 246, 247,
          23, 248,  45,   3]], device='cuda:0')
tensor([[  2, 241,  28,  23, 242, 243, 144,  28, 244,  75, 173, 135, 136,  42,
         184, 134,  75,  58,   9, 130, 179, 180, 181, 245]], device='cuda:0')


In [8]:
out, next_word = predict_next_token(rnng, sents)
print(next_word)



TypeError: max() received an invalid combination of arguments - got (tuple, int), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, *, Tensor out)
 * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)
 * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)


In [None]:
out = out.squeeze()

In [None]:
i = 0
for x in out:
    print(i, ": ", x)
    i += 1

# Try the language model RNN

In [9]:
loaded_lm = torch.load('lm.pt')
lm_args = loaded_lm['args']
lm_state_dict = loaded_lm['model'].state_dict()

In [10]:
rnnlm = RNNLM(
    vocab=len(loaded_lm['word2idx']),
    w_dim=lm_args['w_dim'],
    h_dim=lm_args['h_dim'],           # Dimensionality of hidden states
    num_layers=lm_args['num_layers'], # Number of layers
    dropout=lm_args['dropout']
)
rnnlm.load_state_dict(lm_state_dict)

<All keys matched successfully>

In [11]:
rnnlm.eval()
rnnlm.cuda()

RNNLM(
  (word_vecs): Embedding(288, 650)
  (dropout): Dropout(p=0.5, inplace=False)
  (rnn): LSTM(650, 650, num_layers=100, batch_first=True, dropout=0.5)
  (vocab_linear): Linear(in_features=650, out_features=288, bias=True)
)

In [12]:
p = rnnlm(sents)

In [13]:
p.shape

torch.Size([1, 288])

In [14]:
probs, word = predict_next_token(rnnlm, sents)

tensor([3], device='cuda:0')


In [15]:
word

'</s>'

In [16]:
for x in train_data.sents[11]:
    print(x.item(), ":", train_data.idx2word[x.item()])

2 : <s>
72 : Howard
73 : Mosher
28 : ,
74 : president
75 : and
76 : chief
77 : executive
78 : officer
28 : ,
53 : said
79 : he
80 : anticipates
81 : growth
82 : for
23 : the
65 : luxury
66 : auto
67 : maker
25 : in
83 : Britain
75 : and
84 : Europe
28 : ,
75 : and
25 : in
85 : Far
86 : Eastern
87 : markets
45 : .
3 : </s>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <pad>
0 : <p

In [17]:
indices = []
for x in train_data[11][6][0][0].split():
    print(x)
    idx = train_data.word2idx[x]
    indices.append(idx)

Currently
,
the
rules
force
executives
,
directors
and
other
corporate
insiders
to
report
purchases
and
sales
of
their
companies
'
shares
within
about
a
month
after
the
transaction
.


In [18]:
indices.pop()
print(indices)
test_tensor = torch.tensor(indices).unsqueeze(dim=0).cuda()

[241, 28, 23, 242, 243, 144, 28, 244, 75, 173, 135, 136, 42, 184, 134, 75, 58, 9, 130, 179, 180, 181, 245, 61, 96, 246, 247, 23, 248]


In [19]:
test_tensor.shape

torch.Size([1, 29])

In [20]:
probs, word = predict_next_token(rnnlm, test_tensor)

tensor([3], device='cuda:0')


In [21]:
word

'</s>'

In [27]:
with open('../train_100M/bnc_spoken.train', 'r', encoding='utf-8') as file:
    content = file.read()

# Print the content or the first few lines
print(type(content))

<class 'str'>
