In [0]:
!pip install sklearn_crfsuite

In [0]:
import re
import random


class DataLoader:
    def __init__(self):         
        sents = self.data_reader('/content/drive/My Drive/Colab Notebooks/ezafe/data/bijankhan_corpus.tsv')
        
        random.seed(17)
        random.shuffle(sents)
        data_split_1 = int(len(sents) * .1)
        data_split_2 = int(len(sents) * .2)

        self.test_data = sents[:data_split_1]
        self.dev_data = sents[data_split_1:data_split_2]
        self.train_data = sents[data_split_2:]


    def data_reader(self, directory):
        sents, sent = [], []
        with open(directory) as corpus:
            for line in corpus:
                if line != '\n':
                    word, tag, ezafe_tag = line.strip().split()
                    word = word.replace('ي', 'ی').replace('ك', 'ک').replace('ة', 'ه')
                    sent.append(((word, int(ezafe_tag)), tag))
                else:
                    sents.append(sent)
                    sent = []
        
        return sents

    def data_loader(self):
        return self.train_data, self.dev_data, self.test_data

In [0]:
import pickle

from nltk.tag.util import untag
from sklearn_crfsuite import CRF
from sklearn_crfsuite import metrics


def features(sentence, index):
    # ezafe_tags = [x[1] for x in sentence]
    sentence = [x[0] for x in sentence]
    """ sentence: [c1, c2, ...], index: the index of the char """
    return {
        'word': sentence[index],
        'is_first': index == 0,
        'is_last': index == len(sentence) - 1,
        'is_2_to_last': index == len(sentence) - 2,
        'prefix-1': sentence[index][0],
        'prefix-2': sentence[index][:2],
        'prefix-3': sentence[index][:3],
        'suffix-1': sentence[index][-1],
        'suffix-2': sentence[index][-2:],
        'suffix-3': sentence[index][-3:],
        'prev_char': '' if index == 0 else sentence[index - 1],
        '2_prev_char': '' if index == 0 or index == 1 else sentence[index - 2],
        '3_prev_char': '' if index == 0 or index == 1 or index == 2 else sentence[index - 3],
        '4_prev_char': '' if index in range(0, 4) else sentence[index - 4],
        '5_prev_char': '' if index in range(0, 5) else sentence[index - 5],
        'next_char': '' if index >= (len(sentence) - 1) else sentence[index + 1],
        'next_char_2': '' if index >= (len(sentence) - 2) else sentence[index + 2],
        'next_char_3': '' if index >= (len(sentence) - 3) else sentence[index + 3],
        'next_char_4': '' if index >= (len(sentence) - 4) else sentence[index + 4],
        'next_char_5': '' if index >= (len(sentence) - 5) else sentence[index + 5],
    }

# transform the dataset from [[('ali', 'N'), ('be', 'P'), ('madrese', 'N'), ('raft', 'V')]]
# to [features], [tags]
def transform_to_dataset(tagged_sentences):
	X, y = [], []
	for tagged in tagged_sentences:
		X.append([features(untag(tagged), index) for index in range(len(tagged))])
		y.append([tag for _, tag in tagged])
	
	return X, y

# instantiating the DataLoader module and loading the data
data_loader = DataLoader()
training_sentences, dev_sentences, test_sentences = data_loader.data_loader()

print(training_sentences[0])

X_train, y_train = transform_to_dataset(training_sentences)
X_dev, y_dev = transform_to_dataset(dev_sentences)
X_test, y_test = transform_to_dataset(test_sentences)

# some printing for the sake of debugging
print(len(X_train))     
print(len(X_test))
print(X_train[0])
print(y_train[0])

# defining the model
model = CRF(algorithm='lbfgs',
    	      c1=0.1,
            c2=0.1,
            max_iterations=100,
            all_possible_transitions=True)

# training the model
model.fit(X_train, y_train)
  
# validating the model
y_pred = model.predict(X_dev)
print(metrics.flat_classification_report(y_dev, y_pred, digits=4))

# saving the model
with open('crf_model.pickle', 'wb') as handle:
	 pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [0]:
# testing the model
y_pred = model.predict(X_test)
print(metrics.flat_classification_report(y_test, y_pred, digits=4))