In [80]:
from pathlib import Path
import pandas as pd
import numpy as np
# from nltk.tokenize import wordpunct_tokenize
from pymystem3 import Mystem
import pickle
from collections import defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer
import csv


_mystem = Mystem()

DATA_FOLDER = Path('~/data/EUS').expanduser()

unsup_quesions = list(pd.read_csv(DATA_FOLDER / 'zpp_questions.example.csv')['Название'])

with (DATA_FOLDER / 'classes.list').open() as f:
    class_names = [s.replace('\n','') for s in f.readlines()]

ANSWERS_1ST_TRY = DATA_FOLDER / 'first_classification.csv'
    

In [38]:
def lemmatize(s, tries=0):
    global _mystem
    try:
        return [tok.strip() for tok in _mystem.lemmatize(s) if tok.strip()]
    except BrokenPipeError:
        if tries > 10:
            raise
        _mystem = Mystem()
        return lemmatize(s, tries+1)

with (DATA_FOLDER / 'senses.pkl').open('rb') as f:
    abbr2senses = pickle.load(f)
    
_lemmatized_senses = {}

sense2abbr = defaultdict(list)
for abbr, senses in abbr2senses.items():
    for s in senses:
        l = ''.join(_mystem.lemmatize(s)).strip()
        sense2abbr[l].append(abbr)

for sense, abbrs in sense2abbr.items():
    main_abbr = abbrs[0]
    _lemmatized_senses[sense] = main_abbr
    for a in abbrs[1:]:
        _lemmatized_senses[a] = main_abbr
        
def desynonimize(text, synset=None):
    """ Lemmatize and desynonimize text
    """

    if synset is None:
        synset = _lemmatized_senses

    lemmas = lemmatize(text)
    lemtext = ' '.join(lemmas).strip()

    for s in synset:
        if s in lemtext:
            lemtext = lemtext.replace(s, synset[s].lower())
    return lemtext

In [45]:
%time tfidf = TfidfVectorizer().fit([desynonimize(c) for c in class_names+unsup_quesions])

CPU times: user 7.71 s, sys: 345 ms, total: 8.05 s
Wall time: 24.7 s


In [46]:
class_tfidf = tfidf.transform([desynonimize(c) for c in class_names])

  if hasattr(X, 'dtype') and np.issubdtype(X.dtype, np.float):


In [52]:
unsup_tfidf = tfidf.transform([desynonimize(c) for c in unsup_quesions])

  if hasattr(X, 'dtype') and np.issubdtype(X.dtype, np.float):


In [81]:
res = (class_tfidf @ unsup_tfidf.T).T

with ANSWERS_1ST_TRY.open('w') as f:
    cw = csv.writer(f)
    cw.writerow(['question', 'class'])
    
    for i, q in enumerate(unsup_quesions):
        t = res[i].toarray()[0]
        a = t.argmax()
        cw.writerow([q, class_names[a]])