In [40]:
import pyconll #pip3 install this if you don't have it
import torchtext.data as tt
import torch 
import torch.nn as nn
import torch.optim as optim
import time

In [41]:
AFRIKAANS_TRAIN = 'UD_Afrikaans-AfriBooms/af_afribooms-ud-train.conllu'
AFRIKAANS_DEV = 'UD_Afrikaans-AfriBooms/af_afribooms-ud-dev.conllu'
AFRIKAANS_TEST = 'UD_Afrikaans-AfriBooms/af_afribooms-ud-test.conllu'

In [60]:
# from https://github.com/soutsios/pos-tagger-bert/blob/master/pos_tagger_bert.ipynb
def make_sentences(path):
    data = pyconll.load_from_file(path)
    sentences = []
    tagged_sentences = []
    for each in data:
        tagged_sentence=[]
        sentence = []
        for token in each:
            if token.upos and token.form:
                tagged_sentence.append(token.upos)
                sentence.append(token.form.lower())
        tagged_sentences.append(tagged_sentence)
        sentences.append(sentence)
    return sentences, tagged_sentences

In [61]:
train_afr_raw, tagged_train_afr_raw = make_sentences(AFRIKAANS_TRAIN)
dev_afr_raw, tagged_dev_afr_raw = make_sentences(AFRIKAANS_DEV)
test_afr_raw, tagged_test_afr_raw = make_sentences(AFRIKAANS_TEST)

In [64]:
print("Tagged sentences in train set: ", len(tagged_train_afr_raw))
print("Tagged words in train set:", len([item for sublist in tagged_train_afr_raw for item in sublist]))
print(40*'=')
print("Tagged sentences in dev set: ", len(tagged_dev_afr_raw))
print("Tagged words in dev set:", len([item for sublist in tagged_dev_afr_raw for item in sublist]))
print(40*'=')
print("Tagged sentences in test set: ", len(tagged_test_afr_raw))
print("Tagged words in test set:", len([item for sublist in tagged_test_afr_raw for item in sublist]))
print(40*'*')
print("Total sentences in dataset:", len(tagged_train_afr_raw)+len(tagged_dev_afr_raw)+len(tagged_dev_afr_raw))

Tagged sentences in train set:  1315
Tagged words in train set: 33894
Tagged sentences in dev set:  194
Tagged words in dev set: 5317
Tagged sentences in test set:  425
Tagged words in test set: 10065
****************************************
Total sentences in dataset: 1703


In [65]:
# from https://github.com/tringm/POSTagger_Pytorch/blob/master/src/util/nlp.py
def build_tag_field(sentences_tokens):
    token_field = tt.Field(tokenize=list, init_token="<bos>", eos_token="<eos>")
    fields = [('tokens', token_field)]
    examples = [tt.Example.fromlist([t], fields) for t in sentences_tokens]
    torch_dataset = tt.Dataset(examples, fields)
    token_field.build_vocab(torch_dataset)
    return token_field
    
def build_text_field(sentences_words):
    text_field = tt.Field(tokenize=list, init_token="<bos>", eos_token="<eos>")
    fields = [('text', text_field)]
    examples = [tt.Example.fromlist([t], fields) for t in sentences_words]
    torch_dataset = tt.Dataset(examples, fields)
    text_field.build_vocab(torch_dataset)
    return text_field

In [74]:
#fields
train_afr = build_text_field(train_afr_raw)
dev_afr = build_text_field(dev_afr_raw)
test_afr = build_text_field(test_afr_raw)
tagged_train_afr = build_tag_field(tagged_train_afr_raw)
tagged_dev_afr = build_tag_field(tagged_dev_afr_raw)
tagged_test_afr = build_tag_field(tagged_test_afr_raw)

fields_train = (("text", train_afr), ("udtags", tagged_train_afr))
examples_train = [tt.Example.fromlist(item, fields_train) for item in zip(train_afr_raw, tagged_train_afr_raw)]
fields_dev = (("text", dev_afr), ("udtags", tagged_dev_afr))
examples_dev = [tt.Example.fromlist(item, fields_dev) for item in zip(dev_afr_raw, tagged_dev_afr_raw)]
fields_test = (("text", test_afr), ("udtags", tagged_test_afr))
examples_test = [tt.Example.fromlist(item, fields_test) for item in zip(test_afr_raw, tagged_test_afr_raw)]
train_data = tt.Dataset(examples_train, fields_train)
valid_data = tt.Dataset(examples_dev, fields_dev)
test_data = tt.Dataset(examples_dev, fields_dev)

In [75]:
print(train_afr)

<torchtext.data.field.Field object at 0x7f83fed2ca90>


In [76]:
print(train_data)

<torchtext.data.dataset.Dataset object at 0x7f83fed30080>


In [77]:
len(train_data)

1315

In [96]:
# from https://github.com/bentrevett/pytorch-pos-tagging/blob/master/1%20-%20BiLSTM%20for%20PoS%20Tagging.ipynb
#model
batch_size=128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#needs to be tuple of dataset objects
train_iterator, valid_iterator, test_iterator = tt.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = batch_size,
    device = device, sort=False)

In [97]:
# try without dropout first
class BiLSTMTagger(nn.Module):
    #https://github.com/bentrevett/pytorch-pos-tagging/blob/master/1%20-%20BiLSTM%20for%20PoS%20Tagging.ipynb
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional)
        #fully connected layer
        self.fc = nn.Linear((hidden_dim * 2 if bidirectional else hidden_dim), output_dim)
     
    
    def forward(self, text):
        embedded = self.embedding(text)
        outputs, (hidden, cell) = self.lstm(embedded)
        predictions = self.fc(outputs)
        return predictions

In [98]:
in_dim = len(train_afr.vocab)
emb_dim = 100
hid_dim = 128
out_dim = len(tagged_train_afr.vocab)
n_layers = 2
bidirectional = True
pad_index = train_afr.vocab.stoi[train_afr.pad_token]
tag_pad_idx = tagged_train_afr.vocab.stoi[tagged_train_afr.pad_token]

In [99]:
model = BiLSTMTagger(in_dim, emb_dim, hid_dim, out_dim, n_layers, bidirectional, pad_index)
criterion = nn.CrossEntropyLoss(ignore_index = tag_pad_idx)
optimizer = optim.Adam(model.parameters())

In [100]:
def categorical_accuracy(preds, y, tag_pad_idx):
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]])

In [101]:
def train(model, iterator, optimizer, criterion, tag_pad_idx):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    
    for batch in iterator:
        text = batch.text
        tags = batch.udtags
        
        optimizer.zero_grad()       
        predictions = model(text)        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        
        loss = criterion(predictions, tags) 
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [102]:
def evaluate(model, iterator, criterion, tag_pad_idx):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for batch in iterator:
            text = batch.text
            tags = batch.udtags
            
            predictions = model(text)
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [103]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 10

best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, tag_pad_idx)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, tag_pad_idx)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 12s
	Train Loss: 2.735 | Train Acc: 22.42%
	 Val. Loss: 2.572 |  Val. Acc: 19.63%
Epoch: 02 | Epoch Time: 0m 15s
	Train Loss: 2.226 | Train Acc: 32.53%
	 Val. Loss: 2.228 |  Val. Acc: 35.88%
Epoch: 03 | Epoch Time: 0m 16s
	Train Loss: 1.741 | Train Acc: 50.04%
	 Val. Loss: 2.036 |  Val. Acc: 43.07%
Epoch: 04 | Epoch Time: 0m 10s
	Train Loss: 1.242 | Train Acc: 66.40%
	 Val. Loss: 2.183 |  Val. Acc: 42.45%
Epoch: 05 | Epoch Time: 0m 9s
	Train Loss: 0.907 | Train Acc: 73.46%
	 Val. Loss: 2.492 |  Val. Acc: 42.80%
