In [205]:
import numpy as np
import gezi
from gezi import tqdm
import sys,os
sys.path.append('..')
sys.path.append('../../../../utils')
sys.path.append('../../../../third')
from src.config import *

In [86]:
file = '../working/offline/36/0/pt.tiny.roberta-base.sm=start.tiny/valid_ori.pkl'

In [87]:
inputs = gezi.load(file)

In [88]:
inputs.keys()

dict_keys(['id', 'pred', 'word_ids', 'num_words', 'label', 'start', 'start_logits', 'parts', 'para_logits'])

In [121]:
def token2word(start_logits, token_logits, starts, labels, word_ids):
  start_logits_ , token_logits_ , starts_, labels_ = [], [], [], []
  for i in range(len(word_ids)):
    if word_ids[i] < 0:
      continue
    start_logits_.append(start_logits[i])
    token_logits_.append(token_logits[i])
    starts_.append(starts[i])
    labels_.append(labels[i])
  return np.asarray(start_logits_), np.asarray(token_logits_), np.asarray(starts_), np.asarray(labels_)

In [153]:
def decode_label(starts, labels):
  m = {}
  for c in range(1, NUM_CLASSES):
    m[c] = []
  start, c = 0, labels[0]
  for i in range(len(labels)):
    if labels[i] == 0:
      continue
    # if labels[i] != c:
    if starts[i] > 0:
      if i > start:
        if c:
          m[c].append([start, i])
      start = i
      c = labels[i]
  i += 1
  if i > start:
    if c:
      m[c].append([start, i])
  return m
      

In [154]:
def greedy_decode(start_logits, token_logits):
  m = {}
  for c in range(1, NUM_CLASSES):
    m[c] = []
  starts = start_logits[:,1] > start_logits[:,0]
  start, c = 0, None
  for i in range(len(starts)):
    if starts[i]:
      c = token_logits[start:i].sum(0).argmax()
      if c:
        m[c].append([start, i])
      start = i
  i += 1
  if i > start:
    c = token_logits[start:i].sum(0).argmax()
    if c:
      m[c].append([start, i])
  return m

In [197]:
def sample_decode(start_logits, token_logits):
  m = {}
  for c in range(1, NUM_CLASSES):
    m[c] = []
  starts = []
  start_probs = gezi.softmax(start_logits)
  for i in range(len(start_logits)):
    idx = np.random.choice(2, None, p=start_probs[i])
    starts.append(idx)
  start, c = 0, None
  for i in range(len(starts)):
    if starts[i]:
      probs = gezi.softmax(token_logits[start:i].sum(0))
      c = np.random.choice(NUM_CLASSES, None, p=probs)
      if c:
        m[c].append([start, i])
      start = i
  i += 1
  if i > start:
    probs = gezi.softmax(token_logits[start:i].sum(0))
    c = np.random.choice(NUM_CLASSES, None, p=probs)
    if c:
      m[c].append([start, i])
  return m

In [194]:
def calc_f1(m):
  f1_scores = []
  ignores = 0
  for c in range(1, NUM_CLASSES):
    TP = m[c]['match']
    FP = m[c]['pred'] - TP
    FN = m[c]['gt'] - TP
    if m[c]['gt'] == 0 and m[c]['pred'] == 0:
      f1_score = 0
      ignores += 1
    else:
      f1_score = TP / (TP + 0.5 * (FP + FN))
    f1_scores.append(f1_score)
  return np.sum(f1_scores) / (len(f1_scores) - ignores)

In [211]:
def calc_rewards(starts_list, labels_list, start_logits_list, token_logits_list, word_ids_list):
  m = {}
  for c in range(1, NUM_CLASSES):
    m[c] = {'match': 0, 'gt': 0, 'pred': 0}
  m2 = {}
  for c in range(1, NUM_CLASSES):
    m2[c] = {'match': 0, 'gt': 0, 'pred': 0}
  for i, (starts, labels, start_logits, token_logits, word_ids) in tqdm(enumerate(zip(starts_list, labels_list, start_logits_list, token_logits_list, word_ids_list)), total=len(starts_list)):
    start_logits, token_logits, starts, labels = token2word(start_logits, token_logits, starts, labels, word_ids)
    gt = decode_label(starts, labels)
    greedy = greedy_decode(start_logits, token_logits)
    res = prepare_f1(gt, greedy)
    for c in range(1, NUM_CLASSES):
      m[c]['match'] += res[c]['match']
      m[c]['gt'] += res[c]['gt']
      m[c]['pred'] += res[c]['pred']
    
    sample = sample_decode(start_logits, token_logits)
    res = prepare_f1(gt, sample)
    for c in range(1, NUM_CLASSES):
      m2[c]['match'] += res[c]['match']
      m2[c]['gt'] += res[c]['gt']
      m2[c]['pred'] += res[c]['pred']
    if i == 3:
      break
  
  greedy_f1 = calc_f1(m)
  sample_f1 = calc_f1(m2)
  rewards = sample_f1 - greedy_f1
  return rewards

In [212]:
calc_rewards(inputs['start'], inputs['label'], inputs['start_logits'], inputs['pred'], inputs['word_ids'])

  0%|          | 0/3117 [00:00<?, ?it/s]

ic| <ipython-input-211-adca87f232df>:29 in calc_rewards()
    greedy_f1: 0.800199390100868
    sample_f1: 0.7942176870748299


-0.0059817030260381765

In [139]:
# test label
start_logits, token_logits, starts, labels =token2word(inputs['start_logits'][0], inputs['pred'][0], inputs['start'][0], inputs['label'][0], inputs['word_ids'][0])

In [129]:
def is_match(gt, pred):
  s = min(gt[1], pred[1]) - max(gt[0], pred[0]) + 1
  intersect = max(0, s)
  return intersect / (gt[1] - gt[0]) >= 0.5 and intersect / (pred[1] - pred[0]) >= 0.5

In [130]:
def calc_match(gts, preds):
  matches = 0
  for gt in gts:
    for pred in preds:
      if is_match(gt, pred):
        matches += 1
  return matches
      

In [134]:
def prepare_f1(gt, pred):
  m = {}
  for c in range(1, NUM_CLASSES):
    m[c] = {
      'match': calc_match(gt[c], pred[c]),
      'gt': len(gt[c]),
      'pred': len(pred[c])
    }
  return m

In [126]:
gt = decode_label(starts, labels)
gt

{1: [[63, 88], [172, 185], [273, 285]],
 2: [[88, 172], [185, 273], [285, 367]],
 3: [[56, 63]],
 4: [[367, 421]],
 5: [[0, 56]],
 6: [],
 7: []}

In [181]:
greedy = greedy_decode(start_logits, token_logits)
greedy

{1: [[63, 88], [172, 185], [273, 285]],
 2: [[88, 172], [185, 273], [285, 367]],
 3: [[40, 63]],
 4: [[367, 421]],
 5: [[0, 40]],
 6: [],
 7: []}

In [193]:
random_decode(start_logits, token_logits)

{1: [[63, 88], [172, 185], [273, 285]],
 2: [[88, 172], [185, 273], [285, 367]],
 3: [[40, 63]],
 4: [[367, 421]],
 5: [[0, 40]],
 6: [],
 7: []}

In [83]:
inputs['label'][0]

array([0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0, 5, 0, 5, 5, 0, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0, 5, 5, 5, 5, 5,
       5, 5, 0, 5, 5, 5, 5, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0, 3,
       3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 2,
       2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 2, 0, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2,
       2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2,

In [84]:
inputs['word_ids'][0]

array([ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,  -1,   9,  -1,
        10,  11,  -1,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,
        -1,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  -1,
        33,  34,  35,  36,  37,  38,  39,  -1,  40,  41,  42,  43,  -1,
        44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  -1,
        56,  57,  58,  59,  60,  61,  62,  -1,  -1,  -1,  -1,  -1,  63,
        -1,  64,  65,  66,  -1,  67,  68,  -1,  69,  70,  71,  72,  73,
        74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  -1,
        86,  87,  -1,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 103, 104, 105, 106, 107,  -1, 108, 109,
       110, 111, 112, 113, 114, 115, 116, 117, 118, 119,  -1, 120, 121,
        -1, 122, 123, 124, 125, 126, 127, 128, 129,  -1, 130, 131, 132,
       133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,  -1, 144,
       145, 146, 147, 148, 149, 150, 151, 152,  -1, 153, 154, 15

In [53]:
inputs['pred'].shape

(3117, 512, 8)

In [52]:
inputs['label'].shape

(3117, 512)

In [54]:
inputs['word_ids']

array([[ -1,   0,   1, ...,  -1,  -1,  -1],
       [ -1,   0,   1, ..., 458, 459, 460],
       [ -1,   0,   1, ...,  -1,  -1,  -1],
       ...,
       [ -1,   0,   1, ..., 429, 430, 431],
       [ -1,   0,   1, ...,  -1,  -1,  -1],
       [ -1,   0,   1, ...,  -1,  -1,  -1]], dtype=int32)

In [None]:
# rewards(inputs['label'], inputs['start_logits'], inputs['pred'])