In [41]:
import pyconll #pip3 install this if you don't have it
import torchtext.data as tt
import torch 
import torch.nn as nn
import torch.optim as optim

In [51]:
AFRIKAANS_TRAIN = 'UD_Afrikaans-AfriBooms/af_afribooms-ud-train.conllu'
AFRIKAANS_DEV = 'UD_Afrikaans-AfriBooms/af_afribooms-ud-dev.conllu'
AFRIKAANS_TEST = 'UD_Afrikaans-AfriBooms/af_afribooms-ud-test.conllu'

In [58]:
# from https://github.com/soutsios/pos-tagger-bert/blob/master/pos_tagger_bert.ipynb
def make_sentences(path):
    data = pyconll.load_from_file(path)
    sentences = []
    tagged_sentences = []
    for each in data:
        tagged_sentence=[]
        sentence = []
        for token in each:
            if token.upos and token.form:
                tagged_sentence.append(token.upos)
                sentence.append(token.form.lower())
        tagged_sentences.append(tagged_sentence)
        sentences.append(sentence)
    return sentences, tagged_sentences

In [79]:
train_afr, tagged_train_afr = make_sentences(AFRIKAANS_TRAIN)
dev_afr, tagged_dev_afr = make_sentences(AFRIKAANS_DEV)
test_afr, tagged_test_afr = make_sentences(AFRIKAANS_TEST)

In [80]:
print(tagged_train_afr[0])

['DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'ADJ', 'NOUN', 'AUX', 'ADP', 'DET', 'NUM', 'NOUN', 'ADP', 'NOUN', 'PART', 'VERB', 'ADP', 'NOUN', 'PRON', 'ADP', 'DET', 'ADJ', 'NOUN', 'VERB', 'AUX', 'PUNCT']


In [81]:
print("Tagged sentences in train set: ", len(tagged_train_afr))
print("Tagged words in train set:", len([item for sublist in tagged_train_afr for item in sublist]))
print(40*'=')
print("Tagged sentences in dev set: ", len(tagged_dev_afr))
print("Tagged words in dev set:", len([item for sublist in tagged_dev_afr for item in sublist]))
print(40*'=')
print("Tagged sentences in test set: ", len(tagged_test_afr))
print("Tagged words in test set:", len([item for sublist in tagged_test_afr for item in sublist]))
print(40*'*')
print("Total sentences in dataset:", len(tagged_train_afr)+len(tagged_dev_afr)+len(tagged_dev_afr))

Tagged sentences in train set:  1315
Tagged words in train set: 33894
Tagged sentences in dev set:  194
Tagged words in dev set: 5317
Tagged sentences in test set:  425
Tagged words in test set: 10065
****************************************
Total sentences in dataset: 1703


In [82]:
# from https://github.com/tringm/POSTagger_Pytorch/blob/master/src/util/nlp.py
def build_tag_field(sentences_tokens):
    token_field = tt.Field(tokenize=list, init_token="<bos>", eos_token="<eos>")
    fields = [('tokens', token_field)]
    examples = [tt.Example.fromlist([t], fields) for t in sentences_tokens]
    torch_dataset = tt.Dataset(examples, fields)
    token_field.build_vocab(torch_dataset)
    return token_field
    
def build_text_field(sentences_words):
    text_field = tt.Field(tokenize=list, init_token="<bos>", eos_token="<eos>")
    fields = [('text', text_field)]
    examples = [tt.Example.fromlist([t], fields) for t in sentences_words]
    torch_dataset = tt.Dataset(examples, fields)
    vocab_field.build_vocab(torch_dataset)
    return text_field

In [83]:
train_afr = build_vocab_field(train_afr)
dev_afr = build_vocab_field(dev_afr)
test_afr = build_vocab_field(test_afr)

tagged_train_afr = build_tag_field(tagged_train_afr)
tagged_dev_afr = build_tag_field(tagged_dev_afr)
tagged_test_afr = build_tag_field(tagged_test_afr)

train_data = tt.Dataset([],(("text", train_afr), ("udtags", tagged_train_afr)))
valid_data = tt.Dataset([],(("text", dev_afr), ("udtags", tagged_dev_afr)))
test_data = tt.Dataset([],(("text", test_afr), ("udtags", tagged_test_afr)))

In [84]:
print(train_data)

<torchtext.data.dataset.Dataset object at 0x7fabcf2addd8>


In [86]:
# from https://github.com/bentrevett/pytorch-pos-tagging/blob/master/1%20-%20BiLSTM%20for%20PoS%20Tagging.ipynb
#model
BATCH_SIZE=128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#needs to be tuple of dataset objects
train_iterator, valid_iterator, test_iterator = tt.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

In [None]:
class BiLSTMTagger():
    #https://github.com/bentrevett/pytorch-pos-tagging/blob/master/1%20-%20BiLSTM%20for%20PoS%20Tagging.ipynb