In [1]:
import json
from pathlib import Path
import numpy as np
from copy import deepcopy
import pandas as pd

from deeppavlov.core.commands.train import read_data_by_config, train_evaluate_model_from_config
from deeppavlov.core.commands.infer import interact_model, build_model_from_config
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.params import from_params
from deeppavlov.core.common.errors import ConfigError

In [2]:
# read unlabelled data for label propagation
def read_unlabelled_data(UNLABELLED_DATA_PATH):
    with open(UNLABELLED_DATA_PATH, "r") as f:
        unlabelled_data = f.read().splitlines()
    return unlabelled_data

In [3]:
def make_pl_config(CONFIG_PATH):
    config_path_pl = Path(CONFIG_PATH).parent / Path(Path(CONFIG_PATH).stem + "_pl.json")

    with open(CONFIG_PATH, "r") as f:
        config = json.load(f)
    
    config_pl = deepcopy(config)
    config_pl["dataset_reader"]["train"] = Path(config_pl["dataset_reader"].get("train", "train.csv")).stem + "_pl.csv"
    
    with open(config_path_pl, "w") as f:
        json.dump(config_pl, f, indent=2)
    
    return config, config_pl

In [4]:
def save_extended_data(config, samples, labels, new_config = None):
    train_data = read_data_by_config(deepcopy(config))
    
    for i in range(len(samples)):
        train_data["train"].append((samples[i], labels[i]))
    df = pd.DataFrame(train_data["train"], 
                      columns=[config["dataset_reader"]["x"], 
                               config["dataset_reader"]["y"]])
    df[config["dataset_reader"]["y"]] = df[config["dataset_reader"]["y"]].apply(
        lambda x: config["dataset_reader"].get("class_sep", ",").join(x))
    
    if new_config is not None:
        config = new_config
    file = expand_path(Path(config["dataset_reader"]["data_path"]) / 
                       Path(config["dataset_reader"]["train"]))

    if config["dataset_reader"].get("format", "csv") == "csv":
        keys = ('sep', 'header', 'names')
        df.to_csv(file, 
                  index=False,
                  sep=config["dataset_reader"].get("sep", ",")
                 )
    elif config["dataset_reader"].get("format", "csv") == "json":
        keys = ('orient', 'lines')
        df.to_json(file, 
                  index=False,
                  orient=config["dataset_reader"].get("orient", None),
                  lines=config["dataset_reader"].get("lines", False)
                  )
    else:
        raise ConfigError("Can not work with current data format")

In [5]:
# manually given parameters for pseudo-labeling

# path to config file
CONFIG_PATH = "../deeppavlov/configs/classifiers/yahoo_answers_L31_fulltext.json"
# path to file with unlabelled data
UNLABELLED_DATA_PATH = "../download/YahooAnswers/yahoo_answers_data/question_L6.txt"
# number of samples that are going to be labelled during one iteration of label propagation
ONE_ITERATION_PORTION = 2000
# number of iterations
N_ITERATIONS = 10
CLASSES_VOCAB_ID_IN_PIPE = 0
CONFIDENT_PROBA = 0.9

In [6]:
# read unlabelled dataset
unlabelled_data = read_unlabelled_data(UNLABELLED_DATA_PATH)
# read config, compose new one, save it
config, config_pl = make_pl_config(CONFIG_PATH)
# save initial dataset as extended
save_extended_data(config, [], [], new_config=config_pl)

In [None]:
available_unlabelled_ids = np.arange(len(unlabelled_data))

np.random.seed(42)

for i in range(N_ITERATIONS):
    samples = []
    labels = []
    
    ids_to_label = available_unlabelled_ids[
        np.random.randint(low=0, 
                          high=len(available_unlabelled_ids), 
                          size=ONE_ITERATION_PORTION)]
    available_unlabelled_ids = np.delete(available_unlabelled_ids, ids_to_label)
    train_evaluate_model_from_config(deepcopy(config_pl))
    model = build_model_from_config(deepcopy(config_pl))
    classes = np.array(list(from_params(
        deepcopy(config_pl["chainer"]["pipe"][CLASSES_VOCAB_ID_IN_PIPE])).keys()))

    for j, sample_id in enumerate(ids_to_label):
        prediction = model([unlabelled_data[sample_id]])[0]
        if len(np.where(np.array(prediction) > CONFIDENT_PROBA)[0]):
            samples.append(unlabelled_data[sample_id])
            labels.append(classes[np.where(np.array(prediction) > CONFIDENT_PROBA)])
    
    print("Iteration {}: add {} samples to train dataset".format(i, len(samples)))
    save_extended_data(config_pl, samples, labels)

2018-11-08 16:25:53.211 INFO in 'deeppavlov.core.data.simple_vocab'['simple_vocab'] at line 89: [saving vocabulary to /home/dilyara/Documents/GitHub/DeepPavlov/download/YahooAnswers/models/model_v8/yahoo_answers_classes.dict]
[nltk_data] Downloading package punkt to /home/dilyara/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/dilyara/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package perluniprops to
[nltk_data]     /home/dilyara/nltk_data...
[nltk_data]   Package perluniprops is already up-to-date!
[nltk_data] Downloading package nonbreaking_prefixes to
[nltk_data]     /home/dilyara/nltk_data...
[nltk_data]   Package nonbreaking_prefixes is already up-to-date!
Using TensorFlow backend.
2018-11-08 16:25:56.73 INFO in 'tensorflow'['tf_logging'] at line 159: Using /tmp/tfhub_modules to cache modules.
2018-11-08 16:25:56.573 DEBUG in 'tensorflow'['tf_logging'

2018-11-08 16:25:56.661 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_0/RNN/MultiRNNCell/Cell1/rnn/lstm_cell/bias:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_0/RNN/MultiRNNCell/Cell1/rnn/lstm_cell/bias
2018-11-08 16:25:56.664 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_0/RNN/MultiRNNCell/Cell1/rnn/lstm_cell/kernel:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_0/RNN/MultiRNNCell/Cell1/rnn/lstm_cell/kernel
2018-11-08 16:25:56.666 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_0/RNN/MultiRNNCell/Cell1/rnn/lstm_cell/projection/kernel:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_0/RNN/MultiRNNCell/Cell1/rnn/lstm_cell/projection/kernel
2018-11-08 16:25:56.668 DEBUG in '

{"valid": {"eval_examples_count": 403, "metrics": {"roc_auc": 0.5001, "sets_accuracy": 0.5484, "f1_macro": 0.4676}, "time_spent": "0:00:20", "epochs_done": 0, "batches_seen": 0, "train_examples_seen": 0, "impatience": 0, "patience_limit": 5}}
{"train": {"epochs_done": 1, "batches_seen": 4, "train_examples_seen": 3613, "metrics": {"roc_auc": 0.5038, "sets_accuracy": 0.5173, "f1_macro": 0.5016}, "time_spent": "0:06:50", "loss": 1.427233949303627}}


2018-11-08 16:33:05.807 INFO in 'deeppavlov.core.commands.train'['train'] at line 518: New best roc_auc of 0.5319
2018-11-08 16:33:05.808 INFO in 'deeppavlov.core.commands.train'['train'] at line 520: Saving model
2018-11-08 16:33:05.809 INFO in 'deeppavlov.models.classifiers.keras_classification_model'['keras_classification_model'] at line 375: [saving model to /home/dilyara/Documents/GitHub/DeepPavlov/download/YahooAnswers/models/model_v8/model_opt.json]
2018-11-08 16:33:05.834 INFO in 'deeppavlov.core.data.simple_vocab'['simple_vocab'] at line 100: [loading vocabulary from /home/dilyara/Documents/GitHub/DeepPavlov/download/YahooAnswers/models/model_v8/yahoo_answers_classes.dict]


{"valid": {"eval_examples_count": 403, "metrics": {"roc_auc": 0.5319, "sets_accuracy": 0.5583, "f1_macro": 0.3684}, "time_spent": "0:07:08", "epochs_done": 1, "batches_seen": 4, "train_examples_seen": 3613, "impatience": 0, "patience_limit": 5}}


2018-11-08 16:33:06.253 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/aggregation/scaling:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with aggregation/scaling
2018-11-08 16:33:06.257 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/aggregation/weights:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with aggregation/weights
2018-11-08 16:33:06.260 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/CNN/W_cnn_0:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/CNN/W_cnn_0
2018-11-08 16:33:06.266 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/CNN/W_cnn_1:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/CNN/W_cnn_1
2018-11-08 16:33:06.269 DEBU

2018-11-08 16:33:06.361 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/bias:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/bias
2018-11-08 16:33:06.365 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/kernel:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/kernel
2018-11-08 16:33:06.369 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/projection/kernel:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/projection/kernel
2018-11-08 16:33:06.373 DEBUG in '

{"valid": {"eval_examples_count": 403, "metrics": {"roc_auc": 0.5319, "sets_accuracy": 0.5583, "f1_macro": 0.3684}, "time_spent": "0:00:20"}}


  'precision', 'predicted', average, warn_for)
2018-11-08 16:33:28.294 INFO in 'deeppavlov.core.data.simple_vocab'['simple_vocab'] at line 100: [loading vocabulary from /home/dilyara/Documents/GitHub/DeepPavlov/download/YahooAnswers/models/model_v8/yahoo_answers_classes.dict]


{"test": {"eval_examples_count": 100, "metrics": {"roc_auc": 0.5, "sets_accuracy": 0.62, "f1_macro": 0.3827}, "time_spent": "0:00:02"}}


2018-11-08 16:33:28.737 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/aggregation/scaling:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with aggregation/scaling
2018-11-08 16:33:28.740 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/aggregation/weights:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with aggregation/weights
2018-11-08 16:33:28.743 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/CNN/W_cnn_0:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/CNN/W_cnn_0
2018-11-08 16:33:28.746 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/CNN/W_cnn_1:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/CNN/W_cnn_1
2018-11-08 16:33:28.752 DEBU

2018-11-08 16:33:28.845 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/bias:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/bias
2018-11-08 16:33:28.848 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/kernel:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/kernel
2018-11-08 16:33:28.852 DEBUG in 'tensorflow'['tf_logging'] at line 100: Initialize variable module/bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/projection/kernel:0 from checkpoint b'/tmp/tfhub_modules/9bb74bc86f9caffc8c47dd7b33ec4bb354d9602d/variables/variables' with bilm/RNN_1/RNN/MultiRNNCell/Cell0/rnn/lstm_cell/projection/kernel
2018-11-08 16:33:28.856 DEBUG in '