In [19]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os
from modules.utils import *

In [2]:
lstm = nn.LSTM(1, 3)  # Input dim is 1, output & hidden dim is 3

In [3]:
inputs = [torch.randn(1, 1) for _ in range(3)]  # make a sequence of length 3

In [4]:
inputs_concat = torch.cat(inputs).view(len(inputs), 1, -1)

In [5]:
inputs

[tensor([[1.5476]]), tensor([[0.5268]]), tensor([[0.5393]])]

In [6]:
inputs_concat

tensor([[[1.5476]],

        [[0.5268]],

        [[0.5393]]])

In [9]:
hidden = (torch.randn(1, 1, 3),   #hidden and cell state
          torch.randn(1, 1, 3))

In [14]:
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print(out, hidden)
    print()

tensor([[[-0.0327,  0.2701, -0.1710]]], grad_fn=<StackBackward>) (tensor([[[-0.0327,  0.2701, -0.1710]]], grad_fn=<StackBackward>), tensor([[[-0.0663,  0.5837, -0.2380]]], grad_fn=<StackBackward>))

tensor([[[-0.0858,  0.2221, -0.1303]]], grad_fn=<StackBackward>) (tensor([[[-0.0858,  0.2221, -0.1303]]], grad_fn=<StackBackward>), tensor([[[-0.1826,  0.5082, -0.2222]]], grad_fn=<StackBackward>))

tensor([[[-0.1142,  0.2000, -0.1399]]], grad_fn=<StackBackward>) (tensor([[[-0.1142,  0.2000, -0.1399]]], grad_fn=<StackBackward>), tensor([[[-0.2453,  0.4489, -0.2329]]], grad_fn=<StackBackward>))



In [15]:
out, hidden = lstm(inputs_concat, hidden)
print(out)
print(hidden)

tensor([[[-0.0324,  0.2709, -0.1710]],

        [[-0.0857,  0.2226, -0.1303]],

        [[-0.1141,  0.2002, -0.1399]]], grad_fn=<StackBackward>)
(tensor([[[-0.1141,  0.2002, -0.1399]]], grad_fn=<StackBackward>), tensor([[[-0.2451,  0.4496, -0.2329]]], grad_fn=<StackBackward>))


In [28]:
class CharDataset(Dataset):
    '''
    Gets left and right sequences for every character in the dataset
    '''
    
    def __init__(self, article_directory, label_directory, pad):
        self.article_fnames = os.listdir(article_directory)
        self.article_directory = article_directory
        self.label_dir = label_directory
        self.pad = pad
        
        corpus = ''
        char_list = []
        
        for fname in self.article_fnames:
            with open(article_directory + fname, newline = '\n') as article:
                corpus += str(article.read() + ' ')
            
        corpus_list = list(corpus)
        for char in corpus_list:
            if char not in char_list:
                char_list.append(char)
                      
        self.char2token = {val : idx + 1 for idx, val in enumerate(char_list)}
        
    def __len__(self):
        return len(self.article_fnames)
    
    def __getitem__(self, i):
        fname = self.article_fnames[i]
        article_id = int(fname[7:16])
        
        labels = label_article_chars(article_id)
        '''for label in labels:
            label[0] = self.char2token[label[0]]
        
        labels = [[0,0]]*self.pad + labels + [[0,0]]*self.pad'''
        
        return labels

In [29]:
char_dataset = CharDataset('../datasets/train-articles/',
                           '../datasets/train-labels-task1-span-identification/',
                           20)

In [33]:
type(char_dataset)

__main__.CharDataset

In [32]:
char_dataset[1]

[['U', 0],
 ['S', 0],
 [' ', 0],
 ['b', 0],
 ['l', 0],
 ['o', 0],
 ['g', 0],
 ['g', 0],
 ['e', 0],
 ['r', 0],
 ['s', 0],
 [' ', 0],
 ['b', 0],
 ['a', 0],
 ['n', 0],
 ['n', 0],
 ['e', 0],
 ['d', 0],
 [' ', 0],
 ['f', 0],
 ['r', 0],
 ['o', 0],
 ['m', 0],
 [' ', 0],
 ['e', 0],
 ['n', 0],
 ['t', 0],
 ['e', 0],
 ['r', 0],
 ['i', 0],
 ['n', 0],
 ['g', 0],
 [' ', 0],
 ['U', 0],
 ['K', 0],
 ['\n', 0],
 ['\n', 0],
 ['T', 0],
 ['w', 0],
 ['o', 0],
 [' ', 0],
 ['p', 0],
 ['r', 0],
 ['o', 0],
 ['m', 0],
 ['i', 0],
 ['n', 0],
 ['e', 0],
 ['n', 0],
 ['t', 0],
 [' ', 0],
 ['U', 0],
 ['S', 0],
 [' ', 0],
 ['b', 0],
 ['l', 0],
 ['o', 0],
 ['g', 0],
 ['g', 0],
 ['e', 0],
 ['r', 0],
 ['s', 0],
 [' ', 0],
 ['h', 0],
 ['a', 0],
 ['v', 0],
 ['e', 0],
 [' ', 0],
 ['b', 0],
 ['e', 0],
 ['e', 0],
 ['n', 0],
 [' ', 0],
 ['b', 0],
 ['a', 0],
 ['n', 0],
 ['n', 0],
 ['e', 0],
 ['d', 0],
 [' ', 0],
 ['f', 0],
 ['r', 0],
 ['o', 0],
 ['m', 0],
 [' ', 0],
 ['e', 0],
 ['n', 0],
 ['t', 0],
 ['e', 0],
 ['r', 0],
 ['i', 0