In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
def prepare_sequence(seq, to_ix):
    """Input: takes in a list of words, and a dictionary containing the index of the words
    Output: a tensor containing the indexes of the word"""
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

# This is the example training data
training_data = [
    ("the dog happily ate the big apple".split(), ["DET", "NN", "ADV", "V", "DET", "ADJ", "NN"]),
    ("everybody read that good book quietly in the hall".split(), ["NN", "V", "DET", "ADJ", "NN", "ADV", "PRP", "DET", "NN"]),
    ("the old head master sternly scolded the naughty children for \
     being very loud".split(), ["DET", "ADJ", "ADJ", "NN", "ADV", "V", "DET", "ADJ",  "NN", "PRP", "V", "ADJ", "NN"]),
    ("i love you loads".split(), ["PRN", "V", "PRN", "ADV"])
]
#  These are other words which we would like to predict (within sentences) using the model
other_words = ["area", "book", "business", "case", "child", "company", "country", 
               "day", "eye", "fact", "family", "government", "group", "hand", "home", 
               "job", "life", "lot", "man", "money", "month", "mother", "food", "night", 
               "number", "part", "people", "place", "point", "problem", "program", 
               "question", "right", "room", "school", "state", "story", "student", 
               "study", "system", "thing", "time", "water", "way", "week", "woman", 
               "word", "work", "world", "year", "ask", "be", "become", "begin", "can", 
               "come", "do", "find", "get", "go", "have", "hear", "keep", "know", "let", 
               "like", "look", "make", "may", "mean", "might", "move", "play", "put", 
               "run", "say", "see", "seem", "should", "start", "think", "try", "turn", 
               "use", "want", "will", "work", "would", "asked", "was", "became", "began", 
               "can", "come", "do", "did", "found", "got", "went", "had", "heard", "kept", 
               "knew", "let", "liked", "looked", "made", "might", "meant", "might", "moved", 
               "played", "put", "ran", "said", "saw", "seemed", "should", "started", 
               "thought", "tried", "turned", "used", "wanted" "worked", "would", "able", 
               "bad", "best", "better", "big", "black", "certain", "clear", "different", 
               "early", "easy", "economic", "federal", "free", "full", "good", "great", 
               "hard", "high", "human", "important", "international", "large", "late", 
               "little", "local", "long", "low", "major", "military", "national", "new", 
               "old", "only", "other", "political", "possible", "public", "real", "recent", 
               "right", "small", "social", "special", "strong", "sure", "true", "white", 
               "whole", "young", "he", "she", "it", "they", "i", "my", "mine", "your", "his", 
               "her", "father", "mother", "dog", "cat", "cow", "tiger", "a", "about", "all", 
               "also", "and", "as", "at", "be", "because", "but", "by", "can", "come", "could", 
               "day", "do", "even", "find", "first", "for", "from", "get", "give", "go", 
               "have", "he", "her", "here", "him", "his", "how", "I", "if", "in", "into", 
               "it", "its", "just", "know", "like", "look", "make", "man", "many", "me", 
               "more", "my", "new", "no", "not", "now", "of", "on", "one", "only", "or", 
               "other", "our", "out", "people", "say", "see", "she", "so", "some", "take", 
               "tell", "than", "that", "the", "their", "them", "then", "there", "these", 
               "they", "thing", "think", "this", "those", "time", "to", "two", "up", "use", 
               "very", "want", "way", "we", "well", "what", "when", "which", "who", "will", 
               "with", "would", "year", "you", "your"]

word_to_ix = {} # This is the word dictionary which will contain the index to each word

for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix.keys():
            word_to_ix[word] = len(word_to_ix)
for word in other_words:
    if word not in word_to_ix.keys():
            word_to_ix[word] = len(word_to_ix)

# print(word_to_ix) # Just have a look at what it contains

tag_to_ix = {"DET": 0, "NN": 1, "V": 2, "ADJ": 3, "ADV": 4, "PRP": 5, "PRN": 6} # This dictionary contains the indices of the tags

EMBEDDING_DIM = 64
HIDDEN_DIM = 64

In [3]:
class LSTMTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, target_size):
        super(LSTMTagger, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, target_size)
        
    def forward(self, sentence):
        print(sentence.shape)
        embeds = self.word_embeddings(sentence)
        print(embeds.shape)
        print("in")
        print(embeds.view(len(sentence), 1, -1).shape)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        print("out")
        print(lstm_out.shape)
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        print(tag_space.shape)
        tag_scores = F.log_softmax(tag_space, dim=1)
        print(tag_scores.shape)
        return tag_scores

In [8]:
# Here I initialize the model with all the necesarry parameters
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix.keys()), len(tag_to_ix.keys()))

# Define the loss function as the Negative Log Likelihood loss (NLLLoss)
loss_function = nn.NLLLoss()

# We will be using a simple SGD optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)

# The test sentence
seq1 = "everybody read the book and ate the food".split()
seq2 = "she like my dog".split()
print("Running a check on the model before training.\nSentences:\n{}\n{}".format(" ".join(seq1), " ".join(seq2)))
with torch.no_grad():
    for seq in [seq1, seq2]:
        inputs = prepare_sequence(seq, word_to_ix)
        tag_scores = model(inputs)
        _, indices = torch.max(tag_scores, 1)
        ret = []
        for i in range(len(indices)):
            for key, value in tag_to_ix.items():
                if indices[i] == value:
                    ret.append((seq[i], key))
        print(ret)
    
print("Training Started")
for epoch in range(300):
    for sentence, tags in training_data:
        model.zero_grad()
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)
        
        tag_scores = model(sentence_in)
        
        print("SHAPE", tag_scores.shape, targets.shape)
        loss = loss_function(tag_scores, targets)
        loss.backward()
        optimizer.step()
        
print("Training Finished!!!\nAgain testing on unknown data")
# with torch.no_grad():
#     for seq in [seq1, seq2]:
#         inputs = prepare_sequence(seq, word_to_ix)
#         tag_scores = model(inputs)
#         _, indices = torch.max(tag_scores, 1)
#         ret = []
#         for i in range(len(indices)):
#             for key, value in tag_to_ix.items():
#                 if indices[i] == value:
#                     ret.append((seq[i], key))
#         print(ret)

Running a check on the model before training.
Sentences:
everybody read the book and ate the food
she like my dog
torch.Size([8])
torch.Size([8, 64])
in
torch.Size([8, 1, 64])
out
torch.Size([8, 1, 64])
torch.Size([8, 7])
torch.Size([8, 7])
[('everybody', 'V'), ('read', 'NN'), ('the', 'V'), ('book', 'NN'), ('and', 'DET'), ('ate', 'NN'), ('the', 'NN'), ('food', 'DET')]
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
[('she', 'NN'), ('like', 'PRN'), ('my', 'DET'), ('dog', 'PRN')]
Training Started
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 

torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) 

torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([1

torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([1

torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])

torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
tor

torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) torch.Size([7])
torch.Size([9])
torch.Size([9, 64])
in
torch.Size([9, 1, 64])
out
torch.Size([9, 1, 64])
torch.Size([9, 7])
torch.Size([9, 7])
SHAPE torch.Size([9, 7]) torch.Size([9])
torch.Size([13])
torch.Size([13, 64])
in
torch.Size([13, 1, 64])
out
torch.Size([13, 1, 64])
torch.Size([13, 7])
torch.Size([13, 7])
SHAPE torch.Size([13, 7]) torch.Size([13])
torch.Size([4])
torch.Size([4, 64])
in
torch.Size([4, 1, 64])
out
torch.Size([4, 1, 64])
torch.Size([4, 7])
torch.Size([4, 7])
SHAPE torch.Size([4, 7]) torch.Size([4])
torch.Size([7])
torch.Size([7, 64])
in
torch.Size([7, 1, 64])
out
torch.Size([7, 1, 64])
torch.Size([7, 7])
torch.Size([7, 7])
SHAPE torch.Size([7, 7]) 

KeyboardInterrupt: 