In [5]:
import tqdm, os
import json, requests, torch
from multiprocessing import Pool
import json, pandas as pd, numpy as np, os
from datasets import load_dataset
import evaluate

# import sys
# sys.path.append("../")

In [6]:
def read_training_dynamics(model_dir: os.path,
                           strip_last: bool = False,
                           id_field: str = "guid",
                           burn_out: int = None):
  """
  Given path to logged training dynamics, merge stats across epochs.
  Returns:
  - Dict between ID of a train instances and its gold label, and the list of logits across epochs.
  """
  train_dynamics = {}

  td_dir = os.path.join(model_dir, "training_dynamics")
  num_epochs = len([f for f in os.listdir(td_dir) if os.path.isfile(os.path.join(td_dir, f))])
  if burn_out:
    num_epochs = burn_out

  print(f"Reading {num_epochs} files from {td_dir} ...")
  for epoch_num in tqdm.tqdm(range(num_epochs)):
    epoch_file = os.path.join(td_dir, f"dynamics_epoch_{epoch_num}.jsonl")
    assert os.path.exists(epoch_file)

    with open(epoch_file, "r") as infile:
      for line in infile:
        record = json.loads(line.strip())
        guid = record[id_field] if not strip_last else record[id_field][:-1]
        if guid not in train_dynamics:
          assert epoch_num == 0
          train_dynamics[guid] = {"gold": record["gold"], "logits": []}
        train_dynamics[guid]["logits"].append(record[f"logits_epoch_{epoch_num}"])

  print(f"Read training dynamics for {len(train_dynamics)} train instances.")
  return train_dynamics

In [7]:
from collections import defaultdict
from typing import List

def compute_forgetfulness(correctness_trend: List[float]) -> int:
  """
  Given a epoch-wise trend of train predictions, compute frequency with which
  an example is forgotten, i.e. predicted incorrectly _after_ being predicted correctly.
  Based on: https://arxiv.org/abs/1812.05159
  """
  if not any(correctness_trend):  # Example is never predicted correctly, or learnt!
      return 1000
  learnt = False  # Predicted correctly in the current epoch.
  times_forgotten = 0
  for is_correct in correctness_trend:
    if (not learnt and not is_correct) or (learnt and is_correct):
      # nothing changed.
      continue
    elif learnt and not is_correct:
      # Forgot after learning at some point!
      learnt = False
      times_forgotten += 1
    elif not learnt and is_correct:
      # Learnt!
      learnt = True
  return times_forgotten


def compute_correctness(trend: List[float]) -> float:
  """
  Aggregate #times an example is predicted correctly during all training epochs.
  """
  return sum(trend)



def compute_train_dy_metrics(training_dynamics):
  """
  Given the training dynamics (logits for each training instance across epochs), compute metrics
  based on it, for data map coorodinates.
  Computed metrics are: confidence, variability, correctness, forgetfulness, threshold_closeness---
  the last two being baselines from prior work
  (Example Forgetting: https://arxiv.org/abs/1812.05159 and
   Active Bias: https://arxiv.org/abs/1704.07433 respectively).
  Returns:
  - DataFrame with these metrics.
  - DataFrame with more typical training evaluation metrics, such as accuracy / loss.
  """
  confidence_ = {}
  variability_ = {}
  threshold_closeness_ = {}
  correctness_ = {}
  forgetfulness_ = {}

  # Functions to be applied to the data.
  variability_func = lambda conf: np.std(conf)
  threshold_closeness_func = lambda conf: conf * (1 - conf)

  loss = torch.nn.CrossEntropyLoss()

  num_tot_epochs = len(list(training_dynamics.values())[0]["logits"])
  print(f"Computing training dynamics across {num_tot_epochs} epochs")
  print("Metrics computed: confidence, variability, correctness, forgetfulness, threshold_closeness")

  logits = {i: [] for i in range(num_tot_epochs)}
  targets = {i: [] for i in range(num_tot_epochs)}
  training_accuracy = defaultdict(float)

  for guid in tqdm.tqdm(training_dynamics):
    correctness_trend = []
    true_probs_trend = []

    record = training_dynamics[guid]
    for i, epoch_logits in enumerate(record["logits"]):
      probs = torch.nn.functional.softmax(torch.Tensor(epoch_logits), dim=-1)
      true_class_prob = float(probs[record["gold"]])
      true_probs_trend.append(true_class_prob)

      prediction = np.argmax(epoch_logits)
      is_correct = (prediction == record["gold"]).item()
      correctness_trend.append(is_correct)

      training_accuracy[i] += is_correct
      logits[i].append(epoch_logits)
      targets[i].append(record["gold"])

    # if burn_out < num_tot_epochs:
    #   correctness_trend = correctness_trend[:args.burn_out]
    #   true_probs_trend = true_probs_trend[:args.burn_out]

    correctness_[guid] = compute_correctness(correctness_trend)
    confidence_[guid] = np.mean(true_probs_trend)
    variability_[guid] = variability_func(true_probs_trend)

    forgetfulness_[guid] = compute_forgetfulness(correctness_trend)
    threshold_closeness_[guid] = threshold_closeness_func(confidence_[guid])

  # Should not affect ranking, so ignoring.
  epsilon_var = np.mean(list(variability_.values()))

  column_names = ['guid',
                  'index',
                  'threshold_closeness',
                  'confidence',
                  'variability',
                  'correctness',
                  'forgetfulness',]
  df = pd.DataFrame([[guid,
                      i,
                      threshold_closeness_[guid],
                      confidence_[guid],
                      variability_[guid],
                      correctness_[guid],
                      forgetfulness_[guid],
                      ] for i, guid in enumerate(correctness_)], columns=column_names)

  df_train = pd.DataFrame([[i,
                            loss(torch.Tensor(logits[i]), torch.LongTensor(targets[i])).item() / len(training_dynamics),
                            training_accuracy[i] / len(training_dynamics)
                            ] for i in range(num_tot_epochs)],
                          columns=['epoch', 'loss', 'train_acc'])
  return df, df_train

In [8]:
# model_dir = "/home/pritam.k/research/hf_audit/cart/financial_phrasebank/real_model/saved_tds/": 
model_dir = "../cartography/outputs"
# model_dir = "/home/pritam.k/research/hf_audit/cart/amazon_multi_reviews_v1/"

training_dynamics = read_training_dynamics(model_dir,
                                            strip_last=False,
                                            burn_out=None)
df_cart, _ = compute_train_dy_metrics(training_dynamics)
df_cart = df_cart.assign(corr_frac = lambda d: d.correctness / d.correctness.max())
df_cart['correct'] = [f"{x:.2f}" for x in df_cart['corr_frac']]

df_cart['correct'] = df_cart['correct'].astype("float64")

Reading 6 files from ../cartography/outputs/training_dynamics ...


100%|██████████| 6/6 [00:00<00:00,  6.53it/s]


Read training dynamics for 25000 train instances.
Computing training dynamics across 6 epochs
Metrics computed: confidence, variability, correctness, forgetfulness, threshold_closeness


100%|██████████| 25000/25000 [00:04<00:00, 5939.36it/s]


In [10]:
df_cart.head()

Unnamed: 0,guid,index,threshold_closeness,confidence,variability,correctness,forgetfulness,corr_frac,correct
0,101761,0,0.081642,0.910315,0.191507,5,0,0.833333,0.83
1,124457,1,0.078446,0.914191,0.186969,5,0,0.833333,0.83
2,107079,2,0.074944,0.918396,0.180779,6,0,1.0,1.0
3,102292,3,0.246512,0.559061,0.417149,4,1,0.666667,0.67
4,120856,4,0.082579,0.909171,0.199691,5,0,0.833333,0.83


In [32]:
df = df_cart
df["difficulty"] = pd.cut(df["correct"], bins=[-1.0, 0.2, 0.8, 1.0], labels=["hard", "ambiguous", "easy"])

In [33]:
df.head()

Unnamed: 0,guid,index,threshold_closeness,confidence,variability,correctness,forgetfulness,corr_frac,correct,difficulty
0,101761,0,0.081642,0.910315,0.191507,5,0,0.833333,0.83,easy
1,124457,1,0.078446,0.914191,0.186969,5,0,0.833333,0.83,easy
2,107079,2,0.074944,0.918396,0.180779,6,0,1.0,1.0,easy
3,102292,3,0.246512,0.559061,0.417149,4,1,0.666667,0.67,ambiguous
4,120856,4,0.082579,0.909171,0.199691,5,0,0.833333,0.83,easy


In [34]:
df.difficulty.value_counts()

difficulty
easy         24278
ambiguous      681
hard            41
Name: count, dtype: int64