In [1]:
import wandb
import click
import torch
import transformers
from pathlib import Path
from copy import deepcopy
from functools import partial
from typing import List, Callable, Dict, Iterable

# Local imports
try:
    import _pathfix
except ImportError:
    from . import _pathfix
from loops import training_loop
from config import LOCATIONS as LOC, CONFIG
from models.multitask import BasicMTL
from dataiter import MultiTaskDataset
from eval import ner_all, ner_only_annotated, ner_span_recog_recall, ner_span_recog_precision, \
    pruner_p, pruner_r


Fixing paths from /home/priyansh/Dev/research/coref/mtl/src


In [2]:

def make_optimizer(model, optimizer_class: Callable, lr: float, freeze_encoder: bool):
    if freeze_encoder:
        return optimizer_class(
            [param for name, param in model.named_parameters() if not name.startswith("encoder")],
            lr=lr
        )
    else:
        return optimizer_class(model.parameters(), lr=lr)


def get_pretrained_dirs(nm: str):
    """Check if the given nm is stored locally. If so, load that. Else, pass it on as is."""
    plausible_parent_dir: Path = LOC.root / "models" / "huggingface" / nm

    if (
            (plausible_parent_dir / "config").exists()
            and (plausible_parent_dir / "tokenizer").exists()
            and (plausible_parent_dir / "encoder").exists()
    ):
        return (
            str(plausible_parent_dir / "config"),
            str(plausible_parent_dir / "tokenizer"),
            str(plausible_parent_dir / "encoder"),
        )
    else:
        return nm, nm, nm


def compute_metrics(metrics: Dict[str, Callable], logits, labels) -> Dict[str, float]:
    return {metric_nm: metric_fn(logits=logits, labels=labels).item() for metric_nm, metric_fn in metrics.items()}


def aggregate_metrics(inter_epoch: dict, intra_epoch: dict):
    for task_nm in inter_epoch.keys():
        for metric_nm, metric_list in intra_epoch[task_nm].items():
            inter_epoch[task_nm][metric_nm].append(np.mean(metric_list))
    return inter_epoch


def simplest_loop(
        epochs: int,
        tasks: Iterable[str],
        opt: torch.optim,
        train_fn: Callable,
        predict_fn: Callable,
        trn_dl: Callable,
        dev_dl: Callable,
        eval_fns: dict,
) -> (list, list, list):
    train_loss = {task_nm: [] for task_nm in tasks}
    train_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}
    valid_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}

    # Make data
    trn_ds = trn_dl()
    dev_ds = dev_dl()

    # Epoch level
    for e in range(epochs):

        per_epoch_loss = {task_nm: [] for task_nm in tasks}
        per_epoch_tr_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}
        per_epoch_vl_metrics = {task_nm: {metric_nm: [] for metric_nm in eval_fns[task_nm].keys()} for task_nm in tasks}

        # Train
        with Timer() as timer:

            # Train Loop
            for instance in tqdm(trn_ds):

                # Reset the gradients.
                opt.zero_grad()

                # Forward Pass
                outputs = train_fn(**instance)

                """
                    Depending on instance.tasks list, do the following:
                        - task specific loss (added to losses)
                        - task specific metrics (added to metrics)
                """
                for task_nm in instance['tasks']:
                    loss = outputs["loss"][task_nm]
                    per_epoch_loss[task_nm].append(loss.item())

                    # TODO: add other metrics here
                    instance_metrics = compute_metrics(eval_fns[task_nm],
                                                       logits=outputs[task_nm]["logits"],
                                                       labels=outputs[task_nm]["labels"])
                    for metric_nm, metric_vl in instance_metrics.items():
                        per_epoch_tr_metrics[task_nm][metric_nm].append(metric_vl)

                loss.backward()
                opt.step()

            # Val
            with torch.no_grad():

                for instance in tqdm(dev_ds):
                    outputs = predict_fn(**instance)

                    for task_nm in instance["tasks"]:
                        logits = outputs[task_nm]["logits"]
                        # TODO: make the label puller task specific somehow
                        labels = instance["ner"]["gold_labels"]

                        instance_metrics = compute_metrics(eval_fns[task_nm], logits=logits, labels=labels)
                        for metric_nm, metric_vl in instance_metrics.items():
                            per_epoch_vl_metrics[task_nm][metric_nm].append(metric_vl)

        # Bookkeep
        for task_nm in tasks:
            train_loss[task_nm].append(np.mean(per_epoch_loss[task_nm]))
            train_metrics = aggregate_metrics(train_metrics, per_epoch_tr_metrics)
            valid_metrics = aggregate_metrics(valid_metrics, per_epoch_vl_metrics)

        print(f"\nEpoch: {e:3d}" +
              ''.join([f" | {task_nm} Loss: {float(np.mean(per_epoch_loss[task_nm])):.5f}" +
                       ''.join([f" | {task_nm} Tr_{metric_nm}: {float(metric_vls[-1]):.3f}"
                                for metric_nm, metric_vls in train_metrics[task_nm].items()]) +
                       ''.join([f" | {task_nm} Vl_{metric_nm}: {float(metric_vls[-1]):.3f}"
                                for metric_nm, metric_vls in valid_metrics[task_nm].items()])
                       # f" | {task_nm} Tr_c: {float(np.mean(per_epoch_tr_acc[task_nm])):.5f}" +
                       # f" | {task_nm} Vl_c: {float(np.mean(per_epoch_vl_acc[task_nm])):.5f}"
                       for task_nm in tasks]))

    return train_metrics, valid_metrics, train_loss

In [3]:
dataset: str = 'ontonotes'
epochs: int = 10
encoder: str = "bert-base-uncased"
tasks: List[str] = ('coref',)
device: str = "cpu"
trim: bool = False
train_encoder: bool = False,
ner_unweighted: bool = False
filter_candidates_pos = True

In [4]:

dir_config, dir_tokenizer, dir_encoder = get_pretrained_dirs(encoder)

tokenizer = transformers.BertTokenizer.from_pretrained(dir_tokenizer)
config = transformers.BertConfig(dir_config)
config.max_span_width = 5
config.coref_dropout = 0.3
config.metadata_feature_size = 20
config.unary_hdim = 1000
config.binary_hdim = 2000
config.top_span_ratio = 0.4
config.max_top_antecedents = 50
config.device = device
config.epochs = epochs
config.trim = trim
config.freeze_encoder = not train_encoder
config.ner_ignore_weights = ner_unweighted
config.filter_candidates_pos_threshold = CONFIG['filter_candidates_pos_threshold'] \
    if filter_candidates_pos else -1


if 'ner' in tasks or 'pruner' in tasks:
    # Need to figure out the number of classes. Load a DL. Get the number. Delete the DL.
    temp_ds = MultiTaskDataset(
        src=dataset,
        config=config,
        tasks=tasks,
        split="development",
        tokenizer=tokenizer,
    )
    if 'ner' in tasks:
        config.ner_n_classes = deepcopy(temp_ds.ner_tag_dict.__len__())
        config.ner_class_weights = temp_ds.estimate_class_weights('ner')
    else:
        config.ner_n_classes = 1
        config.ner_class_weights = [1.0, ]
    if 'pruner' in tasks:
        config.pruner_class_weights = temp_ds.estimate_class_weights('pruner')
    del temp_ds
else:
    config.ner_n_classes = 1
    config.ner_class_weights = [1.0, ]

# Make the model
model = BasicMTL(dir_encoder, config=config)

# Load the data
train_ds = partial(
    MultiTaskDataset,
    src=dataset,
    config=config,
    tasks=tasks,
    split="train",
    tokenizer=tokenizer,
)
valid_ds = partial(
    MultiTaskDataset,
    src=dataset,
    config=config,
    tasks=tasks,
    split="development",
    tokenizer=tokenizer,
)

# Make the optimizer
opt = make_optimizer(model=model, optimizer_class=torch.optim.SGD, lr=0.005, freeze_encoder=config.freeze_encoder)
# opt = torch.optim.SGD(model.parameters(), lr=0.001)

# Make the evaluation suite (may compute multiple metrics corresponding to the tasks)
eval_fns: Dict[str, Dict[str, Callable]] = {
    'ner': {'acc': ner_all,
            'acc_l': ner_only_annotated,
            'span_p': ner_span_recog_precision,
            'span_r': ner_span_recog_recall},
    'coref': {

    },
    'pruner': {'p': pruner_p,
               'r': pruner_r}
}

print(config)
print("Training commences!")

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "binary_hdim": 2000,
  "classifier_dropout": null,
  "coref_dropout": 0.3,
  "device": "cpu",
  "epochs": 10,
  "filter_candidates_pos_threshold": 2000,
  "freeze_encoder": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "max_span_width": 5,
  "max_top_antecedents": 50,
  "metadata_feature_size": 20,
  "model_type": "bert",
  "ner_class_weights": [
    1.0
  ],
  "ner_ignore_weights": false,
  "ner_n_classes": 1,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "top_span_ratio": 0.4,
  "transformers_version": "4.13.0",
  "trim": false,
  "type_vocab_size": 2,
  "unary_hdim": 1000,
  "use_cache": true,
  "vocab_size": "../models/huggingface/bert-base-uncased/config"
}

Training commences!


In [5]:
dl = train_ds()

Pulled 2775 instances from ../data/parsed/ontonotes/train/MultiTaskDatasetDump_coref.pkl.


# Eval for Coref

In [6]:

def b_cubed(clusters, mention_to_gold):
    num, dem = 0, 0

    for c in clusters:
        if len(c) == 1:
            continue

        gold_counts = Counter()
        correct = 0
        for m in c:
            if m in mention_to_gold:
                gold_counts[tuple(mention_to_gold[m])] += 1
        for c2, count in gold_counts.items():
            if len(c2) != 1:
                correct += count * count

        num += correct / float(len(c))
        dem += len(c)

    return num, dem


def muc(clusters, mention_to_gold):
    tp, p = 0, 0
    for c in clusters:
        p += len(c) - 1
        tp += len(c)
        linked = set()
        for m in c:
            if m in mention_to_gold:
                linked.add(mention_to_gold[m])
            else:
                tp -= 1
        tp -= len(linked)
    return tp, p


def phi4(c1, c2):
    return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))


def ceafe(clusters, gold_clusters):
    clusters = [c for c in clusters if len(c) != 1]
    scores = np.zeros((len(gold_clusters), len(clusters)))
    for i in range(len(gold_clusters)):
        for j in range(len(clusters)):
            scores[i, j] = phi4(gold_clusters[i], clusters[j])
    matching = linear_assignment(-scores)
    similarity = sum(scores[matching[0], matching[1]])

    # similarity = sum(scores[matching[:, 0], matching[:, 1]])
    return similarity, len(clusters), similarity, len(gold_clusters)

In [7]:
for i, instance in enumerate(dl):
    outputs = model.pred_with_labels(**instance)
    break

In [8]:
! free -h

              total        used        free      shared  buff/cache   available
Mem:            15G         11G        1,4G        827M        2,3G        2,6G
Swap:          979M        759M        220M


In [9]:
instance['coref'].keys(), outputs['coref'].keys(), outputs.keys()

(dict_keys(['gold_cluster_ids_on_candidates', 'gold_starts', 'gold_ends', 'gold_cluster_ids']),
 dict_keys(['logits', 'labels']),
 dict_keys(['loss', 'ner', 'coref', 'pruner']))

In [10]:
# We go for each cluster id. That's not too difficult
gold_clusters = {}
for i, val in enumerate(instance['coref']['gold_cluster_ids']):
    cluster_id = val.item()
    
    # Populate the dict
    gold_clusters[cluster_id] = gold_clusters.get(cluster_id, []) + \
        [(instance['coref']['gold_starts'][i].item(), instance['coref']['gold_ends'][i].item())]

gold_clusters = [tuple(v) for v in gold_clusters.values()]
mention_to_gold = {}
for c in gold_clusters:
    for mention in c:
        mention_to_gold[mention] = c

In [11]:
gold_clusters.__len__(), instance['coref']['gold_starts'].shape

(11, torch.Size([96]))

In [None]:
top_indices = torch.argmax(outputs['coref']['top_antecedent_scores'], dim=1, keepdim=False)
ids = outputs['input_ids']
top_span_starts = outputs['pruner']['top_span_starts']
top_span_ends = outputs['pruner']['top_span_ends']

In [None]:
top_indices.max()

In [None]:
outputs['pruner'].keys()

In [None]:
# example = {name: tensor.cpu() for name, tensor in example.items()} 
# outputs = {name: tensor.cpu() for name, tensor in outputs.items()}
# gold_clusters = {}
# for i in range(len(example["cluster_ids"])):
#     assert len(example["cluster_ids"]) == len(
#         example["gold_starts"]) == len(example["gold_ends"])
#     cid = example["cluster_ids"][i].item()
#     if cid in gold_clusters:
#         gold_clusters[cid].append((example["gold_starts"][i].item(),
#                                    example["gold_ends"][i].item()))
#     else:
#         gold_clusters[cid] = [(example["gold_starts"][i].item(),
#                                example["gold_ends"][i].item())]

# gold_clusters = [tuple(v) for v in gold_clusters.values()]
# mention_to_gold = {}
# for c in gold_clusters:
#     for mention in c:
#         mention_to_gold[mention] = c

    top_indices = torch.argmax(outputs["top_antecedent_scores"], dim=-1, keepdim=False)
    ids = outputs["flattened_ids"]
    top_span_starts = outputs["top_span_starts"]
    top_span_ends = outputs["top_span_ends"]
    top_antecedents = outputs["top_antecedents"]
    mention_indices = []
    antecedent_indices = []
    predicted_antecedents = []
    for i in range(len(outputs["top_span_ends"])):
        if top_indices[i] > 0:
            mention_indices.append(i)
            antecedent_indices.append(top_antecedents[i][top_indices[i] - 1].item())
            predicted_antecedents.append(top_indices[i] - 1)
    