In [1]:
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import precision_recall_fscore_support, classification_report
from itertools import chain
import os
from xml.etree import ElementTree as ET
from tempfile import NamedTemporaryFile
import numpy as np
import csv
from collections import defaultdict, Counter, namedtuple
import fastText

### Źródło danych 
[Polski Korpus Sejmowy](http://clip.ipipan.waw.pl/PSC) - Sittings, kadencja 8

In [2]:
root = "PSC/Posiedzenia/kadencja8"
dirs = os.listdir(root)
ns = {'namespace': 'http://www.tei-c.org/ns/1.0'}

In [3]:
def fullname(header_elem):
    _id = "#" + header_elem.attrib['{http://www.w3.org/XML/1998/namespace}id']
    function_name = header_elem.find('./namespace:persName', ns).text
    name = " ".join(function_name.split()[-2:])
    return name, _id 

In [4]:
def parse_texts():
    texts = map(lambda d: ET.parse(f"{root}/{d}/text_structure.xml").getroot(),
                dirs)
    texts_elems = chain.from_iterable(
        map(lambda e: e.findall('.//namespace:u', ns), texts))
    sentences = map(lambda e: (e.get('who'), e.text), texts_elems)
    return sentences


def parse_headers():
    headers = map(lambda d: ET.parse(f"{root}/{d}/header.xml").getroot(), dirs)
    person_elems = chain.from_iterable(
        map(lambda x: x.findall(".//namespace:person", ns), headers))
    persons = map(fullname, person_elems)
    name_id = dict(set(persons))
    return name_id


In [5]:
def name_partia_dict(name_id):
    name_partia = dict()
    with open('poslowie.csv') as csv_file:
        reader = csv.reader(csv_file, delimiter=';')
        for x in reader:
            if x[0] in name_id:
                _id = name_id[x[0]]
                name_partia[_id] = x[1]
    return name_partia

In [6]:
def labeled_sentences():
    name_id_dict = parse_headers()
    name_partia = name_partia_dict(name_id_dict)
    for sentence in parse_texts():
        name_id = sentence[0]
        if name_id in name_partia:
            partia = name_partia[sentence[0]] 
            text = sentence[1]
            labeled_text = f"__label__{partia} {text}"
            yield  labeled_text, partia

In [7]:
def print_results(y_true, predictions, average):
    precision, recall, fscore, _ = precision_recall_fscore_support(
        y_true, predictions, average=average)
    print()
    print(average)
    print('Precision', precision)
    print('Recall', recall)
    print('F1', fscore)

In [8]:
def prepare_train_file(x_train):
    train_file = NamedTemporaryFile(mode='w')
    for x in x_train:
        train_file.write(x + "\n")
    return train_file

In [9]:
def evaluate(model,x_test):
    predictions = list(map(lambda s: model.predict(s)[0][0], x_test))
    y_true = [sentence.split()[0] for sentence in x_test]
    print_results(y_true, predictions, average='micro')
    print_results(y_true, predictions, average='macro')
    print(classification_report(y_true, predictions))

In [10]:
data_x, data_y = zip(*labeled_sentences())
x_data = np.array(data_x)

In [11]:
stratified_split = StratifiedShuffleSplit(n_splits=1, test_size=0.25)

for train_index, test_index in stratified_split.split(data_x, data_y):
    x_train, x_test = x_data[train_index], x_data[test_index]
    train_file = prepare_train_file(x_train)
    model = fastText.train_supervised(
        input=train_file.name,
        epoch=25,
        lr=1.0,
        wordNgrams=2,
        verbose=2,
        minCount=1)
    evaluate(model,x_test)



micro
Precision 0.5708110287365332
Recall 0.5708110287365332
F1 0.5708110287365332

macro
Precision 0.5202335429680457
Recall 0.4142591483921965
F1 0.4482066781834884
                     precision    recall  f1-score   support

   __label__Kukiz15       0.50      0.35      0.41      3459
__label__Nowoczesna       0.46      0.36      0.40      4802
       __label__PIS       0.63      0.67      0.65     14627
        __label__PO       0.56      0.70      0.62     15729
       __label__PSL       0.49      0.32      0.39      3345
       __label__WIS       0.66      0.38      0.49      1545
     __label__other       0.34      0.12      0.18      1140

        avg / total       0.56      0.57      0.56     44647

