In [None]:
import numpy as np
import pandas as pd
import sh

from tqdm import tqdm_notebook

In [None]:
sensem_conll = '../../resources/sensem.conll'

test_size = 0.2

metadata_cols = ['META', 'sentence', 'corpus', 'sensem_sense', 'sensem_sentence',
                 'verb_lemma', 'verb_position', 'wn', 'wn16', 'wn30']

In [None]:
metadata = []
for mdata in sh.grep('^META', sensem_conll):
    metadata.append(dict(md.split(':') for md in mdata.strip().split()))

metadata = pd.DataFrame(metadata, columns=metadata_cols)

In [None]:
filtered = metadata.groupby(['verb_lemma', 'sensem_sense']).filter(lambda x: len(x) < 2 or x.sensem_sense.values[0] == '-').index
metadata.loc[filtered,'corpus'] = 'filtered'

non_filtered = metadata.groupby('sensem_sense').filter(lambda x: len(x) >= 2 and x.sensem_sense.values[0] != '-').index

In [None]:
labels = metadata.loc[non_filtered].sensem_sense

classes, y_counts = np.unique(labels, return_counts=True)
n_cls = classes.shape[0]
n_test = labels.shape[0] * test_size
n_train = labels.shape[0] - n_test

assert n_train >= n_cls and n_test >= n_cls

test_count = np.maximum(np.round(y_counts * test_size), np.ones(n_cls)).astype(np.int32)
train_count = (y_counts - test_count).astype(np.int32)

train_indices = []
test_indices = []

for idx, cls in enumerate(classes):
    labels_for_class = labels[labels == cls]

    train_indices.extend(labels_for_class[:train_count[idx]].index)
    test_indices.extend(labels_for_class[train_count[idx]:train_count[idx]+test_count[idx]].index)

train_indices = np.array(train_indices, dtype=np.int32)
test_indices = np.array(test_indices, dtype=np.int32)

metadata.loc[train_indices, 'corpus'] = 'train'
metadata.loc[test_indices, 'corpus'] = 'test'

In [None]:
sensem_meta = ('\t'.join(":".join(r) for r in zip(row.index, row)) for _, row in metadata.iterrows())

with open(sensem_conll, 'r') as fin, open('../../resources/sensem.new.conll', 'w') as fout:
    for line in tqdm_notebook(fin, total=840705):
        if line.startswith("META"):
            print(next(sensem_meta), file=fout)
        else:
            print(line.strip(), file=fout)