In [2]:
import re
import operator
import itertools
import random
from itertools import islice
from functools import reduce

import nlpaug.augmenter.word as naw
from datasets import Dataset, DatasetDict, ClassLabel

from jadoch.data import costep
from jadoch.data.costep import language, contains, starts_with, some
from jadoch.core.functional import ilen, save_iter

### Data Preparation
---

In [3]:
german = language("german")
english = language("english")

In [4]:
# Hopefully this ensures it's the correct "ja".
ignores = map(
    contains,
    [
        "sagt ja",
        "sagen ja",
        "sage ja",
        "sagten ja",
        "sagte ja",
        "ja oder",
        "ja zum" ,
        "ja zur",
        "ja zu"
    ]
)
fltr = german(contains("ja") & ~starts_with("ja") & ~some(*ignores)) & english(~contains("yes"))

In [5]:
def labelify(label):
    def func(val):
        return (val, label)
    return func


def mapify(functions, iterable):
    fns = iter(functions)
    for fn in fns:
        iterable = map(fn, iterable)
    return iterable

In [6]:
def search(fltr, label, fn=lambda v: v):
    return mapify(
        (
            operator.itemgetter("english"),
            fn,
            labelify(label)
        ),
        filter(
            fltr, 
            costep.sentences("english", "german")
        )
    )

In [7]:
next(search(fltr, "ja")), next(search(~fltr, "na"))

(('After all, we have in you an expert who is in any case closely concerned with these matters.',
  'ja'),
 ('I declare resumed the session of the European Parliament adjourned on Thursday, 28 March 1996.',
  'na'))

### Augmentation
---

In [8]:
def augmentify(model_path="distilbert-base-uncased", action="substitute"):
    ctx = naw.ContextualWordEmbsAug(model_path=model_path, action=action)
    def func(txt):
        return ctx.augment(txt)
    return func

In [9]:
next(search(fltr, "ja")), next(search(fltr, "ja", augmentify()))

(('After all, we have in you an expert who is in any case closely concerned with these matters.',
  'ja'),
 ('after all, we must paid you an expert who assisting in the case severely concerned with these matters.',
  'ja'))

### Data Generation
---

In [10]:
def split(itr, pct):
    items = list(itr)
    idx = round(len(items) * pct)
    return items[:idx], items[idx:]


def partition(iterable, sizes):
    it = iter(iterable)

    for size in sizes:
        if size is None:
            yield list(it)
            return
        else:
            yield list(islice(it, size))


def generate(fltr, limit):    
    # jas = search(fltr, "ja")
    jas = itertools.chain(*[search(fltr, "ja")] + [search(fltr, "ja", augmentify()) for _ in range(10)])
    nas = search(~fltr, "na")
    train_jas, test_jas = split(islice(jas, limit), 0.8)
    train_nas, test_nas = split(islice(nas, limit), 0.8)
    # train_nas, test_nas = partition(nas, [len(train_jas), None])
    training_data = train_jas + train_nas
    testing_data = test_jas + test_nas
    random.shuffle(training_data)
    random.shuffle(testing_data)
    class_label = ClassLabel(num_classes=2, names=["na", "ja"]) # XXX: ???
    reshape = lambda dt: {
        "text": [tup[0] for tup in dt],
        "label": list(map(class_label.str2int, [tup[1] for tup in dt]))
    }
    return DatasetDict({
        "train": Dataset.from_dict(reshape(training_data)),
        "test": Dataset.from_dict(reshape(testing_data))
    })

In [None]:
data = generate(fltr, None)

In [63]:
data#.save_to_disk("/gpfs/scratch/asoubki/data/english-balanced-train-test")