In [1]:
import pandas as pd
import random

### Читаем данные

In [2]:
df = pd.read_csv('data.csv', sep='\t')
df = df.sample(frac=1.0, random_state=42)[['query', 'title', 'label']].copy()

labels_map = {
    2: 1,
    1: -1,
    0: 0,
}

df['label'] = [labels_map[l] for l in df['label']]

In [3]:
dev = df.iloc[:1000].copy()
train = df.iloc[1000:9000][['query', 'title']].copy()
test = df.iloc[9000:].copy()

In [4]:
toloka_solutions = {(s['query'], s['title']): s['label'] for _, s in dev.iterrows()}

### Делаем разметчиков

In [5]:
from snorkel.labeling import LabelingFunction, labeling_function
from utils import sanitize, normalize


ABSTAIN = -1

@labeling_function()
def toloka_lf(x):
    return toloka_solutions.get((x['query'], x['title']), ABSTAIN)


@labeling_function()
def full_match_lf(x):
    q = x['query']
    t = x['title']
    q = sanitize(q)
    t = sanitize(t)
    if q == t:
        return 1
    return ABSTAIN


@labeling_function()
def no_match_lf(x):
    q = x['query']
    t = x['title']
    q = normalize(sanitize(q))
    t = normalize(sanitize(t))
    for word in q.split():
        if word in t:
            return ABSTAIN
    return 0


lfs = [
    toloka_lf, 
    full_match_lf, 
    no_match_lf,
]

In [6]:
from snorkel.labeling import PandasLFApplier

applier = PandasLFApplier(lfs)
L_train = applier.apply(train)
L_dev = applier.apply(dev)

100%|██████████| 8000/8000 [00:03<00:00, 2600.23it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2630.59it/s]


In [7]:
from snorkel.labeling import LFAnalysis

LFAnalysis(L_dev, lfs).lf_summary(dev.label.values).tail(5)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
toloka_lf,0,"[0, 1]",0.98,0.268,0.057,977,0,0.996939
full_match_lf,1,[1],0.161,0.161,0.0,161,0,1.0
no_match_lf,2,[0],0.114,0.107,0.057,50,57,0.438596


In [8]:
print(f"Training set coverage: {100 * LFAnalysis(L_train).label_coverage(): 0.1f}%")
print(f"Dev set coverage: {100 * LFAnalysis(L_dev).label_coverage(): 0.1f}%")

Training set coverage:  30.4%
Dev set coverage:  98.7%


### Обучаем разметочную модель

In [9]:
from snorkel.labeling import LabelModel

label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train, n_epochs=100, seed=123, log_freq=20, l2=0.1, lr=0.01)

In [10]:
from snorkel.analysis import metric_score

preds_dev = label_model.predict(L_dev)

acc = metric_score(dev.label.values, preds_dev, probs=None, metric="accuracy")
print(f"LabelModel Accuracy: {acc:.3f}")

LabelModel Accuracy: 0.990
