# Table of Contents
1. [Imports](#Imports)
2. [Data Read In](#Data-Read-in)

## Imports
[back to top](#Table-of-Contents)

In [1]:
import argparse
import csv
import os
import pickle
import random
import sys
import unittest

import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.autograd import Variable

SEED = 1234

random.seed(SEED
            )
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.has_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [2]:
# Part 1
def prepare_sequence(seq, to_ix):
    """Input: takes in a list of words, and a dictionary containing the index of the words
    Output: a tensor containing the indexes of the word"""
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)
# This is the example training data
training_data = [
    ("the dog happily ate the big apple".split(), ["DET", "NN", "ADV", "V", "DET", "ADJ", "NN"]),
    ("everybody read that good book quietly in the hall".split(), ["NN", "V", "DET", "ADJ", "NN", "ADV", "PRP", "DET", "NN"]),
    ("the old head master sternly scolded the naughty children for \
     being very loud".split(), ["DET", "ADJ", "ADJ", "NN", "ADV", "V", "DET", "ADJ",  "NN", "PRP", "V", "ADJ", "NN"]),
    ("i love you loads".split(), ["PRN", "V", "PRN", "ADV"])
]
#  These are other words which we would like to predict (within sentences) using the model
other_words = ["area", "book", "business", "case", "child", "company", "country",
               "day", "eye", "fact", "family", "government", "group", "hand", "home",
               "job", "life", "lot", "man", "money", "month", "mother", "food", "night",
               "number", "part", "people", "place", "point", "problem", "program",
               "question", "right", "room", "school", "state", "story", "student",
               "study", "system", "thing", "time", "water", "way", "week", "woman",
               "word", "work", "world", "year", "ask", "be", "become", "begin", "can",
               "come", "do", "find", "get", "go", "have", "hear", "keep", "know", "let",
               "like", "look", "make", "may", "mean", "might", "move", "play", "put",
               "run", "say", "see", "seem", "should", "start", "think", "try", "turn",
               "use", "want", "will", "work", "would", "asked", "was", "became", "began",
               "can", "come", "do", "did", "found", "got", "went", "had", "heard", "kept",
               "knew", "let", "liked", "looked", "made", "might", "meant", "might", "moved",
               "played", "put", "ran", "said", "saw", "seemed", "should", "started",
               "thought", "tried", "turned", "used", "wanted" "worked", "would", "able",
               "bad", "best", "better", "big", "black", "certain", "clear", "different",
               "early", "easy", "economic", "federal", "free", "full", "good", "great",
               "hard", "high", "human", "important", "international", "large", "late",
               "little", "local", "long", "low", "major", "military", "national", "new",
               "old", "only", "other", "political", "possible", "public", "real", "recent",
               "right", "small", "social", "special", "strong", "sure", "true", "white",
               "whole", "young", "he", "she", "it", "they", "i", "my", "mine", "your", "his",
               "her", "father", "mother", "dog", "cat", "cow", "tiger", "a", "about", "all",
               "also", "and", "as", "at", "be", "because", "but", "by", "can", "come", "could",
               "day", "do", "even", "find", "first", "for", "from", "get", "give", "go",
               "have", "he", "her", "here", "him", "his", "how", "I", "if", "in", "into",
               "it", "its", "just", "know", "like", "look", "make", "man", "many", "me",
               "more", "my", "new", "no", "not", "now", "of", "on", "one", "only", "or",
               "other", "our", "out", "people", "say", "see", "she", "so", "some", "take",
               "tell", "than", "that", "the", "their", "them", "then", "there", "these",
               "they", "thing", "think", "this", "those", "time", "to", "two", "up", "use",
               "very", "want", "way", "we", "well", "what", "when", "which", "who", "will",
               "with", "would", "year", "you", "your"]


In [4]:
word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix.keys():
            word_to_ix[word] = len(word_to_ix)
for word in other_words:
    if word not in word_to_ix.keys():
        word_to_ix[word] = len(word_to_ix)
tag_to_ix = {"DET": 0, "NN": 1, "V": 2, "ADJ": 3, "ADV": 4, "PRP": 5, "PRN": 6}
EMBEDDING_DIM = 64
HIDDEN_DIM = 64

In [26]:
class LSTMTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, target_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim).to(device)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim).to(device)
        self.hidden2tag = nn.Linear(hidden_dim, target_size).to(device)
    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        sentence.to(device)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        lstm_out.to(device)
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_score = F.log_softmax(tag_space, dim = 1)
        return tag_score

In [6]:
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [11]:
# test a sentence
seq1 = "everybody read the book and ate the food".split()
seq2 = "she like my dog".split()
print("Running a sample tenset \n Sentence:\n {} \n {}".format(" ".join(seq1),
                                                               " ".join(seq2)))
with torch.no_grad():
    for seq in [seq1, seq2]:
        model.to(device)
        inputs = prepare_sequence(seq, word_to_ix).to(device)
        tag_score = model(inputs)
        max_indices = tag_score.max(dim=1)[1]
        ret = []
        # reverse tag_to_ix
        reverse_tag_index = {v: k for k, v in tag_to_ix.items()}
        for i in range(len(max_indices)):
            idx = int(max_indices[i])
            ret.append((seq[i], reverse_tag_index[idx]))
        print(ret)

Running a sample tenset 
 Sentence:
 everybody read the book and ate the food 
 she like my dog
[('everybody', 'ADJ'), ('read', 'ADJ'), ('the', 'V'), ('book', 'NN'), ('and', 'V'), ('ate', 'V'), ('the', 'V'), ('food', 'V')]
[('she', 'V'), ('like', 'V'), ('my', 'NN'), ('dog', 'NN')]


In [13]:
# Train
losses = []
model.to(device)
for epoch in range(300):
    count = 0
    sum_loss = 0
    for sentence, tags in training_data:
        sentence_in = prepare_sequence(sentence, word_to_ix).to(device)
        targets = prepare_sequence(tags, tag_to_ix).to(device)
        out = model(sentence_in)
        loss = loss_function(out, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        count += 1
        sum_loss += loss
        losses.append(sum_loss / count)
    print("Epoch: {}, Loss {}".format(epoch, losses[-1]))
print("Train Finished")

Epoch: 0, Loss 1.9295401573181152
Epoch: 1, Loss 1.8933115005493164
Epoch: 2, Loss 1.8580784797668457
Epoch: 3, Loss 1.8237969875335693
Epoch: 4, Loss 1.7892918586730957
Epoch: 5, Loss 1.753929615020752
Epoch: 6, Loss 1.7171865701675415
Epoch: 7, Loss 1.6786421537399292
Epoch: 8, Loss 1.6379826068878174
Epoch: 9, Loss 1.595015287399292
Epoch: 10, Loss 1.5496798753738403
Epoch: 11, Loss 1.5020530223846436
Epoch: 12, Loss 1.4523417949676514
Epoch: 13, Loss 1.4008657932281494
Epoch: 14, Loss 1.3480287790298462
Epoch: 15, Loss 1.2942864894866943
Epoch: 16, Loss 1.2401100397109985
Epoch: 17, Loss 1.1859546899795532
Epoch: 18, Loss 1.1322342157363892
Epoch: 19, Loss 1.0793037414550781
Epoch: 20, Loss 1.0274531841278076
Epoch: 21, Loss 0.97690749168396
Epoch: 22, Loss 0.927833080291748
Epoch: 23, Loss 0.8803455233573914
Epoch: 24, Loss 0.8345179557800293
Epoch: 25, Loss 0.7903887033462524
Epoch: 26, Loss 0.747968316078186
Epoch: 27, Loss 0.7072464227676392
Epoch: 28, Loss 0.668196439743042
Ep

In [16]:
# predict function
def predict_seq(seq_list, model):
    """

    :param seq_list: list of sequences
    :param model: NN model
    :return: tuple predictions
    """
    # model.to(device)
    with torch.no_grad():
        for seq in seq_list:
            inputs = prepare_sequence(seq, word_to_ix).to(device)
            tags_score = model(inputs)
            max_indices = tags_score.max(dim=1)[1]
            pred = []
            reverse_tag_index = {v: k for k, v in tag_to_ix.items()}
            for i in range(len(max_indices)):
                idx = int(max_indices[i])
                pred.append(reverse_tag_index[idx])
            print("Sequence: {} \n"
              "Tag Prediction: {}\n".format(seq, pred))

In [17]:
# test on unkown data
predict_seq([seq1, seq2], model)

Sequence: ['everybody', 'read', 'the', 'book', 'and', 'ate', 'the', 'food'] 
Tag Prediction: ['NN', 'V', 'DET', 'NN', 'NN', 'V', 'DET', 'ADJ']

Sequence: ['she', 'like', 'my', 'dog'] 
Tag Prediction: ['PRN', 'V', 'PRN', 'NN']



## Data Read in
[back to top](#Table-of-Contents)

In [18]:
def split_text(text_file, by_line=False):
    """

    :param by_line: bool, whether to split by lines; if False, split by word
    :param text_file: training file
    :return: DIC, TOKENS and TAGS

    """
    if by_line == False:
        with open(text_file, mode="r") as file:
            text_f = file.read()
            text_f_lst = text_f.split()
            file.close()
        keys, values = text_f_lst[::2], text_f_lst[1::2]
        result_dic = dict(zip(keys, values))
        return result_dic, keys, values
    else:
        with open(text_file, mode="r") as file:
            text_f = file.read()
            text_f_lst = text_f.splitlines()
            file.close()
        keys = [line.split()[::2] for line in text_f_lst]
        values = [line.split()[1::2] for line in text_f_lst]
        # result_dic = dict(zip(keys, values))
        return keys, values
# create a list of list of tuples for training data
def combine_lists(vocab_list, tags_list):
    """

    :param vocab_list: list of sentence
    :param tags_list:
    :return: list of list of sentence of words tuples e.g. [[('Pierre', 'NOUN'), ('Vinken', 'NOUN'), (',', '.')]]
    """
    result = []
    for i in range(len(vocab_list)):
        sentence, tags = vocab_list[i], tags_list[i]
        zipped = zip(sentence, tags)
        result.append(list(zipped))
    return result

In [19]:
vocab_list, tags_list = split_text("wsj1-18.training", by_line=True)
train_list = combine_lists(vocab_list, tags_list)
test_vocab_list, test_tags_list = split_text("wsj19-21.truth", by_line=True)
test_list = combine_lists(test_vocab_list, test_tags_list)

### Construct dictionary
1. A word/tag dictionary
2. A letter/character dictionary
3. A POS tag dictionary


In [20]:
def sequence_to_idx(words, dic_ix):
    """

    :param words: list of words
    :param dic_ix: dictionary with the index as values, word as keys
    :return: list of indices
    """
    return torch.tensor([dic_ix[word] for word in words], dtype=torch.long)

In [21]:
word_to_idx = {}
tag_to_idx = {}
char_to_idx = {}
for sentence in train_list:
    for word, tag in sentence:
        if word not in word_to_idx.keys():
            word_to_idx[word] = len(word_to_idx)
        if tag not in tag_to_idx.keys():
            tag_to_idx[tag] = len(tag_to_idx)
        for char in word:
            if char not in char_to_idx.keys():
                char_to_idx[char] = len(char_to_idx)
word_vocab_size = len(word_to_idx)
tag_vocab_size = len(tag_to_idx)
char_vocab_size = len(char_to_idx)
for sentence in test_vocab_list:
    for word in sentence:
        if word not in word_to_idx.keys():
            word_to_idx[word] = len(word_to_idx)
print("Unique words: {}".format(len(word_to_idx)))
print("Unique tags: {}".format(len(tag_to_idx)))
print("Unique characters: {}".format(len(char_to_idx)))

Unique words: 46620
Unique tags: 45
Unique characters: 80


### Specify hyperparamters

In [22]:
def get_accuracy(model, if_train):
    if if_train:
        data = list(zip(vocab_list, tags_list))
    else:
        data = list(zip(test_vocab_list, test_tags_list))
    model.to(device)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for input_tuple in data:
            # get X in a list by unzipping the input tuple list
            X = input_tuple[0]
            # get Y similarly
            y = input_tuple[1]
            # convert into index
            X = prepare_sequence(X, word_to_idx).to(device)
            y = prepare_sequence(y, tag_to_idx).to(device)
            # forward model
            out = model(X)
            max_indices = out.max(dim=1)[1]
            total += len(y)
            # because prepare sequence output long type tensor
            correct = torch.eq(max_indices, y).sum().item()
        return correct / total

In [23]:
WORD_EMBEDDING_DIM = 1024
CHAR_EMBEDDING_DIM = 128
WORD_HIDDEN_DIM = 1024
CHAR_HIDDEN_DIM = 1024
EPOCHS = 100
lr = 1e-3

In [31]:
debug = True
def train(model, lr, epochs=EPOCHS):
    optimizer = optim.SGD(model.parameters(), lr=lr)
    loss_function = nn.NLLLoss()
    data = list(zip(vocab_list, tags_list))
    # init losses,
    losses, train_acc, test_acc = [], [], []
    for epoch in range(epochs):
        iters = 0
        sum_loss = 0
        for batch_id, (X, y) in enumerate(data):
            if debug: print('Beginning Reading Batch: {}'.format(batch_id))
            X = prepare_sequence(X, word_to_idx).to(device)
            y = prepare_sequence(y, tag_to_idx).to(device)
            model.to(device)
            # forward model pass
            out = model(X)
            loss = loss_function(out, y) # compute the loss
            loss.backward() # backward pass
            optimizer.step() # make the update to each parameter
            optimizer.zero_grad()

            # save result
            if debug: print('Calculated Loss')
            sum_loss += loss
            iters += 1
            train_acc.append(get_accuracy(model, True))
            if debug: print('Calculated Train Acc')
            test_acc.append(get_accuracy(model, False))
            if debug: print('Calculated Test Acc')
            losses.append(sum_loss / iters)
            # if batch_id % 100 == 0:
            print("Epoch: {}, Batch: {} \n"
                      "Loss{}, Train Accuracy{:.2%}, Test Accuracy:{:.2%}".format(
                    epoch, batch_id, losses[-1], train_acc[-1], test_acc[-1]
                ) )
        # plotting
        plt.title("Training Curve")
        plt.plot(np.arange(len(losses)), losses, label="Loss")
        plt.plot(np.arange(len(losses)), train_acc, linestyle='-.', label="Train")
        plt.plot(np.arange(len(losses)), test_acc, linestyle='-.', label="Test")
        plt.xlabel("Epoch")
        plt.ylabel("Loss/ Accuracy")
        plt.legend(loc="best")
        plt.show()
    return model
    print("Finished Training\n" + "-" * 50)

In [34]:
model = LSTMTagger(WORD_EMBEDDING_DIM, WORD_HIDDEN_DIM, word_vocab_size, tag_vocab_size)
get_accuracy(model, False)

KeyboardInterrupt: 

In [35]:
out = torch.tensor([[-3.8064, -3.6861, -3.8779, -3.8170, -3.8165, -3.8231, -3.7911, -3.7696,
         -3.7775, -3.7533, -3.8098, -3.7418, -3.6407, -3.8113, -3.8538, -3.9267,
         -3.8926, -3.7769, -3.9160, -3.7349, -3.8407, -3.8920, -3.8366, -3.8859,
         -3.8680, -3.8010, -3.8925, -3.7323, -3.8046, -3.8256, -3.8375, -3.7195,
         -3.8401, -3.8339, -3.8228, -3.7650, -3.6468, -3.9694, -3.7715, -3.7675,
         -3.7508, -3.8195, -3.7975, -3.8476, -3.8106],
        [-3.8204, -3.8007, -3.8056, -3.8045, -3.7528, -3.8282, -3.8481, -3.9125,
         -3.7190, -3.7485, -3.7960, -3.8217, -3.6641, -3.9218, -3.8684, -3.8554,
         -3.7569, -3.8234, -3.7805, -3.7599, -3.7906, -3.8922, -3.7402, -3.8611,
         -3.7926, -3.8411, -3.8018, -3.8120, -3.7535, -3.8739, -3.7824, -3.7007,
         -3.6913, -3.9396, -3.8181, -3.8038, -3.7777, -3.7581, -3.8490, -3.9085,
         -3.7803, -3.7878, -3.7937, -3.8650, -3.8788],
        [-3.7569, -3.6763, -3.7374, -3.8051, -3.8043, -3.9225, -3.9404, -3.7755,
         -3.8280, -3.8150, -3.8818, -3.7948, -3.7637, -3.9012, -3.8660, -3.8266,
         -3.6830, -3.8665, -3.8371, -3.7036, -3.7418, -3.8788, -3.8114, -3.8174,
         -3.8085, -3.7919, -3.9459, -3.8970, -3.7160, -3.8466, -3.7621, -3.6743,
         -3.8311, -3.7939, -3.8235, -3.8004, -3.8518, -3.7786, -3.9431, -3.7774,
         -3.7599, -3.8003, -3.7347, -3.8027, -3.8277],
        [-3.7147, -3.8125, -3.8317, -3.8253, -3.8458, -3.8643, -4.0009, -3.8040,
         -3.6672, -3.7972, -3.8258, -3.7802, -3.6552, -3.8383, -3.8643, -3.8381,
         -3.8039, -3.7895, -3.7981, -3.6961, -3.8237, -3.8914, -3.8287, -3.8331,
         -3.8750, -3.8002, -3.9484, -3.7619, -3.7871, -3.8754, -3.6865, -3.8059,
         -3.8233, -3.7737, -3.7274, -3.6838, -3.9293, -3.6790, -4.0250, -3.9600,
         -3.7300, -3.8410, -3.6889, -3.8822, -3.7465],
        [-3.6645, -3.7854, -3.9518, -3.8959, -3.9094, -3.8217, -3.9750, -3.8862,
         -3.7229, -3.7642, -3.8015, -3.8544, -3.7993, -3.8259, -3.8096, -3.7836,
         -3.7416, -3.7645, -3.6310, -3.7750, -3.8153, -3.9916, -3.9341, -3.8078,
         -3.7710, -3.8773, -3.8675, -3.7755, -3.8245, -3.7448, -3.6529, -3.8821,
         -3.8041, -3.7302, -3.7981, -3.6997, -3.7841, -3.8125, -3.8720, -3.8339,
         -3.8696, -3.8532, -3.6832, -3.7968, -3.7966],
        [-3.6956, -3.8468, -3.9124, -3.8209, -4.0121, -3.8150, -3.9132, -3.8701,
         -3.7798, -3.8592, -3.8821, -3.9241, -3.6536, -3.9529, -3.8873, -3.7850,
         -3.8297, -3.7944, -3.6789, -3.8263, -3.7308, -4.0436, -3.9455, -3.7647,
         -3.7955, -3.8426, -3.9011, -3.8362, -3.7854, -3.6361, -3.6591, -3.8751,
         -3.7482, -3.7232, -3.7242, -3.7383, -3.7086, -3.8622, -3.8436, -3.8506,
         -3.7065, -3.7771, -3.6798, -3.7712, -3.8052],
        [-3.6481, -3.6580, -3.8551, -3.8303, -3.9002, -3.9156, -3.9515, -3.7677,
         -3.8439, -3.8388, -3.9169, -3.8357, -3.7593, -3.9459, -3.8980, -3.7742,
         -3.7654, -3.8263, -3.7763, -3.7291, -3.7382, -3.9384, -3.9579, -3.7597,
         -3.7772, -3.7760, -3.9407, -3.9049, -3.7665, -3.7555, -3.6619, -3.7553,
         -3.8602, -3.6737, -3.8073, -3.7668, -3.8015, -3.8503, -3.9475, -3.7952,
         -3.7017, -3.8189, -3.7468, -3.7496, -3.7689],
        [-3.7404, -3.7937, -3.7227, -3.8239, -3.6854, -3.7936, -3.9346, -3.8563,
         -3.7726, -3.9204, -3.8403, -3.9875, -3.7404, -3.9096, -3.8798, -3.7911,
         -3.8267, -3.8006, -3.6606, -3.7481, -3.8127, -3.8676, -3.8146, -3.7009,
         -3.9097, -3.7640, -3.8747, -3.8817, -3.8971, -3.7301, -3.7155, -3.7230,
         -3.7566, -3.7933, -3.7800, -3.7091, -3.8970, -3.8651, -3.9053, -3.8382,
         -3.7105, -3.7954, -3.7717, -3.9200, -3.7702],
        [-3.8033, -3.8047, -3.7593, -3.8153, -3.8882, -3.8669, -3.8223, -3.8760,
         -3.6245, -3.9543, -3.7404, -3.8470, -3.8600, -3.7891, -3.8045, -3.6943,
         -3.8185, -3.8256, -3.7910, -3.6924, -3.8688, -3.8127, -3.8309, -3.7382,
         -3.9271, -3.7385, -3.7064, -3.7792, -3.9335, -3.6776, -3.6746, -3.7578,
         -3.8621, -3.9000, -3.8528, -3.6432, -3.9483, -4.0025, -3.8727, -3.7270,
         -3.8257, -3.8437, -3.7421, -4.0123, -3.7257],
        [-3.7449, -3.7040, -3.8578, -3.7934, -3.7758, -3.8208, -3.8068, -3.7229,
         -3.7574, -3.8762, -3.7685, -3.9012, -3.7643, -3.8826, -3.8437, -3.7367,
         -3.6822, -3.8863, -3.8725, -3.7964, -3.8735, -3.8933, -3.9225, -3.7333,
         -4.0712, -3.8093, -3.8181, -3.6794, -3.8298, -3.7567, -3.6711, -3.7937,
         -3.8066, -3.8319, -3.7601, -3.6804, -3.8582, -4.0918, -3.8298, -3.7881,
         -3.7601, -3.7846, -3.8611, -3.8545, -3.7122],
        [-3.7795, -3.7394, -3.9347, -3.8260, -3.8177, -3.7251, -3.8236, -3.7337,
         -3.8545, -3.8107, -3.7864, -3.8783, -3.7562, -3.7930, -3.9473, -3.7280,
         -3.6568, -3.8576, -3.9113, -3.7663, -3.7479, -3.8081, -3.8670, -3.7979,
         -3.8608, -3.8325, -3.8843, -3.7879, -3.8240, -3.8226, -3.7083, -3.7941,
         -3.9223, -3.7505, -3.8664, -3.8225, -3.7289, -3.9755, -3.8062, -3.8003,
         -3.8000, -3.8345, -3.6527, -3.8147, -3.7720],
        [-3.6566, -3.7714, -3.8988, -3.8508, -3.8750, -3.8751, -3.9067, -3.8476,
         -3.7487, -3.7690, -3.8397, -3.8253, -3.8756, -3.7180, -3.7534, -3.8108,
         -3.6753, -3.7920, -3.9032, -3.9020, -3.8138, -3.8956, -3.7094, -3.7127,
         -3.9103, -3.7797, -3.7701, -3.7391, -3.8951, -3.7262, -3.7236, -3.7737,
         -3.8763, -3.6681, -3.8167, -3.8037, -3.8032, -3.8306, -3.8149, -3.8153,
         -3.8833, -3.8267, -3.7720, -3.8554, -3.9012],
        [-3.7675, -3.7181, -3.8743, -3.9147, -3.7952, -3.8267, -3.9955, -3.8812,
         -3.6174, -3.7832, -3.9333, -3.7849, -3.9526, -3.8020, -3.7493, -3.7639,
         -3.7449, -3.8440, -3.8255, -3.8040, -3.7161, -3.8180, -3.8691, -3.8230,
         -3.7870, -3.8424, -3.7560, -3.8753, -3.8678, -3.8473, -3.6614, -3.7385,
         -3.7661, -3.7236, -3.8554, -3.8372, -3.7889, -3.8353, -3.8465, -3.6777,
         -3.7744, -3.8650, -3.7410, -3.8450, -3.8904],
        [-3.7336, -3.8178, -3.8059, -3.9887, -3.8826, -3.8109, -3.8448, -3.7167,
         -3.7988, -3.9344, -3.8267, -3.7917, -3.8272, -3.8494, -3.8686, -3.6894,
         -3.6813, -3.8832, -3.8325, -3.9047, -3.7495, -3.8521, -3.7719, -3.8040,
         -3.7823, -3.8780, -3.6638, -3.8351, -3.9140, -3.9574, -3.8628, -3.6902,
         -3.6947, -3.8066, -3.8184, -3.7221, -3.7508, -3.8395, -3.8276, -3.8152,
         -3.6704, -3.8418, -3.6486, -4.0034, -3.7684],
        [-3.6428, -3.6903, -3.8872, -3.9267, -3.9637, -3.9185, -3.8537, -3.6970,
         -3.9063, -3.8867, -3.9348, -3.8812, -3.7554, -3.9000, -3.8211, -3.6883,
         -3.8609, -3.7764, -3.8229, -3.8394, -3.6786, -3.7303, -3.8804, -3.8648,
         -3.7171, -3.9098, -3.8103, -3.6931, -3.8557, -3.8613, -3.8200, -3.6633,
         -3.6903, -3.7636, -3.8420, -3.8298, -3.7286, -3.8714, -3.8928, -3.8384,
         -3.6701, -3.9039, -3.6694, -3.8458, -3.7947],
        [-3.7400, -3.7414, -3.7947, -3.9006, -3.8343, -3.9172, -3.8796, -3.7751,
         -3.6915, -3.7453, -3.8106, -3.9302, -3.8431, -3.8374, -3.8393, -3.8796,
         -3.8330, -3.7775, -3.8518, -3.8253, -3.7622, -3.7353, -3.9290, -3.8986,
         -3.8297, -3.8554, -3.7488, -3.7354, -3.8118, -3.7259, -3.8567, -3.7481,
         -3.7224, -3.8127, -3.8633, -3.8826, -3.8303, -3.8331, -3.8073, -3.7549,
         -3.6395, -3.8839, -3.7573, -3.7916, -3.7359],
        [-3.8293, -3.7964, -3.8912, -3.7737, -3.8711, -3.8490, -3.9733, -3.7915,
         -3.7494, -3.8033, -3.6459, -3.9215, -3.8275, -3.7582, -3.7578, -3.9423,
         -3.8585, -3.7128, -3.8380, -3.8400, -3.7046, -3.6899, -3.9365, -3.8141,
         -3.9000, -3.8374, -3.8161, -3.7728, -3.7878, -3.7544, -3.8721, -3.7127,
         -3.7017, -3.8900, -3.8251, -3.8659, -3.6874, -3.8354, -3.8168, -3.7842,
         -3.7558, -3.9144, -3.7905, -3.8085, -3.7198],
        [-3.9430, -3.6644, -3.7695, -3.8337, -3.8817, -3.8692, -3.9832, -3.7952,
         -3.6338, -3.7847, -3.6693, -3.8395, -3.8337, -3.7766, -3.8948, -3.9011,
         -3.9402, -3.7512, -3.8840, -3.7673, -3.7424, -3.6445, -3.9830, -3.8825,
         -3.9790, -3.9505, -3.9185, -3.7473, -3.7039, -3.7165, -3.8429, -3.7408,
         -3.7329, -3.8213, -3.7644, -3.8727, -3.7670, -3.7403, -3.8827, -3.6653,
         -3.8695, -3.7374, -3.7576, -3.9057, -3.7185]])

In [45]:
out.exp().max(dim=1)[1]

tensor([12, 12, 31, 12, 18, 29,  0, 18,  8, 30, 42,  0,  8, 42,  0, 40, 10,  8])

In [46]:
out.max(dim=1)[1]

tensor([12, 12, 31, 12, 18, 29,  0, 18,  8, 30, 42,  0,  8, 42,  0, 40, 10,  8])

In [47]:
out.exp()[0]

tensor([0.0222, 0.0251, 0.0207, 0.0220, 0.0220, 0.0219, 0.0226, 0.0231, 0.0229,
        0.0234, 0.0222, 0.0237, 0.0262, 0.0221, 0.0212, 0.0197, 0.0204, 0.0229,
        0.0199, 0.0239, 0.0215, 0.0204, 0.0216, 0.0205, 0.0209, 0.0223, 0.0204,
        0.0239, 0.0223, 0.0218, 0.0215, 0.0242, 0.0215, 0.0216, 0.0219, 0.0232,
        0.0261, 0.0189, 0.0230, 0.0231, 0.0235, 0.0219, 0.0224, 0.0213, 0.0221])