In [1]:
import wandb
import click
import torch
import transformers
from pathlib import Path
from copy import deepcopy
from functools import partial
from typing import List, Callable, Dict, Iterable

# Local imports
try:
    import _pathfix
except ImportError:
    from . import _pathfix
from loops import training_loop
from config import LOCATIONS as LOC, CONFIG
from models.multitask import BasicMTL
from dataiter import MultiTaskDataIter
# from eval import ner_all, ner_only_annotated, ner_span_recog_recall, ner_span_recog_precision, \
#     pruner_p, pruner_r


Fixing paths from /home/priyansh/Dev/research/coref/mtl/src


In [2]:
! free -h

              total        used        free      shared  buff/cache   available
Mem:            15G        6,1G        2,5G        934M        6,7G        7,9G
Swap:          979M          0B        979M


In [3]:

def make_optimizer(model, optimizer_class: Callable, lr: float, freeze_encoder: bool):
    if freeze_encoder:
        return optimizer_class(
            [param for name, param in model.named_parameters() if not name.startswith("encoder")],
            lr=lr
        )
    else:
        return optimizer_class(model.parameters(), lr=lr)


def get_pretrained_dirs(nm: str):
    """Check if the given nm is stored locally. If so, load that. Else, pass it on as is."""
    plausible_parent_dir: Path = LOC.root / "models" / "huggingface" / nm

    if (
            (plausible_parent_dir / "config").exists()
            and (plausible_parent_dir / "tokenizer").exists()
            and (plausible_parent_dir / "encoder").exists()
    ):
        return (
            str(plausible_parent_dir / "config"),
            str(plausible_parent_dir / "tokenizer"),
            str(plausible_parent_dir / "encoder"),
        )
    else:
        return nm, nm, nm


def compute_metrics(metrics: Dict[str, Callable], logits, labels) -> Dict[str, float]:
    return {metric_nm: metric_fn(logits=logits, labels=labels).item() for metric_nm, metric_fn in metrics.items()}


def aggregate_metrics(inter_epoch: dict, intra_epoch: dict):
    for task_nm in inter_epoch.keys():
        for metric_nm, metric_list in intra_epoch[task_nm].items():
            inter_epoch[task_nm][metric_nm].append(np.mean(metric_list))
    return inter_epoch


def simplest_loop(
        epochs: int,
        tasks: Iterable[str],
        opt: torch.optim,
        train_fn: Callable,
        predict_fn: Callable,
        trn_dl: Callable,
        dev_dl: Callable,
        eval_fns: dict,
) -> (list, list, list):
    train_loss = {task_nm: [] for task_nm in tasks}
    train_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}
    valid_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}

    # Make data
    trn_ds = trn_dl()
    dev_ds = dev_dl()

    # Epoch level
    for e in range(epochs):

        per_epoch_loss = {task_nm: [] for task_nm in tasks}
        per_epoch_tr_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}
        per_epoch_vl_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}

        # Train
        with Timer() as timer:

            # Train Loop
            for instance in tqdm(trn_ds):

                # Reset the gradients.
                opt.zero_grad()

                # Forward Pass
                outputs = train_fn(**instance)

                """
                    Depending on instance.tasks list, do the following:
                        - task specific loss (added to losses)
                        - task specific metrics (added to metrics)
                """
                for task_nm in instance['tasks']:
                    loss = outputs["loss"][task_nm]
                    per_epoch_loss[task_nm].append(loss.item())

                    # TODO: add other metrics here
                    instance_metrics = compute_metrics(eval_fns[task_nm],
                                                       logits=outputs[task_nm]["logits"],
                                                       labels=outputs[task_nm]["labels"])
                    for metric_nm, metric_vl in instance_metrics.items():
                        per_epoch_tr_metrics[task_nm][metric_nm].append(metric_vl)

                loss.backward()
                opt.step()

            # Val
            with torch.no_grad():

                for instance in tqdm(dev_ds):
                    outputs = predict_fn(**instance)

                    for task_nm in instance["tasks"]:
                        logits = outputs[task_nm]["logits"]
                        # TODO: make the label puller task specific somehow
                        labels = instance["ner"]["gold_labels"]

                        instance_metrics = compute_metrics(eval_fns[task_nm], logits=logits, labels=labels)
                        for metric_nm, metric_vl in instance_metrics.items():
                            per_epoch_vl_metrics[task_nm][metric_nm].append(metric_vl)

        # Bookkeep
        for task_nm in tasks:
            train_loss[task_nm].append(np.mean(per_epoch_loss[task_nm]))
            train_metrics = aggregate_metrics(train_metrics, per_epoch_tr_metrics)
            valid_metrics = aggregate_metrics(valid_metrics, per_epoch_vl_metrics)

        print(f"\nEpoch: {e:3d}" +
              ''.join([f" | {task_nm} Loss: {float(np.mean(per_epoch_loss[task_nm])):.5f}" +
                       ''.join([f" | {task_nm} Tr_{metric_nm}: {float(metric_vls[-1]):.3f}"
                                for metric_nm, metric_vls in train_metrics[task_nm].items()]) +
                       ''.join([f" | {task_nm} Vl_{metric_nm}: {float(metric_vls[-1]):.3f}"
                                for metric_nm, metric_vls in valid_metrics[task_nm].items()])
                       # f" | {task_nm} Tr_c: {float(np.mean(per_epoch_tr_acc[task_nm])):.5f}" +
                       # f" | {task_nm} Vl_c: {float(np.mean(per_epoch_vl_acc[task_nm])):.5f}"
                       for task_nm in tasks]))

    return train_metrics, valid_metrics, train_loss

# Make MTL A

In [4]:
dataset: str = 'ontonotes'
epochs: int = 10
encoder: str = "bert-base-uncased"
tasks: List[str] = ('coref', 'ner', 'pruner')
device: str = "cpu"
trim: bool = True
train_encoder: bool = False,
ner_unweighted: bool = False
filter_candidates_pos = True

In [5]:

dir_config, dir_tokenizer, dir_encoder = get_pretrained_dirs(encoder)

tokenizer = transformers.BertTokenizer.from_pretrained(dir_tokenizer)
config = transformers.BertConfig(dir_config)
config.max_span_width = 5
config.coref_dropout = 0.3
config.metadata_feature_size = 20
config.unary_hdim = 1000
config.binary_hdim = 2000
config.top_span_ratio = 0.4
config.max_top_antecedents = 50
config.device = device
config.epochs = epochs
config.trim = trim
config.freeze_encoder = not train_encoder
config.ner_ignore_weights = ner_unweighted
config.filter_candidates_pos_threshold = CONFIG['filter_candidates_pos_threshold'] \
    if filter_candidates_pos else -1


if 'ner' in tasks or 'pruner' in tasks:
    # Need to figure out the number of classes. Load a DL. Get the number. Delete the DL.
    temp_ds = MultiTaskDataIter(
        src=dataset,
        config=config,
        tasks=tasks,
        split="development",
        tokenizer=tokenizer,
    )
    if 'ner' in tasks:
        config.ner_n_classes = deepcopy(temp_ds.ner_tag_dict.__len__())
        config.ner_class_weights = temp_ds.estimate_class_weights('ner')
    else:
        config.ner_n_classes = 1
        config.ner_class_weights = [1.0, ]
    if 'pruner' in tasks:
        config.pruner_class_weights = temp_ds.estimate_class_weights('pruner')
    del temp_ds
else:
    config.ner_n_classes = 1
    config.ner_class_weights = [1.0, ]

# # Make the model
# model = BasicMTL(dir_encoder, config=config)

# Load the data
train_ds = partial(
    MultiTaskDataIter,
    src=dataset,
    config=config,
    tasks=tasks,
    split="train",
    tokenizer=tokenizer,
)
# valid_ds = partial(
#     MultiTaskDataIter,
#     src=dataset,
#     config=config,
#     tasks=tasks,
#     split="development",
#     tokenizer=tokenizer,
# )

# Make the optimizer
# opt = make_optimizer(model=model, optimizer_class=torch.optim.SGD, lr=0.005, freeze_encoder=config.freeze_encoder)
# opt = torch.optim.SGD(model.parameters(), lr=0.001)

# # Make the evaluation suite (may compute multiple metrics corresponding to the tasks)
# eval_fns: Dict[str, Dict[str, Callable]] = {
#     'ner': {'acc': ner_all,
#             'acc_l': ner_only_annotated,
#             'span_p': ner_span_recog_precision,
#             'span_r': ner_span_recog_recall},
#     'coref': {

#     },
#     'pruner': {'p': pruner_p,
#                'r': pruner_r}
# }

print(config)
print("Training commences!")

Pulled 318 instances from ../data/parsed/ontonotes/development/MultiTaskDatasetDump_coref_ner_pruner.pkl.
BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "binary_hdim": 2000,
  "classifier_dropout": null,
  "coref_dropout": 0.3,
  "device": "cpu",
  "epochs": 10,
  "filter_candidates_pos_threshold": 2000,
  "freeze_encoder": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "max_span_width": 5,
  "max_top_antecedents": 50,
  "metadata_feature_size": 20,
  "model_type": "bert",
  "ner_class_weights": [
    0.05330982413982514,
    24.68555023923445,
    30.717313646106216,
    51.42822966507177,
    20.887773279352228,
    43.23902111967818,
    101.32128829536528,
    1697.1315789473683,
    38.79157894736842,
    174.06477732793522,
    678.8526315789474,
    135.77052631578948,
    141.42763157894737,
    102.85645933014354,
 



# Make MTL B

In [6]:
dataset: str = 'scierc'
epochs: int = 10
encoder: str = "bert-base-uncased"
tasks: List[str] = ('ner',)
device: str = "cpu"
trim: bool = True
train_encoder: bool = False
ner_unweighted: bool = False
filter_candidates_pos = True


dir_config, dir_tokenizer, dir_encoder = get_pretrained_dirs(encoder)

tokenizer = transformers.BertTokenizer.from_pretrained(dir_tokenizer)
config = transformers.BertConfig(dir_config)
config.max_span_width = 5
config.coref_dropout = 0.3
config.metadata_feature_size = 20
config.unary_hdim = 1000
config.binary_hdim = 2000
config.top_span_ratio = 0.4
config.max_top_antecedents = 50
config.device = device
config.epochs = epochs
config.trim = trim
config.freeze_encoder = not train_encoder
config.ner_ignore_weights = ner_unweighted
config.filter_candidates_pos_threshold = CONFIG['filter_candidates_pos_threshold'] \
    if filter_candidates_pos else -1


# if 'ner' in tasks or 'pruner' in tasks:
if False:
    # Need to figure out the number of classes. Load a DL. Get the number. Delete the DL.
    temp_ds = MultiTaskDataIter(
        src=dataset,
        config=config,
        tasks=tasks,
        split="dev",
        tokenizer=tokenizer,
    )
    if 'ner' in tasks:
        config.ner_n_classes = deepcopy(temp_ds.ner_tag_dict.__len__())
        config.ner_class_weights = temp_ds.estimate_class_weights('ner')
    else:
        config.ner_n_classes = 1
        config.ner_class_weights = [1.0, ]
    if 'pruner' in tasks:
        config.pruner_class_weights = temp_ds.estimate_class_weights('pruner')
    del temp_ds
else:
    config.ner_n_classes = 1
    config.ner_class_weights = [1.0, ]

# # Make the model
# model = BasicMTL(dir_encoder, config=config)

# Load the data
train_ds_b = partial(
    MultiTaskDataIter,
    src=dataset,
    config=config,
    tasks=tasks,
    split="train",
    tokenizer=tokenizer,
)

sci_train_ds = train_ds_b()

Pulled 346 instances from ../data/parsed/scierc/train/MultiTaskDatasetDump_ner.pkl.


In [7]:
from dataiter import DataIterCombiner

In [8]:
dc = DataIterCombiner([train_ds, train_ds_b])

Pulled 2455 instances from ../data/parsed/ontonotes/train/MultiTaskDatasetDump_coref_ner_pruner.pkl.
Pulled 346 instances from ../data/parsed/scierc/train/MultiTaskDatasetDump_ner.pkl.


In [10]:
len(dc)

100

In [22]:
dc.history

{0: (1, 1), 1: (1, 2), 2: (1, 3)}

In [23]:
dc[0]

{'tasks': ['ner'],
 'input_ids': tensor([[ 1999,  2023,  3259,  1010,  1037,  3117,  4118,  2000,  4553,  1996,
          23807,  4874,  3252,  2005, 15873,  5107,  9651,  2003,  3818,  1012,
           1996,  3937, 11213,  2003,  2008,  1996, 16381,  3550,  4874,  2110,
           3658,  2006,  1037,  2659,  8789, 19726,  1998,  2064,  2022,  4342,
           2013,  2731,  2951,  1012,  2241,  2006,  2023, 11213,  1010, 15847,
           2057,  5173,  1996,  8789,  3012,  7312,  1998,  4304, 24155,  9896,
           2005,  4895,  6342,  4842, 11365,  2098,  4083,  1997,  4874, 23807,
           6630,  1010,  1996,  4663,  2512,  1011, 11841,  2112,  1997,  4874,
           2110, 13416,  2130,  2000,  1016,  9646,  1012, 16378,  1996,  8790,
           2389,  2944,  2003,  5173,  1998,  4738,  2241,  2006,  2023, 23807,
           6630,  1012,  2353,  2135,  1996,  4342, 23807,  4874,  3252,  2003,
           6377,  2046,  1037, 10811,  1011, 11307,  2806, 27080,  1012,  2057,
        

In [24]:
for x in dc:
    break

In [25]:
x

{'tasks': ['ner'],
 'input_ids': tensor([[10629, 16905,  3464,  1996,  2087,  6450,  2112,  1997, 16905,  1011,
           2241,  8035, 11968,  7741,  1012,  2057,  3579,  2006,  2028,  3177,
           1011,  2039,  5783,  1999,  1996,  2640,  1997, 16905, 13792,  1024,
          24685,  1997, 24731,  1997,  4895,  5302,  4305, 10451,  4942, 27341,
           1012,  2057, 16599,  1037,  4118,  1997, 26615,  2107,  1037,  2640,
           2083,  1037,  4118,  1997,  3252,  1011,  6631,  2029, 26777,  8833,
           1006,  1040,  1007,  8964,  2015,  2411,  3378,  2007,  3252,  1011,
           6631,  1997, 19287,  2302,  2151,  2224,  1997, 17047, 24394, 20884,
           2015,  1012,  1996,  3818,  5679, 11027,  2015, 21707, 24731,  2096,
           8498,  1996, 17982,  1011, 15615,  5679,  1005,  1055,  3754,  2000,
           4468,  2058, 24731,  1998,  2220, 24731,  4117,  2007,  2049,  3754,
           2000,  5047, 23750,  5090,  2302,  9896,  2594, 13134,  1012,     0,
        

In [None]:
x.keys()

In [None]:
la = onto_train_di.__len__()
lb = sci_train_ds.__len__()

la, lb, la+lb

In [None]:
sampling_ratio = [0.5, 1]

In [None]:
llen = la + lb

In [None]:
pointers = [int(x* llen / float(sum(sampling_ratio))) for x in sampling_ratio]
pointers

In [None]:
pointers = [2, 6]

In [None]:
import numpy as np

In [None]:
source_indices = []
for i, dataset_specific_ratio in enumerate(pointers):
    source_indices += [i]*dataset_specific_ratio
    
source_indices

In [None]:
np.random.shuffle(source_indices)

source_indices

# Testing Eval

In [None]:
from eval import Evaluator, NERAcc, NERSpanRecognitionPR, PrunerPR, CorefBCubed, CorefMUC, CorefCeafe

In [None]:
eval_bench = Evaluator(
    predict_fn = model.pred_with_labels,
    dataset_partial = valid_ds,
    metrics = [NERAcc(), NERSpanRecognitionPR(), PrunerPR(), CorefBCubed(), CorefMUC(), CorefCeafe()],
    device = 'cpu'
)

In [None]:
eval_bench.run()

# Eval for Coref

In [None]:


def b_cubed(clusters, mention_to_gold):
    num, dem = 0, 0

    for c in clusters:
        if len(c) == 1:
            continue

        gold_counts = Counter()
        correct = 0
        for m in c:
            if m in mention_to_gold:
                gold_counts[tuple(mention_to_gold[m])] += 1
        for c2, count in gold_counts.items():
            if len(c2) != 1:
                correct += count * count

        num += correct / float(len(c))
        dem += len(c)

    return num, dem


def muc(clusters, mention_to_gold):
    tp, p = 0, 0
    for c in clusters:
        p += len(c) - 1
        tp += len(c)
        linked = set()
        for m in c:
            if m in mention_to_gold:
                linked.add(mention_to_gold[m])
            else:
                tp -= 1
        tp -= len(linked)
    return tp, p


def phi4(c1, c2):
    return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))


def ceafe(clusters, gold_clusters):
    clusters = [c for c in clusters if len(c) != 1]
    scores = np.zeros((len(gold_clusters), len(clusters)))
    for i in range(len(gold_clusters)):
        for j in range(len(clusters)):
            scores[i, j] = phi4(gold_clusters[i], clusters[j])
    matching = linear_assignment(-scores)
    similarity = sum(scores[matching[0], matching[1]])

    # similarity = sum(scores[matching[:, 0], matching[:, 1]])
    return similarity, len(clusters), similarity, len(gold_clusters)

In [None]:
for i, instance in enumerate(dl):
    outputs = model.pred_with_labels(**instance)
    break

In [None]:
! free -h

In [None]:
instance['coref'].keys(), outputs['coref'].keys(), outputs.keys(), outputs['coref']['eval'].keys()

In [None]:
print('=None, '.join(['clusters', 'gold_clusters', 'mention_to_predicted', 'mention_to_gold']))

In [None]:
ll =  outputs['coref']['eval']

In [None]:
ceafe(ll['clusters'], ll['gold_clusters'])

In [None]:
phi4(ll['clusters'], ll['gold_clusters'])

In [None]:
muc(ll['clusters'], ll['mention_to_gold'])

In [None]:
b_cubed(ll['clusters'], ll['mention_to_gold'])