https://pypi.python.org/pypi/fasttext

https://github.com/facebookresearch/fastText

In [5]:
import fasttext
from nltk.stem.snowball import SnowballStemmer
from pathes import path_to_data

In [1]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

def get_accuracy(pred, y_test):
    count_corr = 0
    for i in zip(y_test, pred):
        if i[0] == i[1]:
            count_corr += 1
        
    return count_corr, (count_corr/len(y_test))

def get_metrics(predicted, correct):
    count_corr, accuracy = get_accuracy(predicted, correct)
    print ("all: " + str(len(predicted)) + " " + "correct: " + str(count_corr))
    print ("accuracy " + str(accuracy))
    print ("precision " + str(precision_score(correct, predicted)))
    print ("recall " + str(recall_score(correct, predicted)))
    print ("f1 " + str(f1_score(correct, predicted)))

In [2]:
def read_labels():
    labels = []
    with open(path_to_data + "labels") as f:
        for i in f.readlines():
            s = i.replace("\n", "")
            labels.append(s)

    return labels


def read_texts():
    texts = []

    with open(path_to_data + "texts") as f:
        for i in f.readlines():
            s = i.replace("\n", "")
            texts.append(s)

    return texts

In [3]:
def string_preprocessing(texts):
    # + english stemming?
    stemmer = SnowballStemmer("russian")

    validLetters = "abcdefghijklmnopqrstuvwxyzабвгдеёжзийклмнопрстуфхцчшщъыьэюя"

    for i in range(len(texts)):
        texts[i] = texts[i].lower()

        s = ""
        for char in texts[i]:
            if char in validLetters or char == ' ':
                s += char

        texts[i] = s.replace("  ", " ")

        l = [stemmer.stem(word) for word in texts[i].split(" ")]

        texts[i] = " ".join(l)

        if i % 10000 == 0:
            print (i)

    return texts

In [6]:
def update_text_for_fasttext():
    texts = read_texts()
    labels = read_labels()

    num_sep = len(texts) / 100 * 80
    num_sep = int(num_sep)

    num = 0
    with open(path_to_data + "texts_updated_train", "w") as f:
        for i in texts:
            f.write("__label__" + labels[num] + " " + i.lower() + "\n")
            num += 1

            if num_sep == num:
                break

    num = 0
    with open(path_to_data + "texts_updated_test", "w") as f:
        for i in texts:
            if num > num_sep:
                f.write("__label__" + labels[num] + " " + i.lower() + "\n")

            num += 1

update_text_for_fasttext()

In [7]:
def read_texts_labels_test():
    labels_test = []
    texts_test = []
    num = 0
    with open(path_to_data + "texts_updated_test", "r") as f:
        for i in f.readlines():
            s = i.replace("\n", "").replace("__label__", "")
            s_split = s.split(" ")
            labels_test.append(str(s_split[0]))
            texts_test.append(" ".join(s_split[1:]) + " " + str(num))
            
    return texts_test, labels_test

texts_test, labels_test = read_texts_labels_test()

In [8]:
classifier = fasttext.supervised(path_to_data + "texts_updated_train", 'model', label_prefix='__label__',
                                 epoch=30)

In [9]:
%time model = fasttext.load_model('model.bin', encoding='utf-8')

CPU times: user 122 ms, sys: 52.2 ms, total: 174 ms
Wall time: 178 ms


In [10]:
result = model.test(path_to_data + "texts_updated_test")
print ('P@1:', result.precision)
print ('R@1:', result.recall)
print ('Number of examples:', result.nexamples)

P@1: 0.5505917652942157
R@1: 0.5505917652942157
Number of examples: 5999


In [11]:
labels_pred = model.predict(texts_test)
labels_pred = [i[0].replace("__label__", "") for i in labels_pred]

In [12]:
labels_pred = [int(i) for i in labels_pred]
labels_test = [int(i) for i in labels_test]

In [13]:
get_metrics(labels_pred, labels_test)

all: 5999 correct: 3235
accuracy 0.5392565427571262
precision 0.588588588589
recall 0.0694296847326
f1 0.124207858048


In [92]:
texts = ['привет']
labels_probas = classifier.predict_proba(texts)

In [95]:
print (labels_probas[0][0])

('1', 0.880859)
