# NN-based Language Model
 - Goal: To build LM based on NN
 - Input: the seq of words (x1, xn)
 - Output: netx word (x_n+1)

## 01. Load packages

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
#import torch.optim as optim

## 02. preprocessing of data

In [6]:
test_sentence = """
This is the gravest crisis for Western security since the end of World War Two, and a lasting one. As one expert puts it, Trumpism will outlast his presidency. But which nations are equipped to step to the fore as the US stands back? At 09.00 one morning in February 1947, the UK ambassador in Washington, Lord Inverchapel, walked into the State Department to hand the US Secretary of State, George Marshall, two diplomatic messages printed on blue paper to emphasise their importance: one on Greece, the other on Turkey. Exhausted, broke and heavily in debt to the United States, Britain told the US that it could no longer continue its support for the Greek government forces that were fighting an armed Communist insurgency. Britain had already announced plans to pull out of Palestine and India and to wind down its presence in Egypt.
The United States saw immediately that there was now a real danger that Greece would fall to the Communists and, by extension, to Soviet control. And if Greece went, the United States feared that Turkey could be next, giving Moscow control of the Eastern Mediterranean including, potentially, the Suez Canal, a vital global trade route.
Almost overnight, the United States stepped into the vacuum left by the departing British.
"""


In [7]:
test_sentence = test_sentence.split()

In [8]:
test_sentence

['This',
 'is',
 'the',
 'gravest',
 'crisis',
 'for',
 'Western',
 'security',
 'since',
 'the',
 'end',
 'of',
 'World',
 'War',
 'Two,',
 'and',
 'a',
 'lasting',
 'one.',
 'As',
 'one',
 'expert',
 'puts',
 'it,',
 'Trumpism',
 'will',
 'outlast',
 'his',
 'presidency.',
 'But',
 'which',
 'nations',
 'are',
 'equipped',
 'to',
 'step',
 'to',
 'the',
 'fore',
 'as',
 'the',
 'US',
 'stands',
 'back?',
 'At',
 '09.00',
 'one',
 'morning',
 'in',
 'February',
 '1947,',
 'the',
 'UK',
 'ambassador',
 'in',
 'Washington,',
 'Lord',
 'Inverchapel,',
 'walked',
 'into',
 'the',
 'State',
 'Department',
 'to',
 'hand',
 'the',
 'US',
 'Secretary',
 'of',
 'State,',
 'George',
 'Marshall,',
 'two',
 'diplomatic',
 'messages',
 'printed',
 'on',
 'blue',
 'paper',
 'to',
 'emphasise',
 'their',
 'importance:',
 'one',
 'on',
 'Greece,',
 'the',
 'other',
 'on',
 'Turkey.',
 'Exhausted,',
 'broke',
 'and',
 'heavily',
 'in',
 'debt',
 'to',
 'the',
 'United',
 'States,',
 'Britain',
 'told'

In [9]:
CONTEXT_SIZE = 2

In [14]:
ngram=[
  (
    [test_sentence[i-j-1] for j in range(CONTEXT_SIZE)], test_sentence[i]

  )
  for i in range(CONTEXT_SIZE, len(test_sentence))
 ]

In [15]:
ngram

[(['is', 'This'], 'the'),
 (['the', 'is'], 'gravest'),
 (['gravest', 'the'], 'crisis'),
 (['crisis', 'gravest'], 'for'),
 (['for', 'crisis'], 'Western'),
 (['Western', 'for'], 'security'),
 (['security', 'Western'], 'since'),
 (['since', 'security'], 'the'),
 (['the', 'since'], 'end'),
 (['end', 'the'], 'of'),
 (['of', 'end'], 'World'),
 (['World', 'of'], 'War'),
 (['War', 'World'], 'Two,'),
 (['Two,', 'War'], 'and'),
 (['and', 'Two,'], 'a'),
 (['a', 'and'], 'lasting'),
 (['lasting', 'a'], 'one.'),
 (['one.', 'lasting'], 'As'),
 (['As', 'one.'], 'one'),
 (['one', 'As'], 'expert'),
 (['expert', 'one'], 'puts'),
 (['puts', 'expert'], 'it,'),
 (['it,', 'puts'], 'Trumpism'),
 (['Trumpism', 'it,'], 'will'),
 (['will', 'Trumpism'], 'outlast'),
 (['outlast', 'will'], 'his'),
 (['his', 'outlast'], 'presidency.'),
 (['presidency.', 'his'], 'But'),
 (['But', 'presidency.'], 'which'),
 (['which', 'But'], 'nations'),
 (['nations', 'which'], 'are'),
 (['are', 'nations'], 'equipped'),
 (['equipped',

In [17]:
vocab = set(test_sentence)

In [18]:
word_to_ix = {word: i for i, word in enumerate(vocab)}

In [19]:
word_to_ix

{'step': 0,
 'departing': 1,
 'already': 2,
 'as': 3,
 'the': 4,
 'vital': 5,
 'which': 6,
 'diplomatic': 7,
 'presidency.': 8,
 'This': 9,
 'giving': 10,
 'there': 11,
 'Washington,': 12,
 'At': 13,
 'Trumpism': 14,
 'Communist': 15,
 'that': 16,
 'will': 17,
 'potentially,': 18,
 'India': 19,
 'Britain': 20,
 'is': 21,
 'Inverchapel,': 22,
 'equipped': 23,
 'a': 24,
 'on': 25,
 'for': 26,
 'printed': 27,
 'in': 28,
 'February': 29,
 'Mediterranean': 30,
 'And': 31,
 'one': 32,
 'danger': 33,
 'trade': 34,
 'extension,': 35,
 'including,': 36,
 'stepped': 37,
 'vacuum': 38,
 'continue': 39,
 'out': 40,
 'down': 41,
 'Turkey': 42,
 'Turkey.': 43,
 'walked': 44,
 'Lord': 45,
 'no': 46,
 'State': 47,
 'an': 48,
 'was': 49,
 'As': 50,
 '09.00': 51,
 '1947,': 52,
 'forces': 53,
 'Greek': 54,
 'Soviet': 55,
 'paper': 56,
 'Department': 57,
 'Communists': 58,
 'UK': 59,
 'armed': 60,
 'hand': 61,
 'plans': 62,
 'overnight,': 63,
 'two': 64,
 'end': 65,
 'Suez': 66,
 'heavily': 67,
 'could': 

# 03. Build our model

In [50]:
EMBEDDING_DIM = 10
N_gram = 3

In [22]:
embeddings = nn.Embedding(len(vocab), EMBEDDING_DIM)

In [30]:
#torch.LongTensor([1,2,3])
embeddings(torch.LongTensor([1,2,3]))

tensor([[ 0.6620,  1.0747,  0.5249, -0.5987, -1.2036, -0.9915, -1.3313, -0.2857,
         -1.2303,  0.4614],
        [ 0.5354, -1.0840,  0.2831, -0.4681,  0.7538, -0.3961, -0.8153, -0.9076,
          0.0892,  0.0475],
        [-0.0669,  0.0545, -1.0317, -0.2964,  0.0479,  1.4567, -0.1848, -1.3572,
         -0.7292,  0.9147]], grad_fn=<EmbeddingBackward0>)

In [36]:
input = "the gravest crisis"

In [37]:
torch.LongTensor([1,2,3])

tensor([1, 2, 3])

In [38]:
input_tokens = input.split()

In [41]:
token_idx = []
for token in input_tokens:
 token_idx.append(word_to_ix[token])

In [43]:
token_idx_tensor = torch.LongTensor([token_idx])

In [46]:
 word_representation = embeddings(token_idx_tensor)

In [55]:
concate_words = word_representation.view((1,-1))

In [52]:
HIDDEN_DIM = 20

In [54]:
hidden_lin = nn.Linear(EMBEDDING_DIM*N_gram, HIDDEN_DIM)

In [57]:
hidden_vector = hidden_lin(concate_words)

In [58]:
classfier_lin = nn.Linear(HIDDEN_DIM, len(vocab))

In [61]:
logits = classfier_lin(hidden_vector)

In [63]:
logits

tensor([[ 0.2347, -0.4815,  0.0796, -0.9588,  0.2130,  0.1797,  0.8888, -0.6404,
          0.9185, -0.4924,  0.3753,  0.2282, -0.2226,  0.7298, -0.1459, -0.2145,
          0.4254,  0.0248, -0.4233,  0.2587, -0.2750,  0.1273,  0.0842,  0.2853,
          0.0336, -0.3442,  0.1598, -0.1389,  0.3663, -0.0679,  0.0530,  0.1560,
          0.5205, -0.4688, -0.0826,  0.5229,  0.4090,  0.3845, -0.2710, -0.0336,
         -0.0840, -0.2265,  0.1279,  0.2438, -0.1930, -0.7533, -0.3574,  0.8211,
         -0.0052, -0.0575, -0.0843,  0.3218,  0.0019, -0.2111,  1.0706, -0.2521,
          0.3894,  0.1402, -0.0920, -0.0583,  0.3553,  0.0091,  0.1716, -0.3444,
          0.1849,  0.4196,  1.0680, -0.1326,  0.0023,  0.5382, -0.4444, -0.2224,
         -0.3214, -0.0569,  0.1236, -0.2466,  0.4822, -0.3968, -0.5434,  0.0483,
         -1.0019, -0.1100,  0.5625,  0.5282,  0.1474,  0.8532, -0.3755, -0.2141,
         -0.0144,  0.0524, -0.2533, -0.2169, -0.4267,  0.1138, -0.3644, -0.0229,
          0.3397,  0.4723, -

In [64]:
pred_dist = torch.softmax(logits, dim=1)

In [65]:
torch.argmax(pred_dist)

tensor(54)

In [None]:
word_to_ix