## Loading files

### Setup

In [1]:
# !pip install pyconll torchtext livelossplot

In [2]:
# !curl https://raw.githubusercontent.com/UniversalDependencies/UD_Ancient_Greek-Perseus/master/grc_perseus-ud-train.conllu -o data/perseus-conllu/grc_perseus-ud-train.conllu
# !curl https://raw.githubusercontent.com/UniversalDependencies/UD_Ancient_Greek-Perseus/master/grc_perseus-ud-dev.conllu -o data/perseus-conllu/grc_perseus-ud-dev.conllu
# !curl https://raw.githubusercontent.com/UniversalDependencies/UD_Ancient_Greek-Perseus/master/grc_perseus-ud-test.conllu -o data/perseus-conllu/grc_perseus-ud-test.conllu

### Parsing

In [1]:
import pyconll

In [2]:
def parse_into_list(body):
    data = []
    for sentence in body:
        sentence_words = []
        sentence_tags = []
        for token in sentence:
            sentence_words.append(token.form)
            sentence_tags.append(token.upos)

        if len(sentence_words) > 0:
            data.append((sentence_words, sentence_tags))
    
    return data

In [3]:
train_file = pyconll.load_from_file('data/perseus-conllu/grc_perseus-ud-train.conllu')
val_file = pyconll.load_from_file('data/perseus-conllu/grc_perseus-ud-dev.conllu')

train = parse_into_list(train_file)
val = parse_into_list(val_file)

len(train),len(val)

(11476, 1137)

In [4]:
word_to_ix = {}
for words, tags in train:
    for word in words:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
            
len(word_to_ix)

33237

In [5]:
tag_to_ix = {}
for sent, tags in train:
    for tag in tags:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

len(tag_to_ix)

14

## Model setup

Based on https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils

from model.lstm import LSTMTagger

In [8]:
torch.manual_seed(1)

EMBEDDING_DIM = 10
HIDDEN_DIM = 10

In [9]:
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

## Model training

In [12]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

%matplotlib inline

In [13]:
# TODO: evaluate loss on separate validation dataset

train_loader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True, num_workers=0)
train_losses = []

for epoch in range(50):
    total_loss = 0

    for i, data in enumerate(train_loader):        
        model.zero_grad()
        
        for i in range(len(data[0])):
            sentence = data[0][i]
            tags = data[1][i]
            
            sentence_in = prepare_sequence(sentence, word_to_ix)
            targets = prepare_sequence(tags, tag_to_ix)

            tag_scores = model(sentence_in)

            loss = loss_function(tag_scores, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss

    epoch_loss = total_loss / len(train)
    
    print('Epoch %d: %.4f' % (epoch, total_loss / len(train)))
    train_losses.append(total_loss / len(train))

Epoch 0: 0.2277
Epoch 1: 0.1658
Epoch 2: 0.1308
Epoch 3: 0.1091
Epoch 4: 0.0969
Epoch 5: 0.0887
Epoch 6: 0.0814
Epoch 7: 0.0745
Epoch 8: 0.0693


KeyboardInterrupt: 

## Inference

In [14]:
ix_to_tag = {v: k for k, v in tag_to_ix.items()}
ix_to_tag

{0: 'VERB',
 1: 'ADV',
 2: 'ADJ',
 3: 'NOUN',
 4: 'PUNCT',
 5: 'CCONJ',
 6: 'ADP',
 7: 'DET',
 8: 'PRON',
 9: 'SCONJ',
 10: 'INTJ',
 11: 'NUM',
 12: 'X',
 13: 'PART'}

In [52]:
sentence = train[0][0] # first sentence in the validation dataset
targets = train[0][1]

with torch.no_grad():
    inputs = prepare_sequence(sentence, word_to_ix)
    token_scores = model(inputs)
    scores = [score.tolist() for score in token_scores]
    tag_ix = [score.index(max(score)) for score in scores]
    tags = [ix_to_tag[tag] if tag in ix_to_tag else '' for tag in tag_ix]

    for i, (word, tag) in enumerate(zip(sentence, tags)):
        print('%s=%s (should be %s)' % (word, tag, targets[i]))

    

ἐρᾷ=ADJ (should be VERB)
μὲν=PART (should be ADV)
ἁγνὸς=ADJ (should be ADJ)
οὐρανὸς=VERB (should be NOUN)
τρῶσαι=NOUN (should be VERB)
χθόνα=NOUN (should be NOUN)
,=PUNCT (should be PUNCT)
ἔρως=NOUN (should be NOUN)
δὲ=PART (should be CCONJ)
γαῖαν=NOUN (should be NOUN)
λαμβάνει=PRON (should be VERB)
γάμου=VERB (should be NOUN)
τυχεῖν=VERB (should be VERB)
·=PUNCT (should be PUNCT)
