# Setup

In [None]:
%%bash

pip install datasets
pip install transformers
pip install sentencepiece

In [None]:
%%bash

tar -xzf data.tgz
tar -xzf tokenizers.tgz

In [None]:
import os
import pathlib

DATA_DIR = 'data'
MODEL_DIR = 'models'

if not os.path.isdir(DATA_DIR):
  pathlib.Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
if not os.path.isdir(MODEL_DIR):
  pathlib.Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
print(os.listdir(DATA_DIR))
print(os.listdir(MODEL_DIR))

# Change huggingface cache directories.
os.environ['TRANSFORMERS_CACHE'] = os.path.join(DATA_DIR, 'hf_cache')
os.environ['HF_DATASETS_CACHE'] = os.path.join(DATA_DIR, 'hf_cache')

In [None]:
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Library

In [None]:
#@markdown T5 tokenizer

import copy
import numpy as np
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration

t5_default_tokenizer = T5Tokenizer.from_pretrained("t5-small", cache_dir=MODEL_DIR)
tokenizer = copy.deepcopy(t5_default_tokenizer)

VOCAB = list(sorted(tokenizer.get_vocab(), key=tokenizer.get_vocab().get))
SPM_SPACE = 'â–'

## IIT Data Generation

### Character annotations

In [None]:
#@markdown Pre-compute character annotations for interventions in TSV format.

# Each line has the format
#     "{input}\t{label}\t{feature}\t{anno_0}\t{anno_1}...\t{anno_n}",
# where {feature} is a substring of input to apply character interventions.
# For most tasks, feature equals to input.
# {anno_i} is the character annotation in the format
#     "{token_index}\t{char_pos}\t{char_val}"


import collections

def get_pos_to_token_index(text):
  """Map char index to SPM token index."""
  token_ids = tokenizer(text).input_ids
  tokens = [VOCAB[i] for i in token_ids if i != 1]
  pos_to_tokens = {}
  curr_pos = 0
  for i, tok in enumerate(tokens):
    if i == 0:
      # remove the leading SOS space.
      tok = tok.lstrip(SPM_SPACE)
    pos_to_tokens.update({pos: i for pos in range(curr_pos, curr_pos + len(tok))})
    curr_pos += len(tok)
  return pos_to_tokens, tokens


# Get character locations, specified by the token index and position of the
# char in the token.
# Each char is a candidate for intervention, which will be mapped to a fix
# sized dimensions in the transformer representations.
def get_char_level_inv_locs(input_tokens, span, include_space=False):
  """Returns a set of possible intervention locations."""
  inv_loc_candidates = [
        (i, c_i) for i in range(*span) for c_i in range(len(input_tokens[i]))
        # If include space, exclude the extra leading space added by the
        # tokenizer only, otherwise exclude all spaces.
        if (include_space and not (
            i == span[0] and c_i == 0 and input_tokens[i][c_i] == SPM_SPACE)) or
           (not include_space and input_tokens[i][c_i] != SPM_SPACE)]
  return inv_loc_candidates


def match_feature_span(input_text, feature_text):
  # Return the first occurrence.
  span_begin = input_text.find(feature_text)
  if span_begin < 0:
    return None
  return span_begin


def gen_char_level_feature_annotation_tsv(
    input_tsv_path, output_tsv_path, task_name, feature_to_column, inv_config):
  assert input_tsv_path != output_tsv_path
  assert ('input' in feature_to_column and 'label' in feature_to_column and
          'feature' in feature_to_column)
  stats = collections.defaultdict(int)
  with open(output_tsv_path, 'w') as f_out:
    with open(input_tsv_path, 'r') as f:
      for line in f:
        stats['total'] += 1
        parsed = line.strip().split('\t')
        input_text = parsed[feature_to_column['input']]
        label = parsed[feature_to_column['label']]
        feature_text = parsed[
            feature_to_column['feature'] if feature_to_column['feature'] < len(parsed) else 0]
        # Find the substring to apply character intervention.
        span_begin = match_feature_span(input_text, feature_text)
        if span_begin is None:
          stats['span_not_found'] += 1
          print('SPAN_NOT_FOUND:', input_text, feature_text)
          continue
        pos_to_token_index, input_tokens = get_pos_to_token_index(input_text)
        # Append an extra position to handle the end of span
        pos_to_token_index[len(input_text)] = len(input_tokens)
        # Check if the span is matched corretly, if not drop it.
        if (span_begin not in pos_to_token_index or
            span_begin + len(feature_text) not in pos_to_token_index):
          stats['span_not_in_pos_index'] += 1
          print('SPAN_NOT_IN_INDEX:', span_begin, span_begin + len(feature_text), input_text, feature_text)
          continue
        span_index = (pos_to_token_index[span_begin],
                      pos_to_token_index[span_begin + len(feature_text)])
        span_reconstruct = ''.join(
            input_tokens[span_index[0]:span_index[1]]).replace(SPM_SPACE, ' ').strip()
        if span_reconstruct != feature_text:
          stats['span_mismatch'] += 1
          print('MISMATCH:', feature_text, '|', span_reconstruct)
          continue
        # Annotated each char with its location and value.
        inv_loc_candidates = get_char_level_inv_locs(
            input_tokens, span_index, include_space=inv_config['include_space'])
        # Remove any char pos that is greater than 16 (i.e. subword pieces with
        # more than 16 characters, which should not happen for T5 vocab).
        if max([loc[1] for loc in inv_loc_candidates]) >= 16:
          stats['char_pos_exceed_16'] += 1
          print('CHAR_POS_EXCEED_16:', feature_text)
          continue
        # Sanity check that the number of chars annotated matches the
        if (inv_config['include_space'] and len(inv_loc_candidates) != len(feature_text)) or (
            not inv_config['include_space'] and len(inv_loc_candidates) != len(feature_text.replace(' ', ''))):
          print(feature_text)
          print(input_tokens[span_index[0]:span_index[1]])
          print(inv_loc_candidates)
          continue
        char_annotations = [loc + (input_tokens[loc[0]][loc[1]],) for loc in inv_loc_candidates]
        # Truncate or pad to max number of annotations.
        char_annotations = char_annotations[:inv_config['max_num_anno']]
        char_annotations.extend(
            [(-1, -1, '')] * (inv_config['max_num_anno'] - len(char_annotations)))
        data = [f'{input_text}\t{label}\t{task_name}\t{feature_text}'] + [
            '\t'.join(map(str, inv)) for inv in char_annotations]
        f_out.write('\t'.join(data) + '\n')
        stats['annotated'] += 1
  return dict(stats)


def gen_character_annotation_data(input_tsv_path_base, task_name,
                                  inv_config, splits=None):
  splits = splits or ['train', 'val']
  # The input format is either "{input}\t{output}" or
  # "{input}\t{output}\t{task}\t{template}\t{feature}".
  feature_to_column = {'input': 0, 'label': 1, 'feature': 4}
  dataset_split_to_path = {
      split: input_tsv_path_base % (task_name + (f'_{split}' if split else ''))
      for split in splits}
  anno_dataset_split_to_path = {}
  for split in dataset_split_to_path:
    span_annotated_tsv_path = dataset_split_to_path[split].replace(
        '.tsv', '_char_anno.tsv')
    anno_dataset_split_to_path[split] = span_annotated_tsv_path
    stats = gen_char_level_feature_annotation_tsv(
        dataset_split_to_path[split],
        span_annotated_tsv_path,
        task_name,
        feature_to_column,
        inv_config)
    print(split, stats)
    !wc -l $span_annotated_tsv_path
    !head $span_annotated_tsv_path
  return anno_dataset_split_to_path

### Gnerate intervention examples

In [None]:
#@markdown Simulate effects of character interventions for pure form-based tasks.


import math
import random
import string


def intervene_reversal_base_example(base_example, max_num_inv):
  # Select a set of chars to keep and randomly assign values to others
  base_val = base_example['feature']
  num_char_inv = random.randint(1, min(max_num_inv, len(base_val)))
  num_char_keep = len(base_val) - num_char_inv
  kept_pos = random.sample(range(len(base_val)), k=num_char_keep)
  inv_feature = ''.join([base_val[i] if i in kept_pos else random.choice(string.ascii_lowercase)
                  for i in range(len(base_val))])
  return {
      'input': inv_feature,
      'feature': inv_feature,
      'label': inv_feature[::-1],
      'inv_locs': [(None, None, c) for c in inv_feature]}


def shift_digit(num_str, shift):
  int_str, frac_str = num_str, ''
  if '.' in num_str:
    int_str, frac_str = num_str.split('.')
  if shift >= 0:
    if shift < len(frac_str):
      num_str = ((int_str + frac_str[:shift]).lstrip('0') or '0') + '.' + frac_str[shift:]
    else:
      num_str = (int_str + frac_str + '0' * (shift - len(frac_str))).lstrip('0') or '0'
  else:
    shift = abs(shift)
    if shift < len(int_str):
      num_str = int_str[:len(int_str)-shift] + '.' + int_str[-shift:] + frac_str
    else:
      num_str = '0.' + '0' * (shift - len(int_str)) + int_str + frac_str
    num_str = num_str.rstrip('0').rstrip('.')
  return num_str


def intervene_unit_conversion_base_example(base_example):
  # Get source and target numbers to determine the unit conversion in the base example.
  base_src_number, base_trg_number = base_example['feature'], base_example['label']
  inv_int_digits = ''.join(
      random.choices(string.digits, k=random.randint(1, 4))).lstrip('0') or '0'
  inv_frac_digits = ''.join(
      random.choices(string.digits, k=random.randint(1, 3))).rstrip('0') or ''
  inv_src_number = inv_int_digits
  if inv_frac_digits:
    inv_src_number += '.' + inv_frac_digits
  shift = (round(math.log10(float(base_trg_number) / float(base_src_number)))
          if base_trg_number >= base_src_number else
           -round(math.log10(float(base_src_number) / float(base_trg_number))))
  inv_trg_number = shift_digit(inv_src_number, shift)
  return {
      'input': base_example['input'].replace(base_src_number, inv_src_number),
      'feature': inv_src_number,
      'label': inv_trg_number,
      'inv_locs': [(None, None, c) for c in inv_src_number]}


# For pure form-based tasks where no lexicon is involved.
def gen_form_only_inv_example(base_example, max_num_inv, task):
  if task == 'reversal':
    inv_example = intervene_reversal_base_example(base_example, max_num_inv)
  elif task == 'unit_conversion':
    inv_example = intervene_unit_conversion_base_example(base_example)
  else:
    raise NotImplementedError
  return inv_example


# Test
def run_test():
  print(gen_form_only_inv_example({'input': 'abcdef', 'label': 'fedcba', 'feature': 'abcdef'},
                                   max_num_inv=8, task='reversal'))
  # Possible outputs:
  # {'inv_input': 'avcdgf',
  #  'inv_feature': 'avcdgf',
  #  'inv_label': 'fgdcva',
  #  'inv_locs': [(None, None, 'a'),
  #   (None, None, 'v'),
  #   (None, None, 'c'),
  #   (None, None, 'd'),
  #   (None, None, 'g'),
  #   (None, None, 'f')]}
  print(gen_form_only_inv_example({'input': 'convert 123 million to billion', 'label': '0.123', 'feature': '123'},
                                    max_num_inv=8, task='unit_conversion'))
  # Possible outputs:
  # {'inv_input': 'convert 6.9 million to billion',
  #  'inv_feature': '6.9',
  #  'inv_label': '0.0069',
  #  'inv_locs': [(None, None, '6'), (None, None, '.'), (None, None, '9')]}

In [None]:
#@markdown IIT Data Generation (The algorithm in Appendix A.1)

import itertools
import math
import string
random.seed(0)


ANNO_LEN = 3

def parse_anno_line(line):
  """Parse character annotations.
    Each line has the format
        "{input}\t{label}\t{task}\t{feature}\t{anno_0}\t{anno_1}...\t{anno_n}",
    where {feature} is a substring of input to apply character interventions.
    For most tasks, feature equals to input.
    {anno_i} is the character annotation in the format
        "{token_index}\t{char_pos}\t{char_val}"
  """
  parsed = line.strip('\n').split('\t')
  loc_offset = 4
  num_anno = (len(parsed) - loc_offset) // ANNO_LEN
  return {'input': parsed[0], 'label': parsed[1], 'feature': parsed[3],
          'inv_locs': [[
              # (token_index, char_pos)
              int(parsed[loc_offset + offset + i * ANNO_LEN]) for offset in range(ANNO_LEN - 1)] + [
              # char_val
              parsed[loc_offset + ANNO_LEN - 1 + i * ANNO_LEN]]
              for i in range(num_anno)] + [[-1, -1, '']]}


def get_all_inv_val(example):
  """Return the sequence of char values for interventions."""
  # The inv value is different from feature if SPM_SPACE in feature.
  return ''.join([loc[-1] for loc in example['inv_locs'] if loc[0] != -1]).lower()


def get_template_val(example, task):
  # For most tasks, the template is an empty string.
  # For unit_conversion, the template is the context that specifies which
  # operation to perform.
  template = example['input'].replace(example['feature'], '{feature}')
  if task == 'spelling_correction_contextual':
    template = template.lower()
  return template


def build_char_to_feature_index(feature_to_example_index, key_type):
  if key_type == 'char':
    char_to_feature_index = collections.defaultdict(set)
    for feat_val in feature_to_example_index:
      for c in feat_val:
        char_to_feature_index[c].add(feat_val)
  elif key_type.startswith('char'):
    # Cache more all keys of length n for faster indexing.
    n = int(key_type.replace('char', ''))
    ordered_feature_index = list(feature_to_example_index.keys())
    char_to_feature_index = collections.defaultdict(set)
    for feat_i, feat_val in enumerate(ordered_feature_index):
      key = sorted(set(feat_val))
      for subset in itertools.combinations(key, n):
        char_to_feature_index[''.join(subset)].add(feat_i)
        # Avoid OOM.
        if len(char_to_feature_index) > 10_000:
          print(list(char_to_feature_index.keys())[:20])
          raise ValueError
  else:
    raise NotImplementedError
  print('char_to_feature_index:', 'len=%d' % len(char_to_feature_index),
        {c: len(char_to_feature_index[c]) for c in sorted(char_to_feature_index)})
  char_to_feature_index = dict(char_to_feature_index)
  return char_to_feature_index if key_type == 'char' else (char_to_feature_index, ordered_feature_index)


# Optimize the char lookup order from least frequent to most frequent
# for faster indexing.
letter_sorted_by_freq = dict(zip(
    [' ', 'e', 't', 'a', 'i', 'n', 'o', 's', 'h', 'r', 'd', 'l', 'u', 'c', 'm',
         'f', 'w', 'y', 'g', 'p', 'b', 'v', 'k', 'q', 'j', 'x', 'z',
     '9', '8', '7', '6', '5', '4', '3', '2', '1', '0', ',', '.', '-', "'"],
    range(41)))


def get_features_containing_chars(required_chars, char_to_feature_index, key_type):
  if key_type == 'char':
    required_chars = sorted(set(required_chars), key=lambda k: len(char_to_feature_index[k]))
    # Run time optimization by computing set intersection over smallest sets first.
    source_val_candids = list(
        set.intersection(*[char_to_feature_index[c] for c in required_chars]))
  elif key_type.startswith('char'):
    n = int(key_type.replace('char', ''))
    char_to_feature_index, ordered_feature_index = char_to_feature_index
    required_chars = sorted(set(required_chars), key=letter_sorted_by_freq.get, reverse=True)
    required_char_paris = [''.join(sorted(required_chars[i:i+n]))
                           for i in range(0, len(required_chars) - (len(required_chars) % n), n)]
    if len(required_chars) % n:
      padded_chars = (required_chars[-n:] if len(required_chars[-n:]) == n else
                      required_chars + [' '])
      required_char_paris.append(''.join(sorted(padded_chars)))
    required_char_paris = sorted(required_char_paris, key=lambda k: len(char_to_feature_index[k]))
    source_feat_index_candids = list(
        set.intersection(*[char_to_feature_index[p] for p in required_char_paris]))
    source_val_candids = [ordered_feature_index[i] for i in source_feat_index_candids]
  return source_val_candids


def generate_triplet_index(examples, max_num_triplet, max_num_inv, task_name):
  feature_to_example_index = collections.defaultdict(list)
  for i in range(len(examples)):
    feature_to_example_index[get_all_inv_val(examples[i])].append(i)
  print('#unique_index_feature=%d' % len(feature_to_example_index))
  # For most tasks without a template, the index has a single key which is the
  # empty string.
  template_to_example_index = collections.defaultdict(list)
  for i in range(len(examples)):
    template_to_example_index[get_template_val(examples[i], task_name)].append(i)
  print('#unique_template=%d' % len(template_to_example_index))
  key_type = 'char2' if task_name == 'spelling_correction_contextual' else 'char'
  char_to_feature_index = build_char_to_feature_index(
      feature_to_example_index, key_type)
  stats = collections.defaultdict(int)
  max_num_triplet_per_base = int(max_num_triplet / len(examples)) + 1
  print(f'max_num_triplet_per_base={max_num_triplet_per_base}')
  for base_i in range(len(examples)):
    base_val = get_all_inv_val(examples[base_i])
    num_triplet_per_base = 0
    template = get_template_val(examples[base_i], task_name)
    inv_index_candids = template_to_example_index[template]
    # Skip base with only one possible intervention value.
    if len(inv_index_candids) == 1:
      stats['single_inv_candid_for_base'] += 1
      continue
    inv_index_samples = random.sample(
        inv_index_candids, min(len(inv_index_candids), max_num_triplet_per_base * 2))
    num_unique_inv_label = len(set([examples[i]["label"] for i in inv_index_samples]))
    max_num_triplet_per_inv = math.ceil(max_num_triplet_per_base / num_unique_inv_label)
    for inv_i in inv_index_samples:
      num_triplet_per_inv = 0
      inv_val = get_all_inv_val(examples[inv_i])
      required_chars = [inv_val[i] for i in range(len(inv_val))
                        if i >= len(base_val) or base_val[i] != inv_val[i]]
      if len(required_chars) > max_num_inv:
        stats['exceed_max_require_chars'] += 1
        continue
      if len(required_chars) == 0 or len(required_chars) == 1:
        stats['base_inv_too_similar'] += 1
        continue
      source_val_candids = get_features_containing_chars(
          required_chars, char_to_feature_index, key_type)
      if not source_val_candids:
        stats['no_source_candidates'] += 1
      for _ in range(len(source_val_candids)):
        source_val = random.choice(source_val_candids)
        for source_i in random.sample(feature_to_example_index[source_val], 1):
          if source_i == inv_i:
            continue
          yield (base_i, source_i, inv_i)
          stats['valid'] += 1
          num_triplet_per_base += 1
          num_triplet_per_inv += 1
          if stats['valid'] >= max_num_triplet:
            print(dict(stats))
            return
        if num_triplet_per_base >= max_num_triplet_per_base or num_triplet_per_inv >= max_num_triplet_per_inv:
          break
      if num_triplet_per_base >= max_num_triplet_per_base:
        break
  print(dict(stats))
  return


def generate_form_only_triplet_index(examples, max_num_triplet, max_num_inv, task_name):
  feature_to_example_index = collections.defaultdict(list)
  for i in range(len(examples)):
    feature_to_example_index[get_all_inv_val(examples[i])].append(i)
  print('#unique_index_feature=%d' % len(feature_to_example_index))
  key_type = 'char'
  char_to_feature_index = build_char_to_feature_index(
      feature_to_example_index, key_type)

  stats = collections.defaultdict(int)
  max_num_triplet_per_base = int(max_num_triplet / len(examples)) + 1
  print(f'max_num_triplet_per_base={max_num_triplet_per_base}')
  for base_i in range(len(examples)):
    base_example = examples[base_i]
    base_val = get_all_inv_val(base_example)
    num_triplet_per_base = 0
    for _ in range(max_num_triplet_per_base * 2):
      inv_example = gen_form_only_inv_example(base_example, 8, task_name)
      inv_val = get_all_inv_val(inv_example)
      required_chars = [inv_val[i] for i in range(len(inv_val))
                        if i >= len(base_val) or base_val[i] != inv_val[i]]
      if len(required_chars) > max_num_inv or len(required_chars) == 0:
        stats['invalid_require_chars'] += 1
        continue
      source_val_candids = get_features_containing_chars(
          required_chars, char_to_feature_index, key_type)
      if not source_val_candids:
        stats['no_source_candidates'] += 1
      num_triplet_per_inv = 0
      for _ in range(len(source_val_candids)):
        source_val = random.choice(source_val_candids)
        for source_i in random.sample(feature_to_example_index[source_val], 1):
          num_triplet_per_inv += 1
          yield (base_example, source_i, inv_example)
          stats['valid'] += 1
          num_triplet_per_base += 1
          if stats['valid'] >= max_num_triplet:
            print(dict(stats))
            return
          if num_triplet_per_base >= max_num_triplet_per_base or num_triplet_per_inv >= 1:
            break
        if num_triplet_per_base >= max_num_triplet_per_base or num_triplet_per_inv >= 1:
          break
      if num_triplet_per_base >= max_num_triplet_per_base:
        break
  print(dict(stats))
  return


def test_generate_triplet_index(anno_input_tsv_path, iit_config, task_name):
  with open(anno_input_tsv_path, 'r') as f:
    lines = f.readlines()
    examples = list(map(lambda x: parse_anno_line(x), lines))
  print(f'Parsed {len(examples)} examples.')
  #stats = collections.defaultdict(int)
  for base_index, source_index, inv_index in generate_triplet_index(
      examples, max_num_triplet=20, max_num_inv=iit_config['max_num_inv'], task_name=task_name):
    base_val = examples[base_index]['feature']
    source_val = examples[source_index]['feature']
    inv_val = examples[inv_index]['feature']
    print("BASE  :", base_val, '|', examples[base_index]['label'])
    print("INV   :", inv_val, '|', examples[inv_index]['label'])
    print("SOURCE:", source_val, '|', examples[source_index]['label'])
    #stats[(len(examples[base_index]['label']) < 48, len(examples[source_index]['label']) < 48)] += 1
    for i, c in enumerate(inv_val):
      assert c.lower() in source_val.lower() or (
          i < len(base_val) and base_val[i].lower() == inv_val[i].lower()), (i, c)
  #print(dict(stats))


def test_generate_form_only_triplet_index(anno_input_tsv_path, iit_config, task_name):
  with open(anno_input_tsv_path, 'r') as f:
    lines = f.readlines()
    examples = list(map(lambda x: parse_anno_line(x), lines))
  print(f'Parsed {len(examples)} examples.')
  for base_example, source_index, inv_example in generate_form_only_triplet_index(
      examples, max_num_triplet=20, max_num_inv=iit_config['max_num_inv'], task_name=task_name):
    base_val = base_example['feature']
    source_val = examples[source_index]['feature']
    inv_val = inv_example['feature']
    print("BASE  :", base_val, '|', base_example['label'])
    print("INV   :", inv_val, '|', inv_example['label'])
    print("SOURCE:", source_val, '|', examples[source_index]['label'])
    for i, c in enumerate(inv_val):
      assert c in source_val or (i < len(base_val) and base_val[i] == inv_val[i]), (i, c)


# For testing

#task_name = 'reversal'
#test_generate_triplet_index(
#    anno_dataset_split_to_path['train'], TASK_TO_IIT_DATA[task_name]['iit_config'], task_name)
#test_generate_form_only_triplet_index(
#    anno_dataset_split_to_path['train'], TASK_TO_IIT_DATA[task_name]['iit_config'], task_name)

In [None]:
import random
random.seed(0)

MAX_CHAR_PER_TOKEN = 16

def get_pad_locations(source_inv_locs):
  loc_to_pos = collections.defaultdict(list)
  for i, loc in enumerate(source_inv_locs):
    if loc[0] == -1:
      break
    loc_to_pos[loc[0]].append(loc[1])
  return [(loc, max(loc_to_pos[loc])) for loc in loc_to_pos
          if max(loc_to_pos[loc]) < MAX_CHAR_PER_TOKEN]


# Interventions that preserve source distribution
def gen_inv_base_and_source_features(
    base_example, source_example, inv_example, max_num_inv):
  # Find inv locations in source.
  inv_val = get_all_inv_val(inv_example)
  base_val = get_all_inv_val(base_example)
  source_char_to_location = {}
  for i, source_loc in enumerate(source_example['inv_locs']):
    if source_loc[0] < 0:
      break
    source_char_to_location[source_loc[-1].lower()] = source_loc[:-1]
  source_char_to_location['<pad>'] = get_pad_locations(source_example['inv_locs'])
  # Find inv locations in base.
  base_locations = []
  source_locations = []
  for i, base_loc in enumerate(base_example['inv_locs']):
    if base_loc[0] < 0:
      break
    if (i < len(inv_val) and base_val[i] != inv_val[i]) or (
        base_val == inv_val and inv_val[i] in source_char_to_location):
      if inv_val[i] not in source_char_to_location:
        print(base_val)
        print(inv_val)
        print(inv_val[i], i)
        print(source_example['feature'])
        print(source_char_to_location)
      source_loc = source_char_to_location[inv_val[i]]
      base_locations.append(base_loc[:-1])
      source_locations.append(source_loc)
    elif i >= len(inv_val):
      # base has extra char.
      if not source_char_to_location['<pad>']:
        return None
      source_loc = random.choice(source_char_to_location['<pad>'])
      source_locations.append(source_loc)
      base_locations.append(base_loc[:-1])
    curr_loc = base_loc
  # inv has extra char
  for i in range(len(base_val), len(inv_val)):
    source_locations.append(source_char_to_location[inv_val[i]])
    # Append positions after the last char.
    # Check if it goes over 16 chars.
    if curr_loc[1] >= MAX_CHAR_PER_TOKEN - 1:
      return None
    base_locations.append((curr_loc[0], curr_loc[1] + 1))
    curr_loc = base_locations[-1]
  assert max([loc[-1] for loc in base_locations]) < MAX_CHAR_PER_TOKEN
  assert len(base_locations) == len(source_locations)
  base_locations.extend([[-1, -1]] * (max_num_inv - len(base_locations)))
  source_locations.extend([[-1, -1]] * (max_num_inv - len(source_locations)))
  return {'base_input': base_example['input'],
          'source_input': source_example['input'],
          'base_label': base_example['label'], # for debugging
          'source_label': source_example['label'], # for debugging
          'inv_label': inv_example['label'], # for debugging
          'inv_input': inv_example['input'], # for debugging
          'base_locations': base_locations,
          'source_locations': source_locations}

In [None]:
import time

def gen_inv_examples(anno_input_tsv_path, output_tsv_path, num_examples, task_name):
  iit_config = TASK_TO_IIT_DATA[task_name]['iit_config']
  max_input_seq_len = TASK_TO_DATASETS[task_name]['seq_length']['max_src_token']
  label_dist = collections.defaultdict(int)
  with open(anno_input_tsv_path, 'r') as f:
    lines = f.readlines()
    examples = list(map(lambda x: parse_anno_line(x), lines))
  print(f'Parsed {len(examples)} examples.')
  is_form_only = task_name in ('reversal', 'unit_conversion')
  triplet_gen_fn = generate_form_only_triplet_index if is_form_only else generate_triplet_index
  triplet_gen = triplet_gen_fn(
        examples, max_num_triplet=num_examples,
        max_num_inv=iit_config['max_num_inv'], task_name=task_name)
  inv_examples = []
  unique_inv_label_per_base = collections.defaultdict(set)
  unique_source_pos_per_label = collections.defaultdict(set)
  f_out = open(output_tsv_path, 'w') if output_tsv_path else None
  start_time = time.time()
  for i, (base_i, source_i, inv_i) in enumerate(triplet_gen):
    inv_example = examples[inv_i] if not is_form_only else inv_i
    base_example = examples[base_i] if not is_form_only else base_i
    inv_example = gen_inv_base_and_source_features(
        base_example, examples[source_i], inv_example,
        iit_config['max_num_inv'])
    if not inv_example:
      continue
    unique_inv_label_per_base[inv_example['base_input']].add(inv_example['inv_label'])
    if i % 10000 == 0:
      print('Finished %d examples in %.2f sec.' % (i, time.time() - start_time))
    if not f_out:
      inv_examples.append(inv_example)
      continue
    if not inv_example:
      continue
    if (max([loc[0] for loc in inv_example['base_locations']]) >= max_input_seq_len) or (
        max([loc[0] for loc in inv_example['source_locations']]) >= max_input_seq_len):
      continue
    # Change (token_index, char_pos) to (token_index, char_pos, char_pos+1) to mark the space of the char.
    base_locs = [loc + [loc[-1] + 1 if loc[-1] >= 0 else -1] for loc in inv_example['base_locations']]
    source_locs = [loc + [loc[-1] + 1 if loc[-1] >= 0 else -1] for loc in inv_example['source_locations']]
    data = [f"{inv_example['base_input']}\t{inv_example['source_input']}\t{inv_example['inv_label']}"] + [
            '\t'.join(map(str, loc)) for loc in base_locs] + [
            '\t'.join(map(str, loc)) for loc in source_locs]
    f_out.write('\t'.join(data) + '\n')
  print('avg #inv_label_per_base=%.2f' % (np.mean(list(map(len, unique_inv_label_per_base.values())))),
        'min #inv_label_per_base=%.2f' % (min(map(len, unique_inv_label_per_base.values()))),
        'max #inv_label_per_base=%.2f' % (max(map(len, unique_inv_label_per_base.values()))))
  if not output_tsv_path:
    return inv_examples
  f_out.close()
  !wc -l $output_tsv_path
  !head -n 5 $output_tsv_path
  !tail -n 5 $output_tsv_path
  return None


def test_gen_inv_example():
  split = 'train'
  inv_examples = gen_inv_examples(anno_input_tsv_path, None, 50, task_name)
  for inv_example in inv_examples:
    if inv_example is None:
      continue
    _, base_input_tokens = get_pos_to_token_index(inv_example['base_input'])
    _, source_input_tokens = get_pos_to_token_index(inv_example['source_input'])
    print('BASE:', inv_example['base_input'], '\t', inv_example['base_label'])
    print('SOURCE:', inv_example['source_input'], '\t', inv_example['source_label'])
    print('INV:', inv_example['inv_input'], '\t', inv_example['inv_label'])
    print(base_input_tokens)
    print(inv_example['base_locations'])
    print([(base_input_tokens[s[0]], base_input_tokens[s[0]][s[1]:s[1] + 1])
           for s in inv_example['base_locations'] if s[-1] != -1])
    print(source_input_tokens)
    print(inv_example['source_locations'])
    print([(source_input_tokens[s[0]], source_input_tokens[s[0]][s[1]:s[1] + 1])
           for s in inv_example['source_locations'] if s[-1] != -1])

## Data Loaders

In [None]:
#@markdown TSV Data Loader

import copy
import json
import random
import os

import datasets
from datasets import load_dataset
from datasets import Dataset


def preproc_tokenize(examples, max_input_seq_len, max_output_seq_len,
                     input_feature=None, label_feature=None,
                     extra_feature_to_tokenize=None,
                     source_tokenizer=None,  target_tokenizer=None):
  source_tokenizer = source_tokenizer or t5_default_tokenizer
  target_tokenizer = target_tokenizer or t5_default_tokenizer
  input_batch = copy.deepcopy(examples)
  input_feature = input_feature or 'input'
  label_feature = label_feature or 'label'
  input_batch.update(
      source_tokenizer(examples[input_feature], padding="max_length",
                max_length=max_input_seq_len, return_tensors="pt", truncation=True))
  labels = target_tokenizer(examples[label_feature], padding="max_length",
                     max_length=max_output_seq_len, return_tensors="pt", truncation=True).input_ids
  labels[labels == target_tokenizer.pad_token_id] = -100
  input_batch['labels'] = labels
  if extra_feature_to_tokenize:
    for feat in extra_feature_to_tokenize:
      tokenized_feat = source_tokenizer(
          examples[feat], padding="max_length", max_length=max_input_seq_len,
          return_tensors="pt", truncation=True)
      input_batch[f'{feat}_ids'] = tokenized_feat.input_ids
      input_batch[f'{feat}_attention_mask'] = tokenized_feat.attention_mask
  return input_batch


def parse_tsv_line(line, feature_to_column):
  parsed = line['text'].strip().split('\t')
  return {k: parsed[v] for k, v in feature_to_column.items()}


def gen_seq2seq_dataset_from_tsv(split_to_files, feature_to_column, max_seq_len,
                                 parse_fn=None, extra_feature_to_tokenize=None,
                                 source_tokenizer=None, target_tokenizer=None):
  max_input_seq_len, max_output_seq_len = max_seq_len
  if parse_fn is None:
    parse_fn = parse_tsv_line
  dataset = load_dataset("text", data_files=split_to_files)
  dataset = dataset.map(lambda x: parse_fn(x, feature_to_column))
  print(dataset)
  # print examples
  for split in dataset:
    for i in range(3):
      print('%s split example %d:' % (split.upper(), (i + 1)))
      print('input: %s' % dataset[split]['input'][i])
      print('output: %s' % dataset[split]['label'][i])
  tokenized_datasets = dataset.map(
      lambda examples: preproc_tokenize(
          examples, max_input_seq_len, max_output_seq_len,
          extra_feature_to_tokenize,
          source_tokenizer=source_tokenizer,
          target_tokenizer=target_tokenizer),
      batched=True)
  removed_text_columns = ['input', 'label', 'text'] + (extra_feature_to_tokenize or [])
  tokenized_datasets = tokenized_datasets.remove_columns(removed_text_columns)
  tokenized_datasets = tokenized_datasets.with_format("torch")
  return tokenized_datasets


def gen_seq2seq_text_dataset_from_tsv(dataset_split_to_path,
                                      parse_fn,
                                      parse_fn_args=None,
                                      keep_in_memory=None):
  """Generate raw text dataset without tokenization."""
  dataset = load_dataset("text", data_files=dataset_split_to_path,
                         keep_in_memory=keep_in_memory)
  if parse_fn_args:
    dataset = dataset.map(lambda x: parse_fn(x, parse_fn_args))
  else:
    dataset = dataset.map(parse_fn)
  print(dataset)
  # print examples
  for split in dataset:
    input_keys = [k for k in dataset[split].features.keys() if 'input' in k]
    label_keys = [k for k in dataset[split].features.keys() if 'label' in k]
    if not input_keys or not label_keys:
      continue
    input_key = 'input' if 'input' in input_keys else input_keys[0]
    label_key = 'label' if 'label' in label_keys else label_keys[0]
    for i in range(3):
      print('%s split example %d:' % (split.upper(), (i + 1)))
      print('%s: %s' % (input_key, dataset[split][input_key][i]))
      print('%s: %s' % (label_key, dataset[split][label_key][i]))
  removed_text_columns = ['text']
  dataset = dataset.remove_columns(removed_text_columns)
  return dataset

In [None]:
def load_datasets(task_name, feature_to_column=None,
                  splits=None, dataset_split_to_path=None,
                  source_tokenizer=None, target_tokenizer=None):
  if not feature_to_column:
    feature_to_column = {'input': 0, 'label': 1}
  if not dataset_split_to_path:
    splits = splits or ('train', 'val')
    dataset_split_to_path = {k: TASK_TO_DATASETS[task_name][k] for k in splits}
  max_seq_len = (TASK_TO_DATASETS[task_name]['seq_length']['max_src_token'],
                 TASK_TO_DATASETS[task_name]['seq_length']['max_trg_token'])
  datasets = gen_seq2seq_dataset_from_tsv(
      dataset_split_to_path, feature_to_column, max_seq_len,
      source_tokenizer=source_tokenizer, target_tokenizer=target_tokenizer)

  for key in datasets:
    print(datasets[key]['input_ids'][:2])
    print(datasets[key]['labels'][:2])
  return datasets, dataset_split_to_path

In [None]:
##@markdown IIT Examples TSV Data Loader

def parse_tsv_line_with_inv_example(line, iit_config):
  num_loc = ANNO_LEN * iit_config['max_num_inv']
  parsed = line['text'].strip('\n').split('\t')
  base_loc_offset = 3
  source_loc_offset = base_loc_offset + num_loc
  example = {'base_input': parsed[0],
             'source_input': parsed[1],
             'inv_label': parsed[2],
             'base_locations': list(map(int, parsed[base_loc_offset: source_loc_offset])),
             'source_locations': list(map(int, parsed[source_loc_offset: source_loc_offset + num_loc]))}
  return example


def preproc_inv_example_fn(inv_examples, max_seq_len, iit_config):
  max_input_seq_len, max_output_seq_len = max_seq_len
  target_tokenizer=iit_config['target_tokenizer'] if 'target_tokenizer' in iit_config else None
  tokenized_batch = preproc_tokenize(
      inv_examples, max_input_seq_len, max_output_seq_len,
      input_feature='base_input', label_feature='inv_label',
      extra_feature_to_tokenize=['source_input'],
      source_tokenizer=None,  target_tokenizer=target_tokenizer)
  # Store as 8-bit signed integers.
  tokenized_batch['base_locations'] = torch.CharTensor(tokenized_batch['base_locations'])
  tokenized_batch['source_locations'] = torch.CharTensor(tokenized_batch['source_locations'])
  if 'inv_values' in tokenized_batch:
    tokenized_batch['inv_values'] = torch.CharTensor(tokenized_batch['inv_values'])
    tokenized_batch['inv_value_locations'] = torch.CharTensor(tokenized_batch['inv_value_locations'])
  return tokenized_batch


eval_inv_annotated_feature_keys = set([
    'labels',
    'base_locations', 'source_locations',
    'source_input_ids', 'base_input_ids',
    'base_attention_mask', 'source_attention_mask',
    'inv_values', 'inv_value_locations'])


def gen_iit_dataset_from_tsv(task_name):
  dataset_split_to_path = {'train': TASK_TO_IIT_DATA[task_name]['iit_train']}
  max_seq_len = (TASK_TO_DATASETS[task_name]['seq_length']['max_src_token'],
                 TASK_TO_DATASETS[task_name]['seq_length']['max_trg_token'])
  iit_config = TASK_TO_IIT_DATA[task_name]['iit_config']

  datasets.config.IN_MEMORY_MAX_SIZE = 1024**3  # 1G
  text_datasets = gen_seq2seq_text_dataset_from_tsv(
      dataset_split_to_path,
      parse_fn_args=iit_config,
      parse_fn=parse_tsv_line_with_inv_example,
      keep_in_memory=True)

  tokenized_datasets = text_datasets.map(
      lambda x: preproc_inv_example_fn(x, max_seq_len, iit_config), batched=True)
  tokenized_datasets = tokenized_datasets.rename_columns(
      {'input_ids': 'base_input_ids',
       'attention_mask': 'base_attention_mask',
       'source_input_attention_mask': 'source_attention_mask'})
  all_features = tokenized_datasets[list(tokenized_datasets.keys())[0]].features
  removed_text_columns = [
      k for k in all_features if k not in eval_inv_annotated_feature_keys]
  tokenized_datasets = tokenized_datasets.remove_columns(removed_text_columns)
  tokenized_datasets = tokenized_datasets.with_format("torch")
  print(tokenized_datasets)
  # print examples
  for split in tokenized_datasets:
    for i in range(10):
      print('%s split example %d:' % (split.upper(), (i + 1)))
      print('base_input: %s' % tokenizer.decode(
          tokenized_datasets[split]['base_input_ids'][i], skip_special_tokens=True))
      print('source_input: %s' % tokenizer.decode(
          tokenized_datasets[split]['source_input_ids'][i], skip_special_tokens=True))
      print('labels: %s' % tokenizer.decode(
          torch.maximum(torch.zeros(1), tokenized_datasets[split]['labels'][i]),
          skip_special_tokens=True))
      print('label tokens:', tokenized_datasets[split]['labels'][i].tolist())
  return tokenized_datasets

## Causal Abstraction

In [None]:
#@markdown `t5stack_forward_pre_block`
def t5stack_forward_pre_block(
        t5stack,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
  # Model parallel
  if t5stack.model_parallel:
      torch.cuda.set_device(t5stack.first_device)
      t5stack.embed_tokens = t5stack.embed_tokens.to(t5stack.first_device)
  use_cache = use_cache if use_cache is not None else t5stack.config.use_cache
  output_attentions = output_attentions if output_attentions is not None else t5stack.config.output_attentions
  output_hidden_states = (
      output_hidden_states if output_hidden_states is not None else t5stack.config.output_hidden_states
  )
  return_dict = return_dict if return_dict is not None else t5stack.config.use_return_dict

  if input_ids is not None and inputs_embeds is not None:
      err_msg_prefix = "decoder_" if t5stack.is_decoder else ""
      raise ValueError(
          f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
      )
  elif input_ids is not None:
      input_shape = input_ids.size()
      input_ids = input_ids.view(-1, input_shape[-1])
  elif inputs_embeds is not None:
      input_shape = inputs_embeds.size()[:-1]
  else:
      err_msg_prefix = "decoder_" if t5stack.is_decoder else ""
      raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")

  if inputs_embeds is None:
      assert t5stack.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
      inputs_embeds = t5stack.embed_tokens(input_ids)

  batch_size, seq_length = input_shape

  # required mask seq length can be calculated via length of past
  mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length

  if use_cache is True:
      assert t5stack.is_decoder, f"`use_cache` can only be set to `True` if {t5stack} is used as a decoder"

  if attention_mask is None:
      attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  if t5stack.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
      encoder_seq_length = encoder_hidden_states.shape[1]
      encoder_attention_mask = torch.ones(
          batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
      )

  # initialize past_key_values with `None` if past does not exist
  if past_key_values is None:
      past_key_values = [None] * len(t5stack.block)

  # We can provide a t5stack-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  # ourselves in which case we just need to make it broadcastable to all heads.
  extended_attention_mask = t5stack.get_extended_attention_mask(attention_mask, input_shape)

  # If a 2D or 3D attention mask is provided for the cross-attention
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  if t5stack.is_decoder and encoder_hidden_states is not None:
      encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
      encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
      if encoder_attention_mask is None:
          encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
      encoder_extended_attention_mask = t5stack.invert_attention_mask(encoder_attention_mask)
  else:
      encoder_extended_attention_mask = None

  # Prepare head mask if needed
  head_mask = t5stack.get_head_mask(head_mask, t5stack.config.num_layers)
  cross_attn_head_mask = t5stack.get_head_mask(cross_attn_head_mask, t5stack.config.num_layers)
  position_bias = None
  encoder_decoder_position_bias = None

  hidden_states = t5stack.dropout(inputs_embeds)

  # Initialize accumulating variables.
  all_hidden_states = () if output_hidden_states else None
  present_key_value_states = () if use_cache else None
  all_attentions = () if output_attentions else None
  all_cross_attentions = () if (output_attentions and t5stack.is_decoder) else None

  return {'hidden_states': hidden_states,
          'encoder_hidden_states': encoder_hidden_states,
          'encoder_attention_mask': encoder_attention_mask,
          'head_mask': head_mask,
          'cross_attn_head_mask': cross_attn_head_mask,
          'past_key_values': past_key_values,
          'position_bias': position_bias,
          'encoder_decoder_position_bias': encoder_decoder_position_bias,
          'extended_attention_mask': extended_attention_mask,
          'encoder_extended_attention_mask': encoder_extended_attention_mask,
          # Parsed parameters.
          'use_cache': use_cache,
          'output_hidden_states': output_hidden_states,
          'output_attentions': output_attentions,
          'return_dict': return_dict,
          # Accumulating vars.
          'all_hidden_states': all_hidden_states,
          'present_key_value_states': present_key_value_states,
          'all_attentions': all_attentions,
          'all_cross_attentions': all_cross_attentions}
# Test
# model.eval()
# pre_block_outputs = t5stack_forward_pre_block(
#     model.encoder,
#     input_ids=test_input_batch['input_ids'],
#     attention_mask=test_input_batch['attention_mask'])
# print(pre_block_outputs['hidden_states'].shape)
# print(pre_block_outputs['hidden_states'][:1, :3, :5])

In [None]:
#@markdown `t5stack_forward_single_layer_in_block`
def t5stack_forward_single_layer_in_block(
        t5stack,
        layer_index,
        hidden_states,
        encoder_hidden_states,
        encoder_attention_mask,
        head_mask,
        cross_attn_head_mask,
        past_key_values,
        position_bias,
        encoder_decoder_position_bias,
        extended_attention_mask,
        encoder_extended_attention_mask,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        all_hidden_states=None,
        present_key_value_states=None,
        all_attentions=None,
        all_cross_attentions=None
    ):
  layer_module, past_key_value = t5stack.block[layer_index], past_key_values[layer_index]
  layer_head_mask = head_mask[layer_index]
  cross_attn_layer_head_mask = cross_attn_head_mask[layer_index]
  # Model parallel
  if t5stack.model_parallel:
      torch.cuda.set_device(hidden_states.device)
      # Ensure that attention_mask is always on the same device as hidden_states
      if attention_mask is not None:
          attention_mask = attention_mask.to(hidden_states.device)
      if position_bias is not None:
          position_bias = position_bias.to(hidden_states.device)
      if encoder_hidden_states is not None:
          encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
      if encoder_extended_attention_mask is not None:
          encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
      if encoder_decoder_position_bias is not None:
          encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
      if layer_head_mask is not None:
          layer_head_mask = layer_head_mask.to(hidden_states.device)
      if cross_attn_layer_head_mask is not None:
          cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
  # Update accumulating variables.
  if output_hidden_states:
    all_hidden_states = all_hidden_states + (hidden_states,)

  if t5stack.gradient_checkpointing and t5stack.training:
      if use_cache:
          logger.warning(
              "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
          )
          use_cache = False

      def create_custom_forward(module):
          def custom_forward(*inputs):
              return tuple(module(*inputs, use_cache, output_attentions))

          return custom_forward

      layer_outputs = checkpoint(
          create_custom_forward(layer_module),
          hidden_states,
          extended_attention_mask,
          position_bias,
          encoder_hidden_states,
          encoder_extended_attention_mask,
          encoder_decoder_position_bias,
          layer_head_mask,
          cross_attn_layer_head_mask,
          None,  # past_key_value is always None with gradient checkpointing
      )
  else:
      layer_outputs = layer_module(
          hidden_states,
          attention_mask=extended_attention_mask,
          position_bias=position_bias,
          encoder_hidden_states=encoder_hidden_states,
          encoder_attention_mask=encoder_extended_attention_mask,
          encoder_decoder_position_bias=encoder_decoder_position_bias,
          layer_head_mask=layer_head_mask,
          cross_attn_layer_head_mask=cross_attn_layer_head_mask,
          past_key_value=past_key_value,
          use_cache=use_cache,
          output_attentions=output_attentions,
      )

  # layer_outputs is a tuple with:
  # hidden-states, key-value-states, (t5stack-attention position bias), (t5stack-attention weights), (cross-attention position bias), (cross-attention weights)
  if use_cache is False:
      layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]

  hidden_states, present_key_value_state = layer_outputs[:2]

  # We share the position biases between the layers - the first layer store them
  # layer_outputs = hidden-states, key-value-states (t5stack-attention position bias), (t5stack-attention weights),
  # (cross-attention position bias), (cross-attention weights)
  position_bias = layer_outputs[2]
  if t5stack.is_decoder and encoder_hidden_states is not None:
      encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]

  # Model Parallel: If it's the last layer for that device, put things on the next device
  if t5stack.model_parallel:
      for k, v in t5stack.device_map.items():
          if i == v[-1] and "cuda:" + str(k) != t5stack.last_device:
              hidden_states = hidden_states.to("cuda:" + str(k + 1))

  # Update accumulating variables.
  if use_cache:
    present_key_value_states = present_key_value_states + (present_key_value_state,)
  if output_attentions:
    all_attentions = all_attentions + (layer_outputs[3],)
    if t5stack.is_decoder:
      all_cross_attentions = all_cross_attentions + (layer_outputs[5],)

  return {'hidden_states': hidden_states,
          'position_bias': position_bias,
          'present_key_value_states': present_key_value_states,
          'all_hidden_states': all_hidden_states,
          'all_attentions': all_attentions,
          'all_cross_attentions': all_cross_attentions,
          'output_hidden_states': output_hidden_states,
          'return_dict': return_dict,}

# Test
#pre_layer_outputs = pre_block_outputs.copy()
#for i in range(4):
#  layer_outputs = t5stack_forward_single_layer_in_block(
#          model.encoder,
#          i,
#          **pre_layer_outputs)
#  for k in pre_block_outputs:
#    if k in layer_outputs:
#      pre_layer_outputs[k] = layer_outputs[k]
#print(pre_layer_outputs['hidden_states'].shape)
#print(pre_layer_outputs['hidden_states'][:1, :3, :5])

In [None]:
#@markdown `t5stack_forward_post_block`

from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

def t5stack_forward_post_block(
        t5stack,
        hidden_states,
        all_hidden_states=None,
        all_attentions=None,
        all_cross_attentions=None,
        present_key_value_states=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs
    ):

  hidden_states = t5stack.final_layer_norm(hidden_states)
  hidden_states = t5stack.dropout(hidden_states)

  # Add last layer
  if output_hidden_states:
      all_hidden_states = all_hidden_states + (hidden_states,)

  if not return_dict:
      return tuple(
          v
          for v in [
              hidden_states,
              present_key_value_states,
              all_hidden_states,
              all_attentions,
              all_cross_attentions,
          ]
          if v is not None
      )
  return BaseModelOutputWithPastAndCrossAttentions(
      last_hidden_state=hidden_states,
      past_key_values=present_key_value_states,
      hidden_states=all_hidden_states,
      attentions=all_attentions,
      cross_attentions=all_cross_attentions,
  )


# Test
#block_outputs = t5stack_forward_post_block(
#        model.encoder,
#        **layer_outputs)
#
#print(block_outputs[0].shape)
#print(block_outputs[0][:1, :3, :5])

In [None]:
#@markdown encoder causal abstraction
# interchange intervention: [batch, layer, step, pos_start, pos_end]
# setting values: [batch, layer, step, pos_start, pos_end], index to some external embedding

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.t5.modeling_t5 import T5Stack, T5ForConditionalGeneration

# Need to override the encoder, as the generate call directly invoke the foward
# function of the encoder to compute encoder output.
# https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/generation_utils.py#L506


class TransformerEncoderCausalAbstraction(T5Stack):
  def __init__(self, transformer_encoder):
      super().__init__(transformer_encoder.config)
      # Store a copy of pretrained Encoder.
      self.encoder = transformer_encoder
      # Store interventions.
      self.encoder_inv_locations_to_values = None
      # Copy over all attributes.
      self.embed_tokens = transformer_encoder.embed_tokens
      self.is_decoder = transformer_encoder.config.is_decoder
      self.block = transformer_encoder.block
      self.final_layer_norm = transformer_encoder.final_layer_norm
      self.dropout = transformer_encoder.dropout
      self.model_parallel = transformer_encoder.model_parallel
      self.device_map = transformer_encoder.device_map
      self.gradient_checkpointing = transformer_encoder.gradient_checkpointing

  def get_hidden_states(self, input_ids, attention_mask, locations, partial_only=True):
    """GetVals

      partial_only: If True, only run part of the Encoder upto the layer
                    requested in the locations.
    """
    # Encoder is a T5Stack (https://github.com/huggingface/transformers/blob/v4.22.2/src/transformers/models/t5/modeling_t5.py#L899)
    if partial_only:
      hidden_states = []
      max_layer = max(loc[1] for loc in locations)
      prev_layer_outputs = t5stack_forward_pre_block(
          self.encoder,
          input_ids=input_ids,
          attention_mask=attention_mask,
          inputs_embeds=None,
          head_mask=None,
          use_cache=None,
          output_hidden_states=True,
          output_attentions=None,
          return_dict=None)
      hidden_states.append(prev_layer_outputs['hidden_states'])
      # Run block, i.e. transformer layers.
      for layer_index in range(min(len(self.encoder.block), max_layer)):
        layer_outputs = t5stack_forward_single_layer_in_block(
                self.encoder, layer_index, **prev_layer_outputs)
        for k in prev_layer_outputs:
          if k in layer_outputs:
            prev_layer_outputs[k] = layer_outputs[k]
        hidden_states.append(prev_layer_outputs['hidden_states'])
      if max_layer > len(self.encoder.block):
        # Run post-block.
        block_outputs = t5stack_forward_post_block(
            self.encoder,
            **layer_outputs)
        hidden_states.append(block_outputs['hidden_states'])
    else:
      # Run full Encoder.
      outputs = self.encoder(input_ids=input_ids,
                             attention_mask=attention_mask,
                             output_hidden_states=True)
      # There are num_layer+1 hidden_states
      # The last layer is the normalized output
      # https://github.com/huggingface/transformers/blob/v4.22.2/src/transformers/models/t5/modeling_t5.py#L1087
      # hidden_states: (num_layer + 1) * B * input_seq_len * dimension
      hidden_states = outputs['hidden_states']
    return {loc: hidden_states[loc[1]][loc[0], loc[2], :] if len(loc) == 3 else
                 hidden_states[loc[1]][loc[0], loc[2], loc[3]:loc[4]]
            for loc in locations}

  def reset_interchange_interventions(self):
    self.encoder_inv_locations_to_values = None

  def set_interchange_interventions(
      self, inv_input_ids, inv_attention_mask, locations, base_locations=None,
      inv_values=None, inv_value_locations=None):
    """Set intervention parameters.

    There are two ways to provide intervention values:
      1) by setting source locations through locations
      2) by directly providing values through inv_values

    inv_input_ids: source input ids
    inv_attention_mask: source attention mask
    locations: locations to retrieve the source values of the internal variables
    base_locations: locations to swap the values of the internal variables with
                    source input. Default to the same locations as source.
                    If set, must have the same shape as source locations.
    inv_values: intervention values. If specified, inv_input_ids and
                inv_attention_mask will be ignored.
    inv_value_locations: locations to set to the intervention values.
    """
    assert locations is not None or inv_values is not None
    if locations is not None:
      assert len(locations) == len(base_locations)
    if inv_values is not None:
      assert len(inv_values) == len(inv_value_locations)
    self.encoder_inv_locations_to_values = (
        inv_input_ids, inv_attention_mask, locations, base_locations,
        inv_values, inv_value_locations)

  def get_interchange_interventions(self):
    inv_loc_to_values = {}
    if self.encoder_inv_locations_to_values is not None:
      (inv_input_ids, inv_attention_mask, source_locations, _, inv_values, _
          ) = self.encoder_inv_locations_to_values
      if source_locations is not None:
        inv_loc_to_values.update(
            self.get_hidden_states(inv_input_ids, inv_attention_mask, source_locations))
      if inv_values is not None:
        inv_loc_to_values.update({i: inv_values[i] for i in range(len(inv_values))})
    return inv_loc_to_values

  def forward_with_intervention(
      self, input_ids, attention_mask,
      inv_locations_to_values,
      encoder_hidden_states=None,
      encoder_attention_mask=None,
      inputs_embeds=None,
      head_mask=None,
      cross_attn_head_mask=None,
      past_key_values=None,
      use_cache=None, output_hidden_states=None,
      output_attentions=None, return_dict=None):
    """IntInv
    inv_locations_to_values: A dict of
      (batch_index, layer_index, step_index): Tensor(float32) or
      (batch_index, layer_index, step_index, dim_begin, dim_end): Tensor(float32)
      indicating the location of the representations to interchange with
      the intervention values.
    """
    # sort inv_locations by layers.
    sorted_loc = sorted(inv_locations_to_values)
    # Run pre-block, i.e. process inputs and embedding layers.
    prev_layer_outputs = t5stack_forward_pre_block(
      self.encoder,
      input_ids=input_ids,
      attention_mask=attention_mask,
      inputs_embeds=inputs_embeds,
      head_mask=head_mask,
      use_cache=use_cache,
      output_hidden_states=output_hidden_states,
      output_attentions=output_attentions,
      return_dict=return_dict)
    # Run block, i.e. transformer layers.
    for layer_index in range(len(self.encoder.block)):
      # Apply intervention on layer i step t BEFORE the hidden state update,
      # as the first set of hidden state returned by the model are the input
      # embeddings, not the outputs of the first layer.
      for loc in sorted_loc:
        b_i, l_i, s_i = loc[:3]
        if l_i > layer_index:
          break
        if l_i == layer_index:
          if len(loc) == 3:
            prev_layer_outputs['hidden_states'][b_i, s_i, :] = (
                inv_locations_to_values[(b_i, l_i, s_i)])
          else:
            prev_layer_outputs['hidden_states'][b_i, s_i, loc[3]:loc[4]] = (
                inv_locations_to_values[(b_i, l_i, s_i, loc[3], loc[4])])
      layer_outputs = t5stack_forward_single_layer_in_block(
              self.encoder, layer_index, **prev_layer_outputs)
      for k in prev_layer_outputs:
        if k in layer_outputs:
          prev_layer_outputs[k] = layer_outputs[k]
    # Run post-block.
    block_outputs = t5stack_forward_post_block(
        self.encoder,
        **layer_outputs)
    # Apply intervention on the last layer output, which is past the layer norm.
    for loc in sorted_loc:
      b_i, l_i, s_i = loc[:3]
      if l_i == len(self.encoder.block):
        if len(loc) == 3:
          block_outputs['last_hidden_state'][b_i, s_i, :] = inv_locations_to_values[(b_i, l_i, s_i)]
        else:
          #print(f'INV at {(b_i, l_i, s_i, loc[3], loc[4])}')
          #print(block_outputs['last_hidden_state'][b_i, s_i, loc[3]:loc[4]])
          block_outputs['last_hidden_state'][b_i, s_i, loc[3]:loc[4]] = (
              inv_locations_to_values[(b_i, l_i, s_i, loc[3], loc[4])])
          #print(inv_locations_to_values[(b_i, l_i, s_i, loc[3], loc[4])])
    # Update all hidden_states after updating the last_hidden_state.
    if 'hidden_states' in block_outputs:
      block_outputs['hidden_states'] = (
          block_outputs['hidden_states'][:-1] + (block_outputs['last_hidden_state'],))
    return block_outputs

  def forward(
      self,
      input_ids=None,
      attention_mask=None,
      encoder_hidden_states=None,
      encoder_attention_mask=None,
      inputs_embeds=None,
      head_mask=None,
      cross_attn_head_mask=None,
      past_key_values=None,
      use_cache=None,
      output_attentions=None,
      output_hidden_states=None,
      return_dict=None):
    # The values are indexed by source_locations or index.
    source_locations_to_values = self.get_interchange_interventions()
    base_locations_to_values = {}
    # Update intervention locations if base and source use different sets of
    # locations.
    if source_locations_to_values:
      (_, _, source_locations, base_locations, _, inv_value_locations
       ) = self.encoder_inv_locations_to_values
      if base_locations is not None:
        base_locations_to_values.update({
            base_locations[i]: source_locations_to_values[source_locations[i]]
            for i in range(len(base_locations))})
      if inv_value_locations is not None:
        base_locations_to_values.update({
            inv_value_locations[i]: source_locations_to_values[i]
            for i in range(len(inv_value_locations))})
      # Set base and source with the same locations.
      if base_locations is None and inv_value_locations is None:
        base_locations_to_values.update(source_locations_to_values)
    return self.forward_with_intervention(
          inv_locations_to_values=base_locations_to_values,
          input_ids=input_ids,
          attention_mask=attention_mask,
          encoder_hidden_states=encoder_hidden_states,
          encoder_attention_mask=encoder_attention_mask,
          inputs_embeds=inputs_embeds,
          head_mask=head_mask,
          cross_attn_head_mask=cross_attn_head_mask,
          past_key_values=past_key_values,
          use_cache=use_cache,
          output_attentions=output_attentions,
          output_hidden_states=output_hidden_states,
          return_dict=return_dict,
      )

class TransformerCausalAbstraction(T5ForConditionalGeneration):
  def __init__(self, transformer):
      """
      Causal abstraction of Transformer models.
      """
      super().__init__(transformer.config)
      # Store a copy of the pretrained transformer.
      self.transformer = transformer
      # Copy over all the attributes.
      self.model_dim = self.transformer.model_dim
      self.shared = self.transformer.shared
      self.encoder = TransformerEncoderCausalAbstraction(
          self.transformer.encoder)
      self.decoder = self.transformer.decoder
      self.lm_head = self.transformer.lm_head
      self.model_parallel = self.transformer.model_parallel
      self.device_map = self.transformer.device_map

  def get_encoder(self):
    return self.encoder

## Training Utils

In [None]:
#@title IIT train step

def add_batch_layer_dim_location_map_fn(locations, iit_layer, feat_dim):
  # B * NUM_INV * ANNO_LEN
  locations = locations.view(locations.shape[0], -1, ANNO_LEN).int()
  locations[:, :, -2:] *= feat_dim
  locations = locations.tolist()
  locations = [tuple([b_i, iit_layer] + loc)
          # Flatten the locations
          for b_i in range(len(locations))
          for loc in locations[b_i] if loc[-1] >= 0]
  return locations

def iit_train_step(causal_abstraction, input_batch, iit_input_batch_list,
                   location_map_fn, base_factor=1.0, inv_factor=1.0):
  """IIT training step.

    location_map_fn: batch of location in input => mapped location in model representations
  """
  if isinstance(iit_input_batch_list, dict):
    iit_input_batch_list = [iit_input_batch_list]
  # base training
  causal_abstraction.encoder.reset_interchange_interventions()
  outputs = causal_abstraction(**input_batch)
  loss = base_factor * outputs.loss
  inv_outputs = None
  # iit training
  if iit_input_batch_list:
    iit_batch_size = len(iit_input_batch_list[0]["source_input_ids"])
    # For extracting intervention values from the source.
    iit_base_input_ids, iit_base_attention_mask, iit_labels = [], [], []
    iit_source_input_ids, iit_source_attention_mask = [], []
    iit_base_localtions, iit_source_locations = [], []
    # For directly provided intervention values.
    iit_inv_values, iit_inv_value_locations = [], []
    for i in range(len(iit_input_batch_list)):
      iit_source_input_ids.append(iit_input_batch_list[i]["source_input_ids"])
      iit_source_attention_mask.append(iit_input_batch_list[i]["source_attention_mask"])
      base_locations = location_map_fn(iit_input_batch_list[i]["base_locations"])
      source_locations = location_map_fn(iit_input_batch_list[i]["source_locations"])
      iit_base_localtions.extend(base_locations)
      iit_source_locations.extend(source_locations)
      iit_base_input_ids.append(iit_input_batch_list[i]["base_input_ids"])
      iit_base_attention_mask.append(iit_input_batch_list[i]["base_attention_mask"])
      iit_labels.append(iit_input_batch_list[i]["labels"])
    iit_base_input_ids = torch.cat(iit_base_input_ids, axis=0).to(device)
    iit_base_attention_mask = torch.cat(iit_base_attention_mask, axis=0).to(device)
    iit_labels = torch.cat(iit_labels, axis=0).to(device)
    inv_outputs = causal_abstraction(
        input_ids=iit_base_input_ids,
        attention_mask=iit_base_attention_mask,
        labels=iit_labels)
    causal_abstraction.encoder.reset_interchange_interventions()
    loss += inv_factor * inv_outputs.loss
  return loss, outputs, inv_outputs

In [None]:
#@markdown Training metrics

from torch.nn import CrossEntropyLoss
from transformers import TrainingArguments, Trainer

eval_loss_fct = CrossEntropyLoss(ignore_index=-100)

def compute_metrics(eval_pred):
  # accuracy
  outputs, labels = eval_pred
  # outputs = [logits, hidden_states]
  logits = outputs[0]
  # numpy array
  predictions = np.argmax(logits, axis=-1)
  mask = ((labels != -100) & (labels != 1) & (labels != 0))
  mask = mask.astype(float)
  # loss
  loss = eval_loss_fct(
      torch.from_numpy(logits.reshape(-1, logits.shape[-1])).to(device),
      torch.from_numpy(labels.reshape(-1)).to(device)).mean()
  loss = float(loss.detach().cpu().numpy())
  return {
        'token_accuracy': float(
            ((predictions == labels).astype(float) * mask).sum() / max(1, mask.sum())),
        'sequence_accuracy': (
            ((predictions == labels) | ~mask.astype(bool)).all(axis=-1).astype(float).sum() / len(mask)),
        'loss': loss}

In [None]:
#@markdown Other training utils

def freeze_encoder_first_n_layers(model, n):
  for param in model.encoder.parameters():
    param.requires_grad = False
  # Unfreeze from n to end
  for i in range(n, len(model.encoder.block)):
    for param in model.encoder.block[i].parameters():
      param.requires_grad = True
  for param in model.encoder.final_layer_norm.parameters():
    param.requires_grad = True
  return model

def train_step(model, input_batch):
  outputs = model(**input_batch)
  return outputs.loss, outputs

def eval_step(model, input_batch):
  outputs = model(**input_batch)
  return outputs.loss, outputs

def run_eval(model, eval_dataloader, first_n_batch=None, metric_prefix=None):
  if metric_prefix is None:
    metric_prefix = ''
  model.eval()
  eval_metrics = collections.defaultdict(list)
  for i, eval_input_batch in enumerate(eval_dataloader):
    if first_n_batch is not None and i > first_n_batch:
      break
    for k in eval_input_batch:
      eval_input_batch[k] = eval_input_batch[k].to(device)
    _, outputs = eval_step(model, eval_input_batch)
    metrics = compute_metrics(([outputs.logits.detach().cpu().numpy()],
                                eval_input_batch['labels'].cpu().numpy()))
    for key in metrics:
      eval_metrics[key].append(metrics[key])
  for key in eval_metrics:
    eval_metrics[metric_prefix + key] = float(np.array(eval_metrics[key]).mean())
  return eval_metrics

## Evaluation Utils

In [None]:
#@markdown Evaluation metrics

import collections
import json

from torch.utils.data import DataLoader

ENGLISH_WORDS_30K = json.load(open(os.path.join(DATA_DIR, 'english_words_30k.json')))
ENGLISH_WORDS_200K = json.load(open(os.path.join(DATA_DIR, 'english_words_200k.json')))
ANAGRAM_DICT = json.load(open(os.path.join(DATA_DIR, 'anagrams_from_200k.json')))
AMBIGUOUS_TYPOS = json.load(open(os.path.join(DATA_DIR, 'ambiguous_typos_18k.json')))

def print_stats(eval_outputs):
  stats, metrics = eval_outputs
  total = stats['count']
  print('String-level accuracy %.2f%%' % (100 * stats['match_fullstring'] / total))
  print('Token-level accuracy  %.2f%%' % (100 * np.mean(metrics['token_accuracy'])))

  print('\nRelaxed matching:')
  for key, val in sorted(stats.items()):
    if key.startswith('match_') and not key.startswith('match_fullstring'):
      print('%s +%.2f%%' % (key, 100 * val / total))


def compute_matching_stats(input_text, target_text, output_text):
  def space_normalize(text):
    return text.strip().replace(' ', '')
  input_norm = space_normalize(input_text)
  output_norm = space_normalize(output_text)
  target_norm = space_normalize(target_text)
  stats = {}
  if len(output_text) < 20:
    stats['match_dictionary_30k'] = int(output_norm.lower() in ENGLISH_WORDS_30K)
    stats['match_dictionary_200k'] = int(output_norm.lower() in ENGLISH_WORDS_200K)
  if output_text.strip() == target_text.strip():
    stats['match_fullstring'] = 1
  elif input_text in AMBIGUOUS_TYPOS and output_text in AMBIGUOUS_TYPOS[input_text]:
    stats['match_spelling_correction'] = 1
  elif output_text in target_text or target_text in output_text:
    stats['match_target_substring'] = 1
  elif output_norm in input_norm:
    stats['match_input_substring'] = 1
  elif sorted(output_norm) == sorted(target_norm):
    valid = 'valid' if output_norm in ANAGRAM_DICT else 'invalid'
    stats[f'match_anagram_{valid}'] = 1
  elif set(output_text) == set(target_text):
    stats['match_char_set'] = 1
  return stats


def eval_topk_accuracy(model, test_dataset, max_output_seq_len,
                       first_n=None, beam_size=None, batch_size=128,
                       target_tokenizer=None):
  target_tokenizer = target_tokenizer or t5_default_tokenizer
  if isinstance(test_dataset, dict):
    test_dataset = Dataset.from_dict(test_dataset).with_format("torch")
  eval_dataloader = DataLoader(test_dataset, batch_size=batch_size)
  model.eval()
  stats = collections.defaultdict(int)
  metrics = collections.defaultdict(list)
  inputs_to_outputs = {}
  for b_i, batch in enumerate(eval_dataloader):
    if first_n and b_i >= first_n:
      break
    if b_i % 10 == 0:
      print(f'Finished {b_i} batch')
    for k in batch:
      batch[k] = batch[k].to(device)
    input_texts = target_tokenizer.batch_decode(
      batch['input_ids'], skip_special_tokens=True)
    target_texts = target_tokenizer.batch_decode(
      torch.maximum(batch['labels'], torch.zeros_like(batch['labels'])),
      skip_special_tokens=True)
    with torch.no_grad():
      predictions = model(**batch)
      if beam_size:
        outputs = model.generate(input_ids=batch['input_ids'],
                                 attention_mask=batch['attention_mask'],
                                 num_beams=beam_size,
                                 num_return_sequences=beam_size,
                                 do_sample=False,
                                 max_length=max_output_seq_len,
                                 length_penalty=0.05)  # (B*K, MAX_OUTPUT_SEQ_LEN)
      else:
        beam_size = 1
        outputs = model.generate(input_ids=batch['input_ids'],
                                 attention_mask=batch['attention_mask'],
                                 max_length=max_output_seq_len)
      output_texts = target_tokenizer.batch_decode(
          outputs, skip_special_tokens=True)
      for i in range(len(input_texts)):
        output_text = output_texts[i * beam_size]
        inputs_to_outputs[(input_texts[i], target_texts[i])] = output_text
        stats['count'] += 1
        matching_stats = compute_matching_stats(
            input_texts[i], target_texts[i], output_text)
        for k, v in matching_stats.items():
          stats[k] += v
        # Log the first few examples
        if b_i < 10 and i == 0:
          print('Input: %s' % input_texts[i])
          print('Label: %s' % target_texts[i])
          print('Pred:  %s' % output_text)
    batch_metrics = compute_metrics(([predictions.logits.detach().cpu().numpy()],
                                     batch['labels'].detach().cpu().numpy()))
    for k, v in batch_metrics.items():
      metrics[k].append(v)
  print_stats((stats, metrics))
  return (stats, metrics), inputs_to_outputs

In [None]:
#@markdown Extract character representations

import string

def parse_annotated_examples(anno_tsv_path, first_n_line, parse_fn, parse_fn_args):
  with open(anno_tsv_path, 'r') as f:
    lines = f.readlines()[:first_n_line]
    examples = list(map(lambda x: parse_fn(x, parse_fn_args), lines))
  return examples

def extract_char_representations(model, anno_tsv_path, data_config,
                                 first_n_line=2048,
                                 positions=None, iit_layer=1, char_dim=16,
                                 eval_batch_size = 128,
                                 source_tokenizer=None):
  if not source_tokenizer:
    source_tokenizer = t5_default_tokenizer
  # parse the annotations
  examples = parse_annotated_examples(anno_tsv_path, first_n_line=first_n_line,
                                      parse_fn=parse_anno_line,
                                      parse_fn_args=data_config)
  # return a dict of char, position to list of vectors.
  model.eval()
  char_to_vec = collections.defaultdict(lambda: collections.defaultdict(list))
  # Avoid having the same token multiple times.
  dedup_keys = {}
  for step in range(first_n_line // eval_batch_size):
    input_text = [example['input'] for example in examples[step * eval_batch_size: (step + 1) * eval_batch_size]]
    input_batch = source_tokenizer(
        input_text, return_tensors="pt", padding="max_length",
        max_length=data_config['max_input_seq_len'], truncation=True)
    for key in input_batch:
      input_batch[key] = input_batch[key].to(device)
    with torch.no_grad():
      enc_outputs = model.encoder(**input_batch, output_hidden_states=True)
      for b_i in range(len(input_text)):
        for loc in examples[step * eval_batch_size + b_i]['inv_locs']:
          if loc[-1] < 0:
            continue
          if positions and loc[1] not in positions:
            continue
          tid = int(input_batch['input_ids'][b_i][loc[0]])
          c = VOCAB[tid][loc[1]: loc[2]]
          if (tid, loc[1]) in dedup_keys:
            continue
          dedup_keys[(tid, loc[1])] = True
          char_to_vec[c][loc[1]].append(
              enc_outputs['hidden_states'][iit_layer][
                  b_i, loc[0], loc[1] * char_dim : loc[2] * char_dim].detach().cpu().numpy())
  return {k: dict(v) for k, v in char_to_vec.items()}

In [None]:
#@markdown Visualization of char representations

import numpy as np

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

char_colors = [
    ('bisque', '#FFE4C4'), # a
    ('blue', '#0000FF'),
    ('blueviolet', '#8A2BE2'),
    ('brown', '#A52A2A'),
    ('burlywood', '#DEB887'), # e
    ('cadetblue', '#5F9EA0'),
    ('chartreuse', '#7FFF00'),
    ('chocolate', '#D2691E'),
    ('coral', '#FF7F50'), # i
    ('cornflowerblue', '#6495ED'),
    ('crimson', '#DC143C'),
    ('cyan', '#00FFFF'),
    ('deepskyblue', '#00BFFF'),
    ('darkgoldenrod', '#B8860B'),
    ('darkorange', '#FF8C00'), # o
    ('darkgreen', '#006400'),
    ('darkmagenta', '#8B008B'),
    ('darkolivegreen', '#556B2F'),
    ('darkorchid', '#9932CC'),
    ('darksalmon', '#E9967A'),
    ('gold', '#FFD700'), # u
    ('darkseagreen', '#8FBC8F'),
    ('darkslateblue', '#483D8B'),
    ('forestgreen', '#228B22'),
    ('deeppink', '#FF1493'),
    ('goldenrod', '#DAA520'),]


def lighter_color(c):
  rgb = sum([int(128 + eval(f'0x{c[i:i+2]}') * 0.5) << (4 * (5 - i))
             for i in range(1,6,2)])
  return f'{rgb:#08x}'.strip('-').replace('0x', '#')

char_colors = char_colors + [
    (f'light_{c[0]}', lighter_color(c[1])) for c in char_colors]


# PCA & cluster
def visualize_pca_2d(char_to_vec, task_name):
  labels = [(k, p) for k in char_to_vec for p in char_to_vec[k]
            for _ in range(len(char_to_vec[k][p]))]
  X = np.concatenate([np.stack(char_to_vec[k][p], axis=0)
            for k in char_to_vec for p in char_to_vec[k]], axis=0)
  print(f'PCA over {len(labels)} vectors.')
  X = (X - np.mean(X, axis=0, keepdims=True)) / X.var(axis=0)**0.5
  pca = PCA(n_components=2)
  pca.fit(X)
  print(f'explained_variance_ratio={pca.explained_variance_ratio_}')

  # Plot.
  figsize = (8, 5)
  plt.rcParams['figure.dpi'] = 300
  plt.rcParams['savefig.dpi'] = 300
  plt.rc('font', **{'size': 6})
  plt.figure(figsize=figsize)

  colors = [v for k, v in char_colors]
  X_2d = pca.transform(X)
  fig, ax = plt.subplots()
  pc_index = [0, 1]
  x = X_2d[:, pc_index[0]]
  y = X_2d[:, pc_index[1]]
  char_to_color_index = lambda c: (
      ord(c) - ord('a') if c in string.ascii_lowercase else 26 + (ord(c) - ord('A'))) % len(colors)
  ax.scatter(x=x, y=y, c=[colors[char_to_color_index(labels[i][0])] for i in range(len(labels))])

  anno_char = collections.defaultdict(list)
  for i, txt in enumerate(labels):
    anno_char[txt[0]].append([x[i], y[i]])
  for char, coords in anno_char.items():
    ax.annotate(char, np.array(coords).mean(axis=0) + np.array([0.01, 0]), fontsize=12, weight='bold', color='white')
    ax.annotate(char, np.array(coords).mean(axis=0) - np.array([0.03, 0]), fontsize=12, weight='bold', color='white')
    ax.annotate(char, np.array(coords).mean(axis=0) + np.array([0, 0.03]), fontsize=12, weight='bold', color='white')
    ax.annotate(char, np.array(coords).mean(axis=0) - np.array([0, 0.03]), fontsize=12, weight='bold', color='white')
    ax.annotate(char, np.array(coords).mean(axis=0), fontsize=12)
 # task = 'Spelling Correction with Context'
  #task = 'Word Search'
  #task = 'Spelling Correction'
  #task = 'Unscramble'
  #task = 'Unit Conversion'
  task = 'Reversal'
  plt.title(f'Character Representations from {task} IIT Model (Encoder Layer 1)')
  plt.xlabel(f'Principal Component {pc_index[0] + 1}')
  plt.ylabel(f'Principal Component {pc_index[1] + 1}')

In [None]:
#@markdown Utils for processing IIT data for evaluation

def parse_tsv_line_with_inv_anno(line, feature_to_column):
  if not isinstance(line, str):
    line = line['text']
  parsed = line.strip().split('\t')
  return {k: parsed[v] if 'inv' not in k or
                          'inv_val' in k or
                          'inv_label' in k
              else int(parsed[v])
          for k, v in feature_to_column.items()}


def parse_tsv_line_with_length_preserving_interchange_inv_anno(line, args):
  feature_to_column, length_indexed_base_examples = args
  parsed = line['text'].strip().split('\t')
  # replace the input inv locations with ones from a in-vocab word of the
  # same length.
  if parsed[2] == 'unit_conversion':
    key = (parsed[0].replace(parsed[3], '{feature}'), len(parsed[3]))
  else:
    key = len(parsed[3])
  candidates = length_indexed_base_examples[key]
  inv_id_examples = random.choice(candidates) if candidates else None
  example = {k: parsed[v] if 'inv' not in k or 'inv_val' in k or 'inv_label' in k
                          else int(parsed[v])
                          for k, v in feature_to_column.items()}
  # If can't find an equivalent input, return the example itself.
  if inv_id_examples is None:
    print('No substitution')
    inv_id_examples = example
  # Keep the inv_val matches the actual value, but update the locations and
  # input to match in-domain intervention example.
  example['input'] = example['input'].replace(
      example['feature'], inv_id_examples['feature'])
  for k in example:
    if 'inv_loc' in k:
      example[k] = inv_id_examples[k]
  return example


def index_examples_by_length(
    anno_input_tsv_path, anno_feature_to_column, first_n_line=10240):
  print(anno_input_tsv_path)
  examples = parse_annotated_examples(
      anno_input_tsv_path, first_n_line=first_n_line,
      parse_fn=parse_tsv_line_with_inv_anno,
      parse_fn_args=anno_feature_to_column)
  length_indexed_examples = collections.defaultdict(list)
  # Index by input length.
  for exp in examples:
    length_indexed_examples[len(exp['feature'])].append(exp)
  print(f'Indexed {len(examples)} examples into {len(length_indexed_examples)} keys.')
  return length_indexed_examples

In [None]:
#@markdown Eval interchange intervention accuracy

alphanum = set(string.ascii_letters + string.digits)

def eval_with_interchange_interventions(
    model, eval_dataloader, data_config,
    iit_layer=1, iit_dim=16, num_beams=None, num_return_sequences=1):
  model.eval()
  causal_abstraction = TransformerCausalAbstraction(model).to(device)
  causal_abstraction.eval()

  eval_outputs = {}
  metrics = collections.defaultdict(int)
  inputs_to_outputs = {}
  for step, batch in enumerate(eval_dataloader):
    for key in batch:
      if key.startswith('inv_') and not isinstance(batch[key], list):
        batch[key] = batch[key].tolist()
    base_locations = [
        (b_i, iit_layer,
         batch[f'inv_loc_{inv_i}'][b_i],
         batch[f'inv_loc_char_begin_{inv_i}'][b_i] * 16, # iit_dim
         batch[f'inv_loc_char_end_{inv_i}'][b_i] * 16)   # iit_dim
        for b_i in range(len(batch['input']))
        for inv_i in range(data_config['max_num_anno'])
        if batch[f'inv_loc_{inv_i}'][b_i] >= 0 and
           batch[f'inv_val_{inv_i}'][b_i] in alphanum]
    inv_values = [torch.mean(torch.from_numpy(
              np.stack([v for vs in char_to_vec[c].values() for v in vs], axis=0)), dim=0).to(device)
              for b_i in range(len(batch['input']))
              for inv_i, c in enumerate([
                  batch[f'inv_val_{i}'][b_i] for i in range(data_config['max_num_anno'])
              if batch[f'inv_loc_{i}'][b_i] >= 0 and
                  batch[f'inv_val_{i}'][b_i] in alphanum])]

    assert len(base_locations) == len(inv_values)
    base_input_batch = tokenizer(batch['input'], return_tensors='pt', padding="max_length",
                max_length=data_config['max_input_seq_len'], truncation=True)
    for key in base_input_batch:
      base_input_batch[key] = base_input_batch[key].to(device)

    causal_abstraction.encoder.set_interchange_interventions(
          None, None, None, None,
          inv_values=inv_values, inv_value_locations=base_locations)
    interventions = causal_abstraction.encoder.get_interchange_interventions()

    inv_outputs = causal_abstraction.generate(
        input_ids=base_input_batch['input_ids'],
        attention_mask=base_input_batch['attention_mask'],
        max_length=data_config['max_output_seq_len'],
        num_beams=num_beams, num_return_sequences=num_return_sequences)
    causal_abstraction.encoder.reset_interchange_interventions()
    pred_text_batch = tokenizer.batch_decode(inv_outputs, skip_special_tokens=True)

    base_outputs = model.generate(
        input_ids=base_input_batch['input_ids'],
        attention_mask=base_input_batch['attention_mask'],
        max_length=data_config['max_output_seq_len'],
        num_beams=num_beams, num_return_sequences=num_return_sequences)
    base_pred_text_batch = tokenizer.batch_decode(base_outputs, skip_special_tokens=True)
    for b_i in range(len(batch['input'])):
      source_input_text = batch['feature_val'][b_i]
      label_text = batch['label'][b_i]
      iit_top1_pred = pred_text_batch[b_i * num_return_sequences]
      iit_topk_preds = pred_text_batch[b_i * num_return_sequences: (b_i + 1) * num_return_sequences]
      base_top1_pred = base_pred_text_batch[b_i * num_return_sequences]
      base_topk_preds = base_pred_text_batch[b_i * num_return_sequences: (b_i + 1) * num_return_sequences]
      iit_match_stats = compute_matching_stats(source_input_text, label_text, iit_top1_pred)
      for k, v in iit_match_stats.items():
        metrics[f'iit_{k}'] += v
      if label_text in iit_topk_preds:
        metrics['iit_match_top%d' % num_return_sequences] += 1
      base_match_stats = compute_matching_stats(source_input_text, label_text, base_top1_pred)
      for k, v in base_match_stats.items():
        metrics[f'base_{k}'] += v
      if label_text in base_topk_preds:
        metrics['base_match_top%d' % num_return_sequences] += 1
      eval_outputs[(batch['input'][b_i], batch['label'][b_i])] = {
          'base': base_topk_preds,
          'iit': iit_topk_preds,
          'source_val': batch['feature_val'][b_i]}
      metrics['total'] += 1
  print(dict(metrics))
  return metrics, eval_outputs


def eval_with_clustered_representations(
    model, test_iit_anno_tsv_path, data_config, char_to_vec,
    base_iit_anno_tsv_path=None,
    iit_layer=1, iit_dim=16, batch_size=128, num_beams=None,
    num_return_sequences=1, num_repeat=1):
  """Evaluate IIT models with clustered character representations.

    There are two ways to use the clustered representations, controlled by the
    parameter `base_iit_anno_tsv_path`:
      1) Directly replace aligned representations of each char with the
        average pooled clustered representation, when `base_iit_anno_tsv_path`
        is set to None.
      2) Apply interchange interventions on a randomly sampled training examples,
         where the intervention values are from the test examples, when
         `base_iit_anno_tsv_path` is set to the path of the iit annotations.
  """

  # Load the test dataset and apply interchange interventions on the base
  # examples if `base_iit_anno_tsv_path` is provided.
  anno_feature_to_column = {'input': 0, 'label': 1, 'task': 2, 'feature_val': 3}
  anno_offset = len(anno_feature_to_column)
  anno_feats = ['inv_loc', 'inv_loc_char_begin', 'inv_loc_char_end',
                'inv_val', 'inv_out_begin', 'inv_out_end']
  for inv_i in range(data_config['max_num_anno']):
    anno_feature_to_column.update(
        {f'{feat}_{inv_i}': anno_offset + feat_i + len(anno_feats) * inv_i
          for feat_i, feat in enumerate(anno_feats)})
  print(anno_feature_to_column)

  parse_fn = parse_tsv_line_with_inv_anno
  parse_fn_args = anno_feature_to_column
  if base_iit_anno_tsv_path:
    # Evaluate with interchange interventions on the base examples.
    length_indexed_base_examples = index_examples_by_length(
        base_iit_anno_tsv_path, anno_feature_to_column)
    parse_fn = parse_tsv_line_with_length_preserving_interchange_inv_anno
    parse_fn_args = (anno_feature_to_column, length_indexed_base_examples)

  datasets.config.IN_MEMORY_MAX_SIZE = 1024**3  # 1G

  acc_metrics = collections.defaultdict(int)
  for _ in range(num_repeat):
    text_anno_datasets = gen_seq2seq_text_dataset_from_tsv(
          {'test': test_iit_anno_tsv_path},
          parse_fn=parse_fn, parse_fn_args=parse_fn_args, keep_in_memory=True)
    eval_dataloader = DataLoader(text_anno_datasets['test'], batch_size=batch_size)
    metrics, eval_outputs = eval_with_interchange_interventions(
        model, eval_dataloader, data_config,
        iit_layer=iit_layer, iit_dim=iit_dim,
        num_beams=num_beams, num_return_sequences=num_return_sequences)
    for k in metrics:
      acc_metrics[k] += metrics[k]
  return acc_metrics, eval_outputs

In [None]:
#@markdown load models

import string
from transformers import AutoTokenizer

def load_model(model_name, ckpt, seq_len_config):
  def is_char(model_name, side):
    return ('char' in model_name and side in model_name) or 'byt5' in model_name
  model = T5ForConditionalGeneration.from_pretrained(
      os.path.join(MODEL_DIR, model_name, ckpt))
  model = model.to(device)
  if 'byt5' not in model_name:
    char_tokenizer = copy.deepcopy(t5_default_tokenizer)
    if 'unit_conversion' in model_name:
      char_tokenizer.add_special_tokens(
          {'additional_special_tokens': [c for c in string.digits + '.']})
    else:
      if 'spelling_correction_contextual' in model_name or 'word_search' in model_name:
        char_tokenizer = T5Tokenizer(
            os.path.join(MODEL_DIR, 't5_char_spiece.model'),
            eos_token=tokenizer.eos_token,
            unk_token=tokenizer.unk_token, pad_token=tokenizer.pad_token,
            extra_ids=100,
            additional_special_tokens=tokenizer.additional_special_tokens)
      else:
        char_tokenizer.add_special_tokens(
            {'additional_special_tokens': [chr(i) for i in range(128) if chr(i) in VOCAB]})
  else:
    char_tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
  source_tokenizer = char_tokenizer if is_char(model_name, 'src') else None
  target_tokenizer = char_tokenizer if is_char(model_name, 'trg') else None
  print('SOURCE_TOKENIZER:', source_tokenizer)
  print('TARGET_TOKENIZER:', target_tokenizer)
  max_input_seq_len = seq_len_config['max_src_char'] if is_char(model_name, 'src') else seq_len_config['max_src_token']
  max_output_seq_len = seq_len_config['max_trg_char'] if is_char(model_name, 'trg') else seq_len_config['max_trg_token']
  return model, (source_tokenizer, target_tokenizer), (max_input_seq_len, max_output_seq_len)

# Tasks

In [None]:
TASK_TO_DATASETS = {
    'reversal': {
        'train': 'data/reversal_train.tsv',
        'val': 'data/reversal_val.tsv',
        'test_iv': 'data/reversal_test_iv.tsv',
        'test_oov': 'data/reversal_test_oov.tsv',
        'seq_length': {'max_src_token': 24, 'max_src_char': 24, 'max_trg_token': 16, 'max_trg_char': 16},
    },
    'unit_conversion': {
        'train': 'data/unit_conversion_train.tsv',
        'val': 'data/unit_conversion_val.tsv',
        'test_iv': 'data/unit_conversion_test_iv.tsv',
        'test_oov': 'data/unit_conversion_test_oov.tsv',
        'seq_length': {'max_src_token': 16, 'max_src_char': 48, 'max_trg_token': 16, 'max_trg_char': 16},
    },
    'unscramble': {
        'train': 'data/unscramble_train.tsv',
        'val': 'data/unscramble_val.tsv',
        'test_iv': 'data/unscramble_test_iv.tsv',
        'test_oov': 'data/unscramble_test_oov.tsv',
        'seq_length': {'max_src_token': 24, 'max_src_char': 24, 'max_trg_token': 16, 'max_trg_char': 16},
    },
    'spelling_correction': {
        'train': 'data/spelling_correction_train.tsv',
        'val': 'data/spelling_correction_val.tsv',
        'test_iv': 'data/spelling_correction_test_iv.tsv',
        'test_oov': 'data/spelling_correction_test_oov.tsv',
        'test_real': 'data/spelling_correction_test_real.tsv',
        'seq_length': {'max_src_token': 24, 'max_src_char': 24, 'max_trg_token': 16, 'max_trg_char': 16},
    },
    'spelling_correction_contextual': {
        'train': 'data/spelling_correction_contextual_train.tsv',
        'val': 'data/spelling_correction_contextual_val.tsv',
        'test_context_independent': 'data/spelling_correction_contextual_test_context_independent.tsv',
        'test_context_dependent': 'data/spelling_correction_contextual_test_context_dependent.tsv',
        'seq_length': {'max_src_token': 48, 'max_src_char': 64, 'max_trg_token': 48, 'max_trg_char': 64},
    },
    'word_search': {
        'train': 'data/word_search_train.tsv',
        'val': 'data/word_search_val.tsv',
        'test_oov': 'data/word_search_test_oov.tsv',
        'test_paraphrase': 'data/word_search_test_paraphrase.tsv',
        'test_overlap': 'data/word_search_test_overlap.tsv',
        'test_paraphrase_overlap': 'data/word_search_test_paraphrase_overlap.tsv',
        'seq_length': {'max_src_token': 48, 'max_src_char': 128, 'max_trg_token': 8, 'max_trg_char': 16},
    },
}

# Generating IIT Datasets

In [None]:
TASK_TO_IIT_DATA = {
    'reversal': {
        'iit_config': {'max_num_inv': 16, 'max_num_anno': 16, 'include_space': False, 'eval_with_separate_base': True},
    },
    'unit_conversion': {
        'iit_config': {'max_num_inv': 16, 'max_num_anno': 16, 'include_space': False, 'eval_with_separate_base': False},
    },
    'unscramble': {
       'iit_config': {'max_num_inv': 16, 'max_num_anno': 16, 'include_space': False, 'eval_with_separate_base': True},
    },
    'spelling_correction': {
       'iit_config': {'max_num_inv': 16, 'max_num_anno': 16, 'include_space': False, 'eval_with_separate_base': True},
    },
    'spelling_correction_contextual': {
       'iit_config': {'max_num_inv': 64, 'max_num_anno': 64, 'include_space': True, 'eval_with_separate_base': False},
    },
    'word_search': {
       'iit_config': {'max_num_inv': 24, 'max_num_anno': 24, 'include_space': False, 'eval_with_separate_base': False},
    },
}

In [None]:
# Annotate the location and value of each character
task_name = 'reversal'
base_input_path = os.path.join(DATA_DIR, '%s.tsv')

anno_dataset_split_to_path = gen_character_annotation_data(
    base_input_path, task_name,
    inv_config=TASK_TO_IIT_DATA[task_name]['iit_config'],
    splits=['train'])

In [None]:
# Generate IIT Dataset

anno_input_tsv_path = anno_dataset_split_to_path['train']
output_tsv_path = anno_dataset_split_to_path['train'].replace('_char_anno.tsv', '_iit_examples.tsv')
print(anno_input_tsv_path)
print(output_tsv_path)

inv_examples = gen_inv_examples(anno_input_tsv_path, output_tsv_path, 100_000, task_name)
#test_gen_inv_example()

In [None]:
TASK_TO_IIT_DATA[task_name]['iit_train'] = output_tsv_path

# Training

In [None]:
import copy
import string

char_set = 'ascii' if task_name != 'unit_conversion' else 'digit'

if char_set == 'digit':
  special_tokens = string.digits + '.'
  char_tokenizer = copy.deepcopy(t5_default_tokenizer)
  char_tokenizer.add_special_tokens({'additional_special_tokens': [c for c in special_tokens]})
elif char_set == 'ascii':
  char_tokenizer = T5Tokenizer(
    os.path.join('tokenizers/t5_char_spiece.model'),
    eos_token=tokenizer.eos_token,
    unk_token=tokenizer.unk_token, pad_token=tokenizer.pad_token,
    extra_ids=100,
    additional_special_tokens=tokenizer.additional_special_tokens)


print(f'char_set="{char_set}"')
print([VOCAB[i] for i in char_tokenizer('test xxx').input_ids])
print([VOCAB[i] for i in char_tokenizer('0.12 333').input_ids])

In [None]:
from transformers import AutoTokenizer


USE_BYT5_TOKENIZER = False
CHAR_TOKENIZE_TARGET = False
CHAR_TOKENIZE_SOURCE = False
if USE_BYT5_TOKENIZER:
  CHAR_TOKENIZE_TARGET = True
  CHAR_TOKENIZE_SOURCE = True

print(f'CHAR_TOKENIZE_TARGET={CHAR_TOKENIZE_TARGET}')
print(f'CHAR_TOKENIZE_SOURCE={CHAR_TOKENIZE_SOURCE}')
print(f'USE_BYT5_TOKENIZER={USE_BYT5_TOKENIZER}')

source_tokenizer = t5_default_tokenizer
target_tokenizer = t5_default_tokenizer
if USE_BYT5_TOKENIZER:
  byt5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
  source_tokenizer = byt5_tokenizer
  target_tokenizer = byt5_tokenizer
  VOCAB = byt5_tokenizer.convert_ids_to_tokens(range(270))
  print('Use ByT5 tokenizer.')
if CHAR_TOKENIZE_SOURCE and not USE_BYT5_TOKENIZER:
  source_tokenizer = char_tokenizer
  print('Use char tokenizer for source.')
if CHAR_TOKENIZE_TARGET and not USE_BYT5_TOKENIZER:
  target_tokenizer = char_tokenizer
  print('Use char tokenizer for target.')

task_name = 'reversal'
trainval_datasets, _ = load_datasets(
    task_name, source_tokenizer=source_tokenizer, target_tokenizer=target_tokenizer)

In [None]:
# Load IIT dataset
ENABLE_IIT = True

iit_dataset = gen_iit_dataset_from_tsv(task_name)

In [None]:
from torch.utils.data import DataLoader

training_batch_size = 16
eval_batch_size = 32

train_dataloader = DataLoader(trainval_datasets['train'], shuffle=True,
                                batch_size=training_batch_size)
val_dataloader = DataLoader(trainval_datasets['val'], shuffle=True,
                                      batch_size=eval_batch_size)
iit_dataloader = DataLoader(iit_dataset['train'], shuffle=True,
                            batch_size=training_batch_size)

In [None]:
import numpy as np

from torch.optim import AdamW
from transformers import get_scheduler
import collections

try:
  del model
except NameError:
  pass

print(task_name, 'training_batch_size=%d' % training_batch_size)
print(f"SEQ_LENGTH={TASK_TO_DATASETS[task_name]['seq_length']}")
print(f'CHAR_TOKENIZE_TARGET={CHAR_TOKENIZE_TARGET}')
print(f'CHAR_TOKENIZE_SOURCE={CHAR_TOKENIZE_SOURCE}')


restore_ckpt = False

if not restore_ckpt:
  start_epoch = 0
  if not USE_BYT5_TOKENIZER:
    # T5
    model = T5ForConditionalGeneration.from_pretrained(
        "t5-small", cache_dir=MODEL_DIR)
    #model = T5ForConditionalGeneration(config=model.config)
  else:
    # ByT5
    model = T5ForConditionalGeneration.from_pretrained(
        "google/byt5-small", cache_dir=MODEL_DIR)
    print('Use pre-trained byt5-small.')
else:
  start_epoch = 20
  pretrain_task_name = task_name
  ckpt = f'ckpt-ep{start_epoch}'
  print(f'Resume from {pretrain_task_name} {ckpt}')
  model = T5ForConditionalGeneration.from_pretrained(
      os.path.join(MODEL_DIR, pretrain_task_name, ckpt))
model = model.to(device)


num_epochs = 40
optimizer = AdamW(model.parameters(), lr=5e-4)
# Need to keep num_training_steps the same for the same learning rate schedule.
num_training_steps = (num_epochs + 20) * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=1_000,
    num_training_steps=num_training_steps
)

base_factor, inv_factor = 1.0, float(ENABLE_IIT)
print(f'base_factor={base_factor} inv_factor={inv_factor}')

iit_layer = 1
char_dim = 16
location_map_fn = lambda x: add_batch_layer_dim_location_map_fn(x, iit_layer, char_dim)
if ENABLE_IIT:
  print(f'iit_layer={iit_layer} char_dim={char_dim}')
  enable_iit = True

metrics_default_value = {'token_accuracy': 0, 'loss': 0, 'sequence_accuracy': 0}

causal_abstraction = TransformerCausalAbstraction(model)

for epoch in range(num_epochs):
  epoch = epoch + start_epoch
  if ENABLE_IIT:
    iit_dataloader_itrs = [iter(iit_dataloader)]
  model.train()
  for step, input_batch in enumerate(train_dataloader):
    iit_batch = []
    if ENABLE_IIT:
      try:
        iit_batch = ([next(iit_dataloader_itrs[i]) for i in range(len(iit_dataloader_itrs))])
      except StopIteration:
        iit_dataloader_itrs = [iter(iit_dataloader)]
        iit_batch = ([next(iit_dataloader_itrs[i]) for i in range(len(iit_dataloader_itrs))])
    for k in input_batch:
      input_batch[k] = input_batch[k].to(device)
    loss, outputs, iit_outputs = iit_train_step(
        causal_abstraction, input_batch, iit_batch,
        location_map_fn, base_factor, inv_factor)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    if step % 500 == 0:
      metrics = compute_metrics(([outputs.logits.detach().cpu().numpy()],
                                  input_batch['labels'].cpu().numpy()))
      # iit metrics only work if there is one iit dataset
      iit_metrics = metrics_default_value
      if iit_batch:
        iit_metrics = compute_metrics(([iit_outputs.logits.detach().cpu().numpy()],
                                       iit_batch[0]['labels'].cpu().numpy()))
        iit_metrics['loss'] = iit_outputs.loss.detach().cpu().numpy()
      print('Epoch %d Step %d: Loss %.4f Loss_iit %.4f Acc %.4f Acc_iit %.4f LR %.2E' % (
          epoch, step, metrics['loss'],
          iit_metrics['loss'],
          metrics['token_accuracy'],
          iit_metrics['token_accuracy'],
          lr_scheduler.get_last_lr()[0]))
    if step % 10000 == 0:
      model.eval()
      train_metrics = run_eval(model, train_dataloader, first_n_batch=16)
      val_metrics = run_eval(model, val_dataloader, first_n_batch=16)
      model.train()
      print('Epoch %d Step %d: TRAIN Loss %.4f Accuracy %.4f VAL Accuracy %.4f' % (
            epoch, step, train_metrics['loss'], train_metrics['token_accuracy'],
            val_metrics['token_accuracy']))
  model.save_pretrained(os.path.join(MODEL_DIR, task_name, 'ckpt-ep%d' % (epoch+1)))
  print('Checkpoint saved at %s' % os.path.join(MODEL_DIR, task_name, 'ckpt-ep%d' % (epoch+1)))
  # run eval
  model.eval()
  train_metrics = run_eval(model, train_dataloader, first_n_batch=16)
  val_metrics = run_eval(model, val_dataloader, first_n_batch=100)
  for key in train_metrics:
    train_metrics[key] = float(np.array(train_metrics[key]).mean())
  print('Epoch %d Done: TRAIN Loss %.4f Accuracy %.4f VAL Accuracy %.4f' % (
          epoch, train_metrics['loss'], train_metrics['token_accuracy'],
          val_metrics['token_accuracy']))

# Evaluation

In [None]:
TASK_TO_METRICS = {
    'reversal': ['match_fullstring'],
    'unit_conversion': ['match_fullstring'],
    'unscramble': ['match_fullstring', 'match_anagram_valid'],
    'spelling_correction': ['match_fullstring', 'match_spelling_correction'],
    'spelling_correction_contextual':  ['match_fullstring'],
    'word_search': ['match_fullstring'],
}

In [None]:
eval_task = 'reversal'
model_name = 'reversal_subword_iit'

In [None]:
model_to_eval_outputs = {}
for split, test_data_tsv_path in TASK_TO_DATASETS[eval_task].items():
  if not split.startswith('test'):
    continue
  model, (source_tokenizer, target_tokenizer), max_seq_len = load_model(
      model_name, '', TASK_TO_DATASETS[eval_task]['seq_length'])
  test_datasets = gen_seqvc2seq_dataset_from_tsv(
    {'test': test_data_tsv_path}, {'input': 0, 'label': 1}, max_seq_len,
    source_tokenizer=source_tokenizer, target_tokenizer=target_tokenizer)
  eval_outputs = eval_topk_accuracy(
      model, test_datasets['test'], max_seq_len[1],
      target_tokenizer=target_tokenizer if 'byt5' in model_name else None,
      batch_size=128)
  model_to_eval_outputs[(split, model_name)] = eval_outputs