In [1]:
from pathlib import Path
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import pandas as pd
import json
import tqdm
import transformers
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import Dataset, DatasetDict
import evaluate
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
NUM_EPOCHS = 50
EXPERIMENT_NAME = "t5-small_falcon2-default_annotation-k5"
EXPERIMENT_DIR = Path('experiments')
MODEL_ARTIFACTS = EXPERIMENT_DIR / EXPERIMENT_NAME
WEIGHTS_DIR = MODEL_ARTIFACTS / 'weights'
VALS_DIR = MODEL_ARTIFACTS / 'validations'

Make appropriate directoreis

In [2]:
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
VALS_DIR.mkdir(parents=True, exist_ok=True)

Defining the model and tokenizer

In [3]:
model_path = "t5-small"
tokenizer_path = "t5-small"

model = T5ForConditionalGeneration.from_pretrained(model_path, device_map ='auto')
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)

In [4]:
from pprint import pprint
pprint(model.hf_device_map)

{'': 0}


Define dataset maker

In [5]:
def split_dataframe(df):
  # ratios from Bannerjee
  train = 0.7
  dev = 0.1
  test = 0.2
  assert train + dev + test == 1.0
  data_len = len(df)
  train_set = Dataset.from_pandas(df[:round(data_len * train)])
  dev_set = Dataset.from_pandas(df[round(data_len * train):round(data_len* (train + dev))])
  test_set = Dataset.from_pandas(df[round(data_len * (train + dev)):])
  
  dataset = DatasetDict()
  dataset['train'] = train_set
  dataset['dev'] = dev_set
  dataset['test'] = test_set

  return dataset

Define dataset tokenizer

In [6]:
def tokenize_data(dataset, column):
  model_inputs = tokenizer(dataset[column], padding=True, truncation=True, return_tensors="pt")
  return model_inputs

Define unmasker

In [7]:
from pipeline import T5Converter
converter = T5Converter()

Defining the validation function

In [8]:
def val(val_dataloader, val_path = None):
  model.eval()
  eval_dict = []

  iters = len(val_dataloader)

  progress_bar = tqdm.tqdm(iters, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}")
  progress_bar.set_description(f"Eval")

  correct_preds = 0
  total_preds = 0

  for val_batch in val_dataloader:
    batch = {}
    for k,v in val_batch.items():
      if k in {"input_ids", "labels", "attention_mask"}:
        batch[k] = v.to("cuda")

    with torch.no_grad():
      outputs = model(**batch)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    for i, pred in enumerate(tokenizer.batch_decode(predictions)):
      gold = val_batch['gold'][i]
      gold = gold.strip().replace(" ","")
      gold2 = gold.replace(">", "> ").replace("<"," <").replace("  ", " ").strip()
      pred = pred.replace(" ","").replace("</s>", "").replace("<pad>","").replace('<unk>','').replace('<s>','').strip().replace(" ","")
      pred2 = pred.replace(">", "> ").replace("<"," <").replace("  ", " ").strip()
      entry_dict = {
        "Utte": val_batch['utterance'][i],
        "Anno": val_batch['annotated'][i],
        "Gold": val_batch['gold'][i],
        "Gene": pred, # THIS NEEDS TO BE UNMASKED
        "Gol2": converter._unmask_generic(gold2),
        "Gen2": converter._unmask_generic(pred2),
      }
      eval_dict.append(entry_dict)
      total_preds += 1
      if entry_dict['Gol2'] == entry_dict['Gen2']:
        correct_preds += 1
    progress_bar.update(1)
  
  if val_path:
    with open(val_path, "w") as f:
      json.dump(eval_dict, f, indent=2)

  accuracy = correct_preds/total_preds

  meta = {
    'accuracy': accuracy
  }
  
  model.train()
  return eval_dict, accuracy

In [9]:
def training_loop(df):
  print("beginning training")

  assert 'utterance' in df.columns
  assert 'annotated' in df.columns
  assert 'gold' in df.columns

  dataset = split_dataframe(df)
  tokenized_dataset = dataset \
    .map(lambda x: tokenize_data(x, 'gold'), batched=True) \
    .rename_column('input_ids', 'labels') \
    .map(lambda x: tokenize_data(x, 'annotated'), batched=True)

  tokenized_dataset.set_format("pt", columns=["input_ids", "attention_mask", "labels"], output_all_columns=True)
  print("data loaded")
  
  train_dataset = tokenized_dataset["train"]
  dev_dataset = tokenized_dataset["dev"]
  test_dataset = tokenized_dataset["test"]

  train_dataloader = DataLoader(train_dataset, batch_size = 10)
  dev_dataloader = DataLoader(dev_dataset, batch_size = 10)

  scalar = 0

  optimizer = optim.AdamW(model.parameters(), lr = 0.0015)
  lr_scheduler=transformers. \
    get_polynomial_decay_schedule_with_warmup(optimizer, 5000, 30000, power=0.5)
  
  epoch_data = {}

  for epoch in range(NUM_EPOCHS):
    print("Beginning Epoch:", epoch)
    i = 0
    iters = len(train_dataloader)
    for batch in train_dataloader:
      newbatch = {}
      for k,v in batch.items():
        if k in ["labels", "input_ids", "attention_mask"]:
          newbatch[k] = v.to("cuda")
      
      batch = newbatch
      newbatch = {}

      outputs = model(**batch)
      loss = outputs.loss
      scalar += loss.mean().item()

      if (i+1) % 100 == 0:
        print(f'iteration = {i+1}/{iters}, training loss={scalar/100}')
        scalar = 0

      loss /= 10 
      loss.mean().backward()
      if (i+1) % 1 == 0:
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
      
      del batch
      i += 1
    
    print(f"Validating epoch {epoch}")
    val_filename = f"val_{epoch}.json"
    _, acc = val(dev_dataloader, VALS_DIR / val_filename)
    print(f"Accuracy: {acc:.2f}")
    assert val_filename not in epoch_data
    epoch_data[val_filename] = {"accuracy": acc}

    with open(MODEL_ARTIFACTS / "meta_data.json", "w") as f:
      json.dump(epoch_data, f, indent=2)

    torch.save(model.state_dict(),
      WEIGHTS_DIR / f"cp_{epoch}.pth")
  print(f"Validating final")
  val_filename = f"val_final.json"

  _, meta = val(dev_dataloader, VALS_DIR / val_filename)

  epoch_data[val_filename] = {"accuracy": acc}

  with open(MODEL_ARTIFACTS / "meta_data.json", "w") as f:
    json.dump(epoch_data, f, indent=2)

  torch.save(model.state_dict(),
    WEIGHTS_DIR / f"cp_final.pth")

In [10]:
df_json = []
with open('falcon_links/2/link_24066.json') as f:
  data_json = json.load(f)

print(data_json[0])


[{'utterance': 'What is Delta Air Lines periodical literature mouthpiece?', 'ents': [{'uri': 'http://www.wikidata.org/entity/Q188920', 'prefix': 'wd:', 'id': 'Q188920'}, {'uri': 'http://www.wikidata.org/entity/Q1002697', 'prefix': 'wd:', 'id': 'Q1002697'}, {'uri': 'http://www.wikidata.org/entity/Q416938', 'prefix': 'wd:', 'id': 'Q416938'}, {'uri': 'http://www.wikidata.org/entity/Q523753', 'prefix': 'wd:', 'id': 'Q523753'}, {'uri': 'http://www.wikidata.org/entity/Q671722', 'prefix': 'wd:', 'id': 'Q671722'}], 'rels': []}, {'utterance': 'What is Delta Air Lines periodical literature mouthpiece?', 'fragments': ['[DEF]', 'wd:', 'Q188920 Delta', '[DEF]', 'wd:', 'Q1002697 periodical literature', '[DEF]', 'wd:', 'Q416938 Mouthpiece', '[DEF]', 'wd:', 'Q523753 mouthpiece', '[DEF]', 'wd:', 'Q671722 mouthpiece']}, {'inputs': 'What is Delta Air Lines periodical literature mouthpiece? <extra_id_59> <extra_id_53> Q188920 Delta <extra_id_59> <extra_id_53> Q1002697 periodical literature <extra_id_59> <

Main

In [11]:
# df_json = []
# with open('weekend.json') as f:
#   data_json = json.load(f)

for data in data_json:
  data_dict = {
    "utterance": data[0]["utterance"],
    "annotated": data[2]["inputs"],
    "gold": data[2]["labels"]
  }
  df_json.append(data_dict)

In [12]:
df = pd.DataFrame.from_dict(df_json)
df.head()

Unnamed: 0,utterance,annotated,gold
0,What is Delta Air Lines periodical literature ...,What is Delta Air Lines periodical literature ...,<extra_id_6> <extra_id_21> <extra_id_39> <extr...
1,What is the name of Ranavalona Is husbands child?,What is the name of Ranavalona Is husbands chi...,<extra_id_6> <extra_id_39> <extra_id_19> <extr...
2,Are Jeff Bridges and Lane Chandler both photog...,Are Jeff Bridges and Lane Chandler both photog...,<extra_id_4> <extra_id_19> <extra_id_33> <extr...
3,What range are the papers at the Monique Genon...,What range are the papers at the Monique Genon...,<extra_id_6> <extra_id_39> <extra_id_19> <extr...
4,Which is the operating income for Qantas?,Which is the operating income for Qantas? <ext...,<extra_id_6> <extra_id_21> <extra_id_39> <extr...


In [13]:
training_loop(df)

beginning training


Map:   0%|          | 0/16845 [00:00<?, ? examples/s]

Map:   0%|          | 0/2406 [00:00<?, ? examples/s]

Map:   0%|          | 0/4813 [00:00<?, ? examples/s]

Map:   0%|          | 0/16845 [00:00<?, ? examples/s]

Map:   0%|          | 0/2406 [00:00<?, ? examples/s]

Map:   0%|          | 0/4813 [00:00<?, ? examples/s]

data loaded
Beginning Epoch: 0


iteration = 100/1685, training loss=12.442191910743713


iteration = 200/1685, training loss=4.148782947063446


iteration = 300/1685, training loss=2.005652048587799


iteration = 400/1685, training loss=1.4345639818906784


iteration = 500/1685, training loss=1.2477442651987076


iteration = 600/1685, training loss=1.1298221987485886


iteration = 700/1685, training loss=1.0355184239149093


iteration = 800/1685, training loss=0.9431734985113144


iteration = 900/1685, training loss=0.8437490284442901


iteration = 1000/1685, training loss=0.7663650840520859


iteration = 1100/1685, training loss=0.702564115524292


iteration = 1200/1685, training loss=0.659597929418087


iteration = 1300/1685, training loss=0.5990251889824867


iteration = 1400/1685, training loss=0.5786932513117791


iteration = 1500/1685, training loss=0.5245084393024445


iteration = 1600/1685, training loss=0.5086952415108681


Validating epoch 0


|                                                                                                                                                                                                             | 0/?

Eval: |                                                                                                                                                                                                       | 0/?

Eval: |                                                                                                                                                                                                       | 2/?

Eval: |                                                                                                                                                                                                       | 4/?

Eval: |                                                                                                                                                                                                       | 6/?

Eval: |                                                                                                                                                                                                       | 8/?

Eval: |                                                                                                                                                                                                      | 10/?

Eval: |                                                                                                                                                                                                      | 12/?

Eval: |                                                                                                                                                                                                      | 14/?

Eval: |                                                                                                                                                                                                      | 16/?

Eval: |                                                                                                                                                                                                      | 18/?

Eval: |                                                                                                                                                                                                      | 20/?

Eval: |                                                                                                                                                                                                      | 22/?

Eval: |                                                                                                                                                                                                      | 24/?

Eval: |                                                                                                                                                                                                      | 26/?

Eval: |                                                                                                                                                                                                      | 28/?

Eval: |                                                                                                                                                                                                      | 30/?

Eval: |                                                                                                                                                                                                      | 32/?

Eval: |                                                                                                                                                                                                      | 34/?

Eval: |                                                                                                                                                                                                      | 36/?

Eval: |                                                                                                                                                                                                      | 38/?

Eval: |                                                                                                                                                                                                      | 40/?

Eval: |                                                                                                                                                                                                      | 42/?

Eval: |                                                                                                                                                                                                      | 44/?

Eval: |                                                                                                                                                                                                      | 46/?

Eval: |                                                                                                                                                                                                      | 48/?

Eval: |                                                                                                                                                                                                      | 50/?

Eval: |                                                                                                                                                                                                      | 52/?

Eval: |                                                                                                                                                                                                      | 54/?

Eval: |                                                                                                                                                                                                      | 56/?

Eval: |                                                                                                                                                                                                      | 58/?

Eval: |                                                                                                                                                                                                      | 60/?

Eval: |                                                                                                                                                                                                      | 62/?

Eval: |                                                                                                                                                                                                      | 64/?

Eval: |                                                                                                                                                                                                      | 66/?

Eval: |                                                                                                                                                                                                      | 68/?

Eval: |                                                                                                                                                                                                      | 70/?

Eval: |                                                                                                                                                                                                      | 72/?

Eval: |                                                                                                                                                                                                      | 74/?

Eval: |                                                                                                                                                                                                      | 76/?

Eval: |                                                                                                                                                                                                      | 78/?

Eval: |                                                                                                                                                                                                      | 80/?

Eval: |                                                                                                                                                                                                      | 82/?

Eval: |                                                                                                                                                                                                      | 84/?

Eval: |                                                                                                                                                                                                      | 86/?

Eval: |                                                                                                                                                                                                      | 88/?

Eval: |                                                                                                                                                                                                      | 90/?

Eval: |                                                                                                                                                                                                      | 92/?

Eval: |                                                                                                                                                                                                      | 94/?

Eval: |                                                                                                                                                                                                      | 96/?

Eval: |                                                                                                                                                                                                      | 98/?

Eval: |                                                                                                                                                                                                     | 100/?

Eval: |                                                                                                                                                                                                     | 102/?

Eval: |                                                                                                                                                                                                     | 104/?

Eval: |                                                                                                                                                                                                     | 105/?

Eval: |                                                                                                                                                                                                     | 106/?

Eval: |                                                                                                                                                                                                     | 107/?

Eval: |                                                                                                                                                                                                     | 108/?

Eval: |                                                                                                                                                                                                     | 109/?

Eval: |                                                                                                                                                                                                     | 110/?

Eval: |                                                                                                                                                                                                     | 111/?

Eval: |                                                                                                                                                                                                     | 112/?

Eval: |                                                                                                                                                                                                     | 113/?

Eval: |                                                                                                                                                                                                     | 114/?

Eval: |                                                                                                                                                                                                     | 115/?

Eval: |                                                                                                                                                                                                     | 116/?

Eval: |                                                                                                                                                                                                     | 117/?

Eval: |                                                                                                                                                                                                     | 118/?

Eval: |                                                                                                                                                                                                     | 119/?

Eval: |                                                                                                                                                                                                     | 120/?

Eval: |                                                                                                                                                                                                     | 121/?

Eval: |                                                                                                                                                                                                     | 122/?

Eval: |                                                                                                                                                                                                     | 123/?

Eval: |                                                                                                                                                                                                     | 124/?

Eval: |                                                                                                                                                                                                     | 125/?

Eval: |                                                                                                                                                                                                     | 126/?

Eval: |                                                                                                                                                                                                     | 127/?

Eval: |                                                                                                                                                                                                     | 128/?

Eval: |                                                                                                                                                                                                     | 129/?

Eval: |                                                                                                                                                                                                     | 130/?

Eval: |                                                                                                                                                                                                     | 131/?

Eval: |                                                                                                                                                                                                     | 132/?

Eval: |                                                                                                                                                                                                     | 133/?

Eval: |                                                                                                                                                                                                     | 134/?

Eval: |                                                                                                                                                                                                     | 135/?

Eval: |                                                                                                                                                                                                     | 136/?

Eval: |                                                                                                                                                                                                     | 137/?

Eval: |                                                                                                                                                                                                     | 138/?

Eval: |                                                                                                                                                                                                     | 139/?

Eval: |                                                                                                                                                                                                     | 140/?

Eval: |                                                                                                                                                                                                     | 141/?

Eval: |                                                                                                                                                                                                     | 142/?

Eval: |                                                                                                                                                                                                     | 143/?

Eval: |                                                                                                                                                                                                     | 144/?

Eval: |                                                                                                                                                                                                     | 145/?

Eval: |                                                                                                                                                                                                     | 146/?

Eval: |                                                                                                                                                                                                     | 147/?

Eval: |                                                                                                                                                                                                     | 148/?

Eval: |                                                                                                                                                                                                     | 149/?

Eval: |                                                                                                                                                                                                     | 150/?

Eval: |                                                                                                                                                                                                     | 151/?

Eval: |                                                                                                                                                                                                     | 152/?

Eval: |                                                                                                                                                                                                     | 153/?

Eval: |                                                                                                                                                                                                     | 154/?

Eval: |                                                                                                                                                                                                     | 155/?

Eval: |                                                                                                                                                                                                     | 156/?

Eval: |                                                                                                                                                                                                     | 157/?

Eval: |                                                                                                                                                                                                     | 158/?

Eval: |                                                                                                                                                                                                     | 159/?

Eval: |                                                                                                                                                                                                     | 160/?

Eval: |                                                                                                                                                                                                     | 161/?

Eval: |                                                                                                                                                                                                     | 162/?

Eval: |                                                                                                                                                                                                     | 163/?

Eval: |                                                                                                                                                                                                     | 164/?

Eval: |                                                                                                                                                                                                     | 165/?

Eval: |                                                                                                                                                                                                     | 166/?

Eval: |                                                                                                                                                                                                     | 167/?

Eval: |                                                                                                                                                                                                     | 168/?

Eval: |                                                                                                                                                                                                     | 169/?

Eval: |                                                                                                                                                                                                     | 170/?

Eval: |                                                                                                                                                                                                     | 171/?

Eval: |                                                                                                                                                                                                     | 172/?

Eval: |                                                                                                                                                                                                     | 173/?

Eval: |                                                                                                                                                                                                     | 174/?

Eval: |                                                                                                                                                                                                     | 175/?

Eval: |                                                                                                                                                                                                     | 176/?

Eval: |                                                                                                                                                                                                     | 177/?

Eval: |                                                                                                                                                                                                     | 178/?

Eval: |                                                                                                                                                                                                     | 179/?

Eval: |                                                                                                                                                                                                     | 180/?

Eval: |                                                                                                                                                                                                     | 181/?

Eval: |                                                                                                                                                                                                     | 182/?

Eval: |                                                                                                                                                                                                     | 183/?

Eval: |                                                                                                                                                                                                     | 184/?

Eval: |                                                                                                                                                                                                     | 185/?

Eval: |                                                                                                                                                                                                     | 186/?

Eval: |                                                                                                                                                                                                     | 187/?

Eval: |                                                                                                                                                                                                     | 188/?

Eval: |                                                                                                                                                                                                     | 189/?

Eval: |                                                                                                                                                                                                     | 190/?

Eval: |                                                                                                                                                                                                     | 191/?

Eval: |                                                                                                                                                                                                     | 192/?

Eval: |                                                                                                                                                                                                     | 193/?

Eval: |                                                                                                                                                                                                     | 194/?

Eval: |                                                                                                                                                                                                     | 195/?

Eval: |                                                                                                                                                                                                     | 196/?

Eval: |                                                                                                                                                                                                     | 197/?

Eval: |                                                                                                                                                                                                     | 198/?

Eval: |                                                                                                                                                                                                     | 199/?

Eval: |                                                                                                                                                                                                     | 200/?

Eval: |                                                                                                                                                                                                     | 202/?

Eval: |                                                                                                                                                                                                     | 204/?

Eval: |                                                                                                                                                                                                     | 206/?

Eval: |                                                                                                                                                                                                     | 208/?

Eval: |                                                                                                                                                                                                     | 210/?

Eval: |                                                                                                                                                                                                     | 212/?

Eval: |                                                                                                                                                                                                     | 214/?

Eval: |                                                                                                                                                                                                     | 216/?

Eval: |                                                                                                                                                                                                     | 218/?

Eval: |                                                                                                                                                                                                     | 220/?

Eval: |                                                                                                                                                                                                     | 222/?

Eval: |                                                                                                                                                                                                     | 224/?

Eval: |                                                                                                                                                                                                     | 226/?

Eval: |                                                                                                                                                                                                     | 228/?

Eval: |                                                                                                                                                                                                     | 230/?

Eval: |                                                                                                                                                                                                     | 232/?

Eval: |                                                                                                                                                                                                     | 234/?

Eval: |                                                                                                                                                                                                     | 236/?

Eval: |                                                                                                                                                                                                     | 238/?

Eval: |                                                                                                                                                                                                     | 240/?

Eval: |                                                                                                                                                                                                     | 241/?




Accuracy: 0.01


Beginning Epoch: 1


iteration = 100/1685, training loss=1.639268020093441


iteration = 200/1685, training loss=0.8514760440587997


In [None]:
import tqdm
import time

pb = tqdm.tqdm(total=5, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}")

for i in range(5):
  time.sleep(1)
  pb.update(1)