In [None]:
# %cd drive/MyDrive/

In [None]:
!pip install torch
!pip install transformers
# !pip install py-readability-metrics
!python -m nltk.downloader punkt

In [None]:
# !git clone https://github.com/dimi1357/Loora-cond-bot.git
# %cd Loora-cond-bot/

In [None]:
%cd py-readability-metrics/
!pip install .
%cd ..

In [None]:
## only needed if data/test_ready and data/train_ready not exist
# !python ./prepare_data.py data/combined_train.csv data/train.csv
# !python ./prepare_data.py data/combined_test.csv data/test.csv
# !python convert_to_signal.py data/train.csv data/test.csv easy medium hard

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, os
model_path = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
path = "dialog-model"
# path = "temp_fix_model"
start = 0
if os.path.exists('index.torch'):
  print(f"loading checkpoint: {path}")
  start = torch.load("index.torch")
  model_path = path
model = AutoModelForCausalLM.from_pretrained(model_path)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
model.to(device)
model.train()

In [None]:
import csv
grades = set()
lines = []
first = True
with open("data/train_ready.csv") as f:
  csvFile = csv.reader(f)
  contexts = []
  # targets = []
  inputs = []
  for line in csvFile:
    if first:
        first = False
        continue
    lines.append(line)
    grades.add(line[2])

In [None]:
import gc
from transformers import AdamW
import torch
import tqdm, sys, os

In [None]:
## Train
batch_size = 1
num_batches = len(lines) // batch_size
save_every = 1000

optimizer = AdamW(model.parameters(), lr=5e-5)

contexts = []
inputs = []
with tqdm.tqdm(desc="Training...", total=num_batches-(start // batch_size),
                      file=sys.stdout) as pbar:
  for i, line in enumerate(lines[start:]):
    while True:
      try:
        context, target, grade, *_ = line
        inputs.append(f"{grade} {context} {tokenizer.eos_token} {target} {tokenizer.eos_token}")
        contexts.append(f"{grade} {context} {tokenizer.eos_token}")
        # targets.append(target)
        if len(inputs) < batch_size:
          continue

        optimizer.zero_grad()
        context_dic = tokenizer.batch_encode_plus(contexts, padding='longest', return_tensors='pt', truncation=True)
        inputs_dic = tokenizer.batch_encode_plus(inputs, padding='longest', return_tensors='pt', truncation=True)

        mask = context_dic.attention_mask
        mask = torch.cat((mask, torch.zeros((batch_size, inputs_dic.attention_mask.size(-1)-mask.size(-1)), dtype=inputs_dic.attention_mask.dtype)), dim=-1)
        premise_mask = inputs_dic.attention_mask - mask
        labels = inputs_dic["input_ids"].clone() * (premise_mask==1) + -100 * ~(premise_mask==1)
#       import pdb; pdb.set_trace()
        inputs_dic = inputs_dic.to(device)
        labels = labels.to(device)
        outputs = model(**inputs_dic, labels=inputs_dic["input_ids"])
        inputs_dic, context_dic, inputs, contexts, mask = (None,) * 5
        loss = outputs.loss
        loss.backward()
        pbar.set_description(f"loss: {loss.item()}")
        pbar.update()
        contexts = []
        inputs = []
        if i+1 % save_every == 0:
          model.eval()
          model.save_pretrained(path)
          torch.save(start+i, 'index.torch')
          print(f"\nModel checkpoint is saved in: {path}")
          model.train()
        break
      except RuntimeError as e:
        gc.collect()
        inputs = []
        contexts = []

  model.eval()
  model.save_pretrained(path)
  torch.save(i, 'index.torch')
  print(f"\nModel checkpoint is saved in: {path}")


In [None]:
model.save_pretrained(path)

In [None]:
import csv
lines = []
first = True
with open("data/test_ready.csv") as f:
  csvFile = csv.reader(f)
  contexts = []
  # targets = []
  inputs = []
  for line in csvFile:
    if first:
        first = False
        continue
    lines.append(line)

In [None]:
## If you want to load the checkpoint
# path = "temp_fix_model"
# model = AutoModelForCausalLM.from_pretrained(path)
# model.eval()

In [None]:
## Eval method 1
model.eval()
batch_size = 1
num_batches = len(lines) // batch_size
contexts = []
inputs = []
num_correct = 0

with tqdm.tqdm(desc="Running eval method 1...", total=len(lines),
                      file=sys.stdout) as pbar:
  for i, line in enumerate(lines):
    context, target, grade, *_ = line
    best = (10e5, None)
    for temp_grade in grades:
      inputs.append(f"{temp_grade} {context} {tokenizer.eos_token} {target} {tokenizer.eos_token}")
      contexts.append(f"{temp_grade} {context} {tokenizer.eos_token}")

      context_dic = tokenizer.batch_encode_plus(contexts, padding='longest', return_tensors='pt', truncation=True)
      inputs_dic = tokenizer.batch_encode_plus(inputs, padding='longest', return_tensors='pt', truncation=True)

      mask = context_dic.attention_mask
      mask = torch.cat((mask, torch.zeros((batch_size, inputs_dic.attention_mask.size(-1)-mask.size(-1)), dtype=inputs_dic.attention_mask.dtype)), dim=-1)
      premise_mask = inputs_dic.attention_mask - mask
      labels = inputs_dic["input_ids"].clone() * (premise_mask==1) + -100 * ~(premise_mask==1)
#       import pdb; pdb.set_trace()
      inputs_dic = inputs_dic.to(device)
      labels = labels.to(device)
      with torch.no_grad():
        outputs = model(**inputs_dic, labels=labels)
        loss = outputs.loss
      if loss < best[0]:
        best = (loss, temp_grade)
      inputs_dic, context_dic, inputs, contexts, mask, labels = None, None, [], [], None, None
    if grade == best[-1]:
      num_correct += 1
    pbar.set_description(f"loss: {loss.item()}, running accuracy: {float(num_correct) / (i+1)}")
    pbar.update()
    contexts = []
    inputs = []

print(f"The model accuracy is: {float(num_correct) / len(lines)}, using eval method 1")

In [None]:
## Eval method 2
from readability import Readability

model.eval()
contexts = []
inputs = []
num_correct = 0
# thresholds = [2.0, 6.0, 10e8]
thresholds = [0.0, 7.0, 10e8]
signals = ["easy", "medium", "hard"]

with tqdm.tqdm(desc="Running eval method 2...", total=len(lines),
                      file=sys.stdout) as pbar:
  for i, line in enumerate(lines):
    context, target, grade, *_ = line
    contexts.append(f"{grade} {context} {tokenizer.eos_token}")

    context_dic = tokenizer.batch_encode_plus(contexts, padding='longest', return_tensors='pt', truncation=True)
    context_dic = context_dic.to(device)
    context_ids = context_dic["input_ids"]
    response_ids = model.generate(context_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    out = tokenizer.decode(response_ids[:, context_ids.shape[-1]:][0], skip_special_tokens=True)
    try:
        r = Readability(out)
        score = r.flesch_kincaid()
        pred_grade = score.grade_level
#         import pdb; pdb.set_trace()
        for j in range(len(thresholds)):
          if float(pred_grade) < thresholds[j]:
            pred_grade = signals[j]
            break

    except ZeroDivisionError:
        pred_grade = ""
    if grade == pred_grade:
      num_correct += 1
    
    pbar.update()
    pbar.set_description(f"running accuracy: {float(num_correct) / (i+1)}")
    contexts = []
    inputs = []

print(f"The model accuracy is: {float(num_correct) / len(lines)}, using eval method 2")

In [None]:
# Chat
for step in range(10**8):
    inp = input(">> User: ")
    level = input("Response level: ")
    new_user_input_ids = tokenizer.encode(inp + tokenizer.eos_token, return_tensors='pt')
    level_ids = tokenizer.encode(level + tokenizer.eos_token, return_tensors='pt')
    bot_input_ids = torch.cat([level_ids, chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    bot_input_ids = bot_input_ids.to(device)
#     import pdb; pdb.set_trace()
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, num_beams=3, temperature=0.8, top_p=0.9)
    chat_history_ids = chat_history_ids.cpu()
    bot_input_ids = bot_input_ids.cpu()
    lvl_ids = tokenizer.encode(level + tokenizer.eos_token, return_tensors='pt')
    out = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)

    if (chat_history_ids[:,:level_ids.size(-1)] == level_ids).all():
        chat_history_ids = chat_history_ids[:,level_ids.size(-1):]
        
    print("DialoGPT: {}".format(out))