## Iteractively refine dataset labels

In [1]:
import os
import pathlib
import typing as t
import random
import requests
import json as json_package
import collections
import glob

import IPython.display
import ipywidgets
import datasets
import tokenizers
import ipywidgets
import colorama
import tqdm

import segmentador
from config import *
import interactive_labeling


%load_ext autoreload
%autoreload 2


CN = colorama.Fore.RED
CT = colorama.Fore.YELLOW
CLN = colorama.Fore.CYAN
CSR = colorama.Style.RESET_ALL


VOCAB_SIZE = 6000
AUTOSAVE_IN_N_INSTANCES = 20


DATASET_SPLIT = "test"
CACHED_INDEX_FILENAME = f"{DATASET_SPLIT}_refined_indices.csv"
SKIPPED_INDEX_FILENAME = f"{DATASET_SPLIT}_skipped_indices.csv"
BRUTE_DATASET_DIR = "../data"

TARGET_DATASET_NAME = f"df_tokenized_split_0_120000_{VOCAB_SIZE}_resplit"

REFINED_DATASET_DIR = os.path.join(BRUTE_DATASET_DIR, "refined_datasets", TARGET_DATASET_NAME)

BRUTE_DATASET_URI = os.path.join(BRUTE_DATASET_DIR, TARGET_DATASET_NAME)
CACHED_INDEX_URI = os.path.join(REFINED_DATASET_DIR, CACHED_INDEX_FILENAME)
SKIPPED_INDEX_URI = os.path.join(REFINED_DATASET_DIR, SKIPPED_INDEX_FILENAME)


assert BRUTE_DATASET_URI != REFINED_DATASET_DIR


logit_model = segmentador.BERTSegmenter(
    uri_model=f"../pretrained_segmenter_model/2_{VOCAB_SIZE}_layer_model",
    device="cpu",
)
# logit_model = segmentador.LSTMSegmenter(
#     uri_model=
#         os.path.join(
#             f"../pretrained_segmenter_model/512_{VOCAB_SIZE}_1_lstm",
#             "checkpoints/512_hidden_dim_6000_vocab_size_1_layer_lstm.pt"
#         ),
#     device="cpu",
# )

pbar_dump = tqdm.auto.tqdm(desc="Instances until next dump:", total=AUTOSAVE_IN_N_INSTANCES)
it_counter = 0

lock_save = False
random.seed(17)

Instances until next dump::   0%|          | 0/20 [00:00<?, ?it/s]

In [2]:
tokenizer = tokenizers.Tokenizer.from_file(f"../tokenizers/{VOCAB_SIZE}_subwords/tokenizer.json")
tokenizer.get_vocab_size()

6000

## Load dataset

In [3]:
try:
    with open(CACHED_INDEX_URI, "r") as f_index:
        cached_indices = set(map(int, f_index.read().split(",")))

    print(f"Loaded {len(cached_indices)} indices from disk.")

except FileNotFoundError:
    cached_indices = set()


try:
    with open(SKIPPED_INDEX_URI, "r") as f_index:
        skipped_indices = set(map(int, f_index.read().split(",")))

    print(f"Loaded {len(skipped_indices)} skipped indices from disk.")

except (FileNotFoundError, ValueError):
    skipped_indices = set()


new_refined_instances = collections.defaultdict(list)
df_brute = datasets.load_from_disk(BRUTE_DATASET_URI)
df_brute

Loaded 1851 indices from disk.


DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 138580
    })
    eval: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2015
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2016
    })
})

## Iteractive refinement

In [4]:
df_split = df_brute[DATASET_SPLIT]

fn_rand = lambda: random.randint(0, df_split.num_rows)

df_view_labels = df_split["labels"]
df_view_input_ids = df_split["input_ids"]

cls2id = {
    "no-op": 0,
    "seg": 1,
    "n-start": 2,
    "n-end": 3,
}


def print_labels(input_ids, labels, input_is_tokens: bool = False):
    seg_counter = 1

    print(end=CSR)
    print(end=f"{CLN}{seg_counter}.{CSR} ")

    if not input_is_tokens:
        tokens = list(map(tokenizer.id_to_token, input_ids))

    else:
        tokens = input_ids

    for i, (tok, lab) in enumerate(zip(tokens, labels)):
        if lab == cls2id["seg"]:
            seg_counter += 1
            print("\n\n", end=f"{CLN}{seg_counter}.{CSR} ")

        if lab == cls2id["n-start"]:
            print(end=CN)

        if lab == cls2id["n-end"]:
            print(end=CSR)

        print(tok, end=" ")

    return tokens


def dump_refined_dataset():
    new_subset = datasets.Dataset.from_dict(new_refined_instances, split=DATASET_SPLIT)
    shard_id = len(glob.glob(os.path.join(REFINED_DATASET_DIR, f"{DATASET_SPLIT}_*")))
    REFINED_DATASET_SHARD_URI = os.path.join(REFINED_DATASET_DIR, f"{DATASET_SPLIT}_{shard_id}")
    new_subset.save_to_disk(REFINED_DATASET_SHARD_URI)
    new_refined_instances.clear()

    with open(CACHED_INDEX_URI, "w") as f_index:
        f_index.write(",".join(map(str, sorted(cached_indices))))

    with open(SKIPPED_INDEX_URI, "w") as f_index:
        f_index.write(",".join(map(str, sorted(skipped_indices))))

    it_counter = 0

    print(f"Saved progress in '{REFINED_DATASET_SHARD_URI}'.")


def fn_run_cell(index_shift: int):
    js = IPython.display.Javascript(
        f"Jupyter.notebook.execute_cells([IPython.notebook.get_selected_index()+{index_shift}])"
    )
    IPython.display.display(js)


def fn_run_last_cell(_):
    js = IPython.display.Javascript(
        "Jupyter.notebook.execute_cells([IPython.notebook.ncells() - 12])"
    )
    IPython.display.display(js)


def fn_skip_instance(_):
    if id_ not in cached_indices:
        skipped_indices.add(id_)
        fn_run_cell(1)


button_run_cell = ipywidgets.Button(
    description="Fetch new instance",
    tooltip="Run cell below to fetch a random instance",
    layout=ipywidgets.Layout(width="20%", height="48px", margin="0 0 0 5%"),
)
button_run_cell.on_click(lambda _: fn_run_cell(1))

button_run_cell_b = ipywidgets.Button(
    description="Fetch refined instance",
    tooltip="Fetch refined instance from front-end (run cell below)",
    style=dict(button_color="lightgreen"),
)
button_run_cell_b.layout = button_run_cell.layout
button_run_cell_b.on_click(lambda _: fn_run_cell(1))

button_edit_instance = ipywidgets.Button(
    description="Edit instance",
    tooltip="Send instance to interactive front-end for refinement",
    style=dict(button_color="lightblue"),
)
button_edit_instance.layout = button_run_cell.layout
button_edit_instance.on_click(lambda _: fn_run_cell(2))

button_save = ipywidgets.Button(
    description="Save test instance",
    style=dict(button_color="salmon"),
    tooltip="Run notebook last cell (triggers code to save instance in curated dataset)",
)
button_save.layout = button_run_cell.layout
button_save.on_click(fn_run_last_cell)

button_skip = ipywidgets.Button(
    description="Skip instance",
    style=dict(button_color="black"),
    tooltip="Skip instance",
)
button_skip.layout = button_run_cell.layout
button_skip.on_click(fn_skip_instance)

In [5]:
import scipy.special
import numpy as np
import tqdm
import torch
import transformers
import os


m = 140000 if DATASET_SPLIT == "train" else 2100
k = 0
cached_margins_filename = f"cached_margins_{DATASET_SPLIT}_{m}_{k}.txt"


if DATASET_SPLIT != "test":
    torch.cuda.empty_cache()

    if isinstance(logit_model.model, transformers.BertForTokenClassification):
        logit_model._model = torch.load("bert_logit_model_finetuned.pt")
    else:
        logit_model._model = torch.load("lstm_logit_model_finetuned.pt")

    logit_model.model.to("cuda")
    logit_model.device = "cuda"

    if os.path.exists(cached_margins_filename):
        with open(cached_margins_filename, "r") as f_in:
            margins = np.asfarray(f_in.readlines())

    else:
        margins = np.full(len(df_view_input_ids), fill_value=np.inf)

        compute_diff_tokens = False
        diff_80 = 0
        total_tokens_80 = 1e-8

        pbar = tqdm.auto.tqdm(
            enumerate(df_view_input_ids[-1 - k : -m - 1 - k : -1], 1 + k),
            total=min(m, len(df_view_input_ids) - k),
        )

        for i, text in pbar:
            logits = logit_model(tokenizer.decode(text), return_logits=True).logits
            probs = scipy.special.softmax(logits, axis=-1)
            margin = np.diff(np.sort(probs, axis=-1)[:, [-2, -1]]).ravel()

            try:
                true_labels = np.asarray(df_view_labels[-i], dtype=int)
                not_middle_word = true_labels != -100

                if compute_diff_tokens:
                    try:
                        margin_80_inds = np.flatnonzero(
                            np.logical_and(not_middle_word, margin >= 0.90)
                        )
                        high_conf_preds = np.argmax(probs[margin_80_inds], axis=-1)
                        diff_inds = np.flatnonzero(true_labels[margin_80_inds] != high_conf_preds)
                        diff_80 += diff_inds.size
                        total_tokens_80 += margin_80_inds.size

                        if diff_inds.size:
                            new_labels.append(
                                (
                                    i,
                                    list(
                                        zip(margin_80_inds[diff_inds], high_conf_preds[diff_inds])
                                    ),
                                )
                            )

                    except ValueError:
                        pass

                margin = margin[not_middle_word]

            except IndexError as err:
                margin = [np.inf, np.inf]

            margins[-i] = float(np.quantile(margin, 0.01))

            if compute_diff_tokens:
                pbar.set_description(
                    f"50: {100. * diff_80 / total_tokens_80:.2f}% ({diff_80} of {int(total_tokens_80)})"
                )

    ids_to_fetch = np.argsort(margins)

In [30]:
if not os.path.exists(cached_margins_filename):
    with open(cached_margins_filename, "w") as f_out:
        f_out.write("\n".join(map(lambda x: f"{x:.6f}", margins)))

## Interative refinery with low margin (active learning)

In [6]:
IPython.display.display(
    ipywidgets.HBox((button_run_cell, button_edit_instance, button_save, button_skip))
);

HBox(children=(Button(description='Fetch new instance', layout=Layout(height='48px', margin='0 0 0 5%', width=…

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [661]:
if DATASET_SPLIT == "test":
    j = 0
    while j in cached_indices or j in skipped_indices:
        j += 1

    id_ = j

else:
    j = 0
    while ids_to_fetch[j] in cached_indices or ids_to_fetch[j] in skipped_indices:
        j += 1

    id_ = int(ids_to_fetch[j])


input_ids = df_view_input_ids[id_]
labels = df_view_labels[id_]
tokens = print_labels(input_ids, labels)

if DATASET_SPLIT != "test":
    print("\n\nmargin:", margins[ids_to_fetch[j]])

IndexError: list index out of range

In [655]:
logits = logit_model(df_split[id_], return_logits=True).logits
interactive_labeling.open_example(
    tokens, labels, logits=logits if DATASET_SPLIT != "test" else None
)
lock_save = True

In [7]:
IPython.display.display(ipywidgets.HBox((button_run_cell_b, button_save)));

HBox(children=(Button(description='Fetch refined instance', layout=Layout(height='48px', margin='0 0 0 5%', wi…

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [656]:
ret = interactive_labeling.retrieve_refined_example()
assert len(ret["labels"]) == len(tokens)
labels = ret["labels"]
print_labels(tokens, labels, input_is_tokens=True)
lock_save = False

[0m[36m1.[0m [CLS] [31mCÂMARA DOS DEPUTADOS L ##idera ##nça do Par ##tid ##o Social ##ismo e Lib ##erdade 1 

[36m2.[0m REQUERIMENTO DE INFORMAÇÃO Nº _ _ _ _ _ DE 2019 ( Da Sra . Ta ##l ##í ##ria P ##etro ##ne ) 

[36m3.[0m Solicita ao Ministro da Saúde , Sr . Luiz H ##en ##ri ##que Man ##de ##t ##ta , informações acerca da compra central ##izada do medica ##mento M ##is ##op ##ros ##to ##l . 

[36m4.[0m Senhor Presidente , Requer ##emos a Vossa Excelência , com base no art . 50 , § 2º da Constituição Federal , e na forma dos arts . 115 e 116 do Regimento Interno da Câmara dos Deputados , as seguintes informações do Ministro de Estado da Saúde Luiz H ##en ##ri ##que Man ##de ##t ##ta , acerca da compra central ##izada do medica ##mento M ##is ##op ##ros ##to ##l a ser anualmente realizada pela pas ##ta sob sua responsabilidade . 

[36m5.[0m 1 . Nos anos de 2018 e 2019 , foi realizada a compra do medica ##mento na quantidade e no prazo adequado ##s ? 

[36m6.[0m 2 . Há dis

In [663]:
if not lock_save and id_ not in cached_indices and id_ < len(df_split):
    it_counter += 1
    cached_indices.add(id_)

    for key, val in df_split[id_].items():
        if key != "labels":
            new_refined_instances[key].append(val.copy())

        else:
            new_refined_instances[key].append(labels)

    pbar_dump.update()


if it_counter % AUTOSAVE_IN_N_INSTANCES == 0 or j >= len(df_view_input_ids):
    dump_refined_dataset()
    pbar_dump.reset()


print(pbar_dump)

Saved progress in '../data/refined_datasets/df_tokenized_split_0_120000_6000_resplit/test_83'.
Instances until next dump::   0%|          | 0/20 [00:00<?, ?it/s]


## Merge curated data shards

In [664]:
shard_uris = glob.glob(os.path.join(REFINED_DATASET_DIR, f"{DATASET_SPLIT}_*"))

all_dsets = []

for shard_uri in shard_uris:
    try:
        shard = datasets.Dataset.load_from_disk(shard_uri)
    except Exception:
        shard = None

    if shard and shard.num_rows:
        all_dsets.append(shard)

merged_dset = datasets.concatenate_datasets(all_dsets)
print(merged_dset.num_rows)

output_uri_merged = os.path.join(
    REFINED_DATASET_DIR,
    f"combined_{DATASET_SPLIT}_{len(shard_uris)}_parts_{merged_dset.num_rows}_instances",
)

merged_dset.save_to_disk(output_uri_merged)

3062


## Fine-tune for curated samples

In [667]:
import torch
import torch.nn
import copy
import torchmetrics
import transformers
import pandas as pd


df_curated_train = datasets.Dataset.load_from_disk(
    os.path.join(REFINED_DATASET_DIR, "combined_train_102_parts_4221_instances")
)
df_curated_test = datasets.Dataset.load_from_disk(
    os.path.join(REFINED_DATASET_DIR, "combined_test_84_parts_3062_instances")
)


def fn_pad(X):
    rem = 1024 - len(X["labels"])
    X["labels"] = X["labels"] + rem * [-100]
    X["input_ids"] = X["input_ids"] + rem * [0]
    X["token_type_ids"] = X["token_type_ids"] + rem * [0]
    X["attention_mask"] = X["attention_mask"] + rem * [0]
    return X


#### delete repeated instances
def rem_repeated_instances(df):
    df_aux = pd.DataFrame(df)
    print(df_aux.shape)
    df_aux_2 = df_aux.applymap(lambda x: " ".join(map(str, x)))
    df_aux_2 = df_aux_2.drop_duplicates(
        subset=["input_ids"], inplace=False, ignore_index=True, keep="last"
    )
    df_aux_2 = df_aux_2.applymap(lambda x: list(map(int, x.split(" "))))
    print("before:", df_aux.shape, "after:", df_aux_2.shape)
    return datasets.Dataset.from_pandas(df_aux_2)


print("train:")
df_curated_train = rem_repeated_instances(df_curated_train)
print("test:")
df_curated_test = rem_repeated_instances(df_curated_test)

df_curated_train = df_curated_train.map(fn_pad, batch_size=32)
df_curated_test = df_curated_test.map(fn_pad, batch_size=32)

df_curated_train.set_format("torch")
df_curated_test.set_format("torch")

batch_size = 3
dl_train = torch.utils.data.DataLoader(df_curated_train, shuffle=True, batch_size=batch_size)
dl_test = torch.utils.data.DataLoader(df_curated_test, batch_size=batch_size)

train:
(4221, 4)
before: (4221, 4) after: (4157, 4)
test:
(3062, 4)
before: (3062, 4) after: (2299, 4)


  0%|          | 0/4157 [00:00<?, ?ex/s]

  0%|          | 0/2299 [00:00<?, ?ex/s]

In [1395]:
logit_model = segmentador.BERTSegmenter(
    uri_model=f"../pretrained_segmenter_model/2_{VOCAB_SIZE}_layer_model",
    device="cpu",
)
# logit_model = segmentador.LSTMSegmenter(
#     uri_model=(
#         os.path.join(
#             f"../pretrained_segmenter_model/512_{VOCAB_SIZE}_1_lstm/checkpoints",
#             "512_hidden_dim_6000_vocab_size_1_layer_lstm.pt",
#         )
#     ),
#     device="cpu",
# )
logit_model_copy = copy.deepcopy(logit_model.model)
torch.cuda.empty_cache()

In [1397]:
optim = torch.optim.Adam(logit_model_copy.parameters(), lr=5e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.8)
num_epochs = 3
num_training_steps = (num_epochs * len(df_curated_train)) // batch_size
grad_acc_steps = 3 // batch_size
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

In [1398]:
progress_bar = tqdm.auto.tqdm(range(num_training_steps), total=num_training_steps)

logit_model_copy.to("cuda:0")
logit_model_copy.train()

for epoch in range(num_epochs):
    for i, batch in enumerate(dl_train, 1):
        if isinstance(logit_model_copy, transformers.BertForTokenClassification):
            batch = {k: v.to("cuda:0") for k, v in batch.items()}
            outputs = logit_model_copy(**batch)
            loss = outputs.loss

        else:
            outputs = logit_model_copy(batch["input_ids"].to("cuda:0"))
            logits = outputs["logits"].view(-1, 4)
            true = batch["labels"].to("cuda:0").view(-1)
            loss = loss_fn(logits, true)

        loss.backward()

        if i % grad_acc_steps == 0 or i == len(dl_train):
            optim.step()
            optim.zero_grad()

        progress_bar.update(1)

    lr_scheduler.step()

  0%|          | 0/4157 [00:00<?, ?it/s]

In [1400]:
if isinstance(logit_model_copy, transformers.BertForTokenClassification):
    torch.save(logit_model_copy, "bert_logit_model_finetuned.pt")
else:
    torch.save(logit_model_copy, "lstm_logit_model_finetuned.pt")

In [668]:
def eval_model(model, device="cuda:0"):
    model.eval()
    model.to(device)

    fn_precision = torchmetrics.classification.Precision(num_classes=4, average=None).to(device)
    fn_recall = torchmetrics.classification.Recall(num_classes=4, average=None).to(device)

    preds = []
    targets = []

    for batch in tqdm.auto.tqdm(dl_test):
        with torch.no_grad():
            if isinstance(model, transformers.BertForTokenClassification):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                logits = outputs.logits

            else:
                outputs = model(batch["input_ids"].to(device))
                logits = outputs["logits"].view(-1, 4)

        predictions = torch.argmax(logits, dim=-1)

        preds.append(predictions.to("cpu"))
        targets.append(batch["labels"].to("cpu"))

    try:
        preds = torch.vstack(preds).view(-1)

    except RuntimeError:
        preds = torch.concat(preds)

    targets = torch.vstack(targets).view(-1)

    preds = torch.tensor(
        [p for i, p in enumerate(preds) if targets[i] != -100], dtype=torch.long
    ).to(device)
    targets = torch.tensor([tg for tg in targets if tg != -100], dtype=torch.long).to(device)

    recall = fn_recall(preds, targets)
    precision = fn_precision(preds, targets)

    return recall.to("cpu"), precision.to("cpu")

In [669]:
recall_orig, precision_orig = eval_model(logit_model.model)
print(f"{recall_orig = }, {precision_orig = }")

# combined_test_48_parts_1036_instances
# BERT: recall_orig = ([0.9994, 0.9542, 0.8403, 0.5435]), precision_orig = ([0.9983, 0.9853, 0.9138, 0.6637])
# LSTM: recall_orig = ([0.9995, 0.9547, 0.8571, 0.7174]), precision_orig = ([0.9985, 0.9871, 0.9166, 0.6600])

# combined_test_73_parts_1476_instances
# BERT: recall_orig = ([0.9994, 0.9470, 0.8354, 0.5529]), precision_orig = ([0.9981, 0.9853, 0.9058, 0.6805])

# combined_test_84_parts_3062_instances
# BERT: recall_orig = ([0.9994, 0.9389, 0.7762, 0.4331]), precision_orig = ([0.9976, 0.9858, 0.8991, 0.6920])

  0%|          | 0/767 [00:00<?, ?it/s]

recall_orig = tensor([0.9994, 0.9389, 0.7762, 0.4331]), precision_orig = tensor([0.9976, 0.9858, 0.8991, 0.6920])


In [672]:
recall_new, precision_new = eval_model(logit_model_copy)
print(f"{recall_new = }, {precision_new = }")

## BERT
# combined_train_43_parts_860_instances (lr=1e-5)
# g=9, e=2 - recall_new = ([0.9994, 0.9593, 0.8187, 0.4855]), precision_new = ([0.9984, 0.9822, 0.9420, 0.7204])
# g=9, e=1 - recall_new = ([0.9995, 0.9552, 0.8163, 0.4710]), precision_new = ([0.9983, 0.9847, 0.9444, 0.7222])
# g=18,e=2 - recall_new = ([0.9995, 0.9563, 0.8187, 0.4855]), precision_new = ([0.9983, 0.9845, 0.9433, 0.7204])
# g=3, e=2 - recall_new = ([0.9994, 0.9610, 0.8283, 0.5145]), precision_new = ([0.9985, 0.9825, 0.9413, 0.7172])

# combined_train_56_parts_1120_instances (lr=1e-5)
# g=300,e=10- recall_new = ([0.9994, 0.9558, 0.8331, 0.5290]), precision_new = ([0.9984, 0.9839, 0.9253, 0.6759])
# g=180,e=3 - recall_new = ([0.9994, 0.9560, 0.8295, 0.5217]), precision_new = ([0.9983, 0.9837, 0.9363, 0.6857])
# g=18,e=1 - recall_new = ([0.9995, 0.9567, 0.8175, 0.4855]), precision_new = ([0.9983, 0.9838, 0.9432, 0.6979])
# g=9, e=1 - recall_new = ([0.9995, 0.9606, 0.8103, 0.3913]), precision_new = ([0.9984, 0.9817, 0.9534, 0.7714])
# g=9, e=2 - recall_new = ([0.9995, 0.9603, 0.8283, 0.5072]), precision_new = ([0.9984, 0.9833, 0.9439, 0.6863])
# g=6, e=1 - recall_new = ([0.9995, 0.9603, 0.8211, 0.4783]), precision_new = ([0.9984, 0.9825, 0.9474, 0.7253])
# g=3, e=1 - recall_new = ([0.9994, 0.9622, 0.8307, 0.5000]), precision_new = ([0.9985, 0.9815, 0.9428, 0.7263])

# combined_train_56_parts_1120_instances (lr=1e-6)
# g=3, e=1 - recall_new = ([0.9994, 0.9560, 0.8331, 0.5290]), precision_new = ([0.9984, 0.9837, 0.9241, 0.6759])

# combined_train_56_parts_1120_instances (lr=1e-4)
# g=9, e=1 - recall_new = ([0.9990, 0.9700, 0.8247, 0.5870]), precision_new = ([0.9987, 0.9753, 0.8865, 0.5400])
# g=3, e=1 - recall_new = ([0.9994, 0.9633, 0.8211, 0.5072]), precision_new = ([0.9985, 0.9851, 0.9012, 0.7071])


# combined_train_66_parts_1320_instances (lr=1e-5)
# g=9, e=2 - recall_new = ([0.9994, 0.9630, 0.8247, 0.5145]), precision_new = ([0.9985, 0.9815, 0.9463, 0.6698])
# g=9, e=1 - recall_new = ([0.9995, 0.9572, 0.8163, 0.4855]), precision_new = ([0.9983, 0.9852, 0.9471, 0.7204])
# g=3, e=1 - recall_new = ([0.9994, 0.9642, 0.8259, 0.5000]), precision_new = ([0.9985, 0.9803, 0.9477, 0.6765])
# g=30, e=1 - recall_new = ([0.9996, 0.9538, 0.8103, 0.4638]), precision_new = ([0.9982, 0.9857, 0.9534, 0.7356])

# combined_train_76_parts_1520_instances (lr=1e-5)
# g=3, e=1 - recall_new = ([0.9994, 0.9633, 0.8235, 0.5072]), precision_new = ([0.9985, 0.9805, 0.9449, 0.6731])
# g=6, e=1 - recall_new = ([0.9995, 0.9624, 0.8127, 0.4638]), precision_new = ([0.9984, 0.9816, 0.9522, 0.7442])
# g=30, e=1 - recall_new = ([0.9996, 0.9532, 0.8103, 0.4565]), precision_new = ([0.9982, 0.9858, 0.9520, 0.7326])

# combined_train_76_parts_1520_instances (lr=5e-6)
# g=6, e=3 - recall_new = ([0.9994, 0.9619, 0.8319, 0.5072]), precision_new = ([0.9985, 0.9819, 0.9416, 0.6542])

# combined_train_92_parts_1840_instances (lr=1e-5)
# g=3, e=1 - recall_new = ([0.9993, 0.9651, 0.8391, 0.5145]), precision_new = ([0.9986, 0.9812, 0.9332, 0.6339])
# g=6, e=1 - recall_new = ([0.9993, 0.9650, 0.8295, 0.5072]), precision_new = ([0.9986, 0.9787, 0.9427, 0.7143])

# combined_train_103_parts_2059_instances (lr=1e-5)
# g=6, e=1 - recall_new = ([0.9994, 0.9616, 0.8247, 0.5000]), precision_new = ([0.9985, 0.9825, 0.9463, 0.6765])

# combined_train_112_parts_2239_instances (lr=2e-5)
# g=6, e=1 - recall_new = ([0.9994, 0.9666, 0.8463, 0.5362]), precision_new = ([0.9987, 0.9831, 0.9413, 0.6667])
#            recall_new = ([0.9994, 0.9601, 0.8404, 0.5385]), precision_new = ([0.9985, 0.9844, 0.9321, 0.7000])
#            recall_orig= ([0.9994, 0.9470, 0.8354, 0.5529]), precision_orig= ([0.9981, 0.9853, 0.9058, 0.6805])

# combined_train_41_parts_3001_instances (lr=2e-5)
# g=3, e=1 - recall_new = ([0.9994, 0.9638, 0.8486, 0.5514]), precision_new = ([0.9986, 0.9844, 0.9173, 0.6782])
# g=6, e=1 - recall_new = ([0.9994, 0.9609, 0.8422, 0.5280]), precision_new = ([0.9985, 0.9830, 0.9200, 0.6933])

# combined_train_41_parts_3001_instances (lr=1e-5)
# g=3, e=2 - recall_new = ([0.9992, 0.9647, 0.8486, 0.5561]), precision_new = ([0.9986, 0.9794, 0.9126, 0.6364])
# g=3, e=1 - recall_new = ([0.9994, 0.9610, 0.8269, 0.5467]), precision_new = ([0.9985, 0.9828, 0.9345, 0.6648])

# combined_train_41_parts_3001_instances (lr=5e-5)
# g=3, e=1 - recall_new = ([0.9994, 0.9685, 0.8108, 0.4907]), precision_new = ([0.9986, 0.9806, 0.9367, 0.7447])

# combined_train_74_parts_3661_instances (lr=5e-5)
# g=3, e=1 - recall_new = ([0.9991, 0.9711, 0.8591, 0.4766]), precision_new = ([0.9988, 0.9757, 0.8892, 0.7338])

# combined_train_92_parts_4021_instances (lr=5e-5)
# g=3, e=1 - recall_new = ([0.9989, 0.9740, 0.8728, 0.5701]), precision_new = ([0.9989, 0.9731, 0.8742, 0.5837])

# combined_train_102_parts_4221_instances (lr=5e-5)
# g=3, e=1 - recall_new = ([0.9994, 0.9642, 0.8519, 0.5140]), precision_new = ([0.9986, 0.9858, 0.9020, 0.6707])
# g=3, e=3 - recall_new = ([0.9990, 0.9744, 0.8720, 0.5981]), precision_new = ([0.9989, 0.9745, 0.8906, 0.6184])

# combined_test_84_parts_3062_instances:
# recall_orig = ([0.9994, 0.9389, 0.7762, 0.4331]), precision_orig = ([0.9976, 0.9858, 0.8991, 0.6920])
# recall_new  = ([0.9993, 0.9659, 0.8491, 0.5601]), precision_new  = ([0.9985, 0.9797, 0.9228, 0.7508])


## 512 LSTM 1-layer
# combined_train_103_parts_2059_instances (lr=1e-5)
# g=6, e=1 - recall_new = ([0.9994, 0.9564, 0.8583, 0.7101]), precision_new = ([0.9985, 0.9850, 0.9202, 0.6759])
# g=6, e=2 - recall_new = ([0.9994, 0.9591, 0.8547, 0.7101]), precision_new = ([0.9986, 0.9821, 0.9163, 0.7000])

  0%|          | 0/767 [00:00<?, ?it/s]

recall_new = tensor([0.9993, 0.9659, 0.8491, 0.5601]), precision_new = tensor([0.9985, 0.9797, 0.9228, 0.7508])
