In [1]:
%cd ppo

/data/group1/z44383r/dev/rl-nlg/experiments/ppo


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import nltk
from collections import defaultdict
from tqdm import tqdm
import wandb
import time
import numpy as np
import torch

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
from common_utils.path import ROOT_DPATH
from common_utils.random_seed import set_seed
from common_utils.multiwoz_data import MultiWOZData
from sys_nlg.gpt2.model import build_gpt2
from sys_nlg.gpt2.nlg import GPT2NLG

from ppo_utils.ppo_updator import PPOUpdator, Rollouts
from ppo_utils.ppo_train_data import PPOTrainData, ActionIDF
from ppo_utils.reward import da_accuracy, ComputeReward

from experiments.speech_error_simulation.error_simulator import SpeechErrorSimulator

In [6]:
checkpoints_dname = "checkpoints"
lm_name = "act_resp.4"

config = {
    "checkpoints_dname": checkpoints_dname,
    "lm_name": lm_name,
    "gpt2_config": {
        "pretrained_model_dpath": os.path.join(ROOT_DPATH, "sys_nlg/gpt2/outputs", checkpoints_dname, lm_name),
        "ref_model_dpath": os.path.join(ROOT_DPATH, "sys_nlg/gpt2/outputs",  checkpoints_dname, lm_name),
        "tokenizer_name": "gpt2",
        "lm_task_type": os.path.splitext(lm_name)[0],
        "act_bos_token": "[ACT]",
        "resp_bos_token": "[RSP]",
        "separate_vf": False
    },
    "ppo_config": {
        "checkpoint_output_dpath": os.path.join(ROOT_DPATH, "experiments/ppo/outputs/checkpoints"),
        "batch_size": 128,
        "forward_batch_size": 128,
        "minibatch_size": 1,
        "ppo_epochs": 4,
        "max_length": 256,
        "lr": 5.0e-6,
        "lr_linear_decay": False,
        "gamma":1.0,
        "lam":0.95,
        "cliprange": .2,
        "cliprange_value":.2,
        "vf_coef":.1,
        "target_kl": 8,
        "init_kl_coef":0.2,
        "horizon":10000,
        "temperature": 1.2
    },
    "user_nlu_config": {
        "nlu_name": "milu",
        "nlu_model_name": "size1-sys"
    },
    "train_config": {
        "random_seed": 42,
        "total_iterations": 50,
        "iterations_vf_pretrain": 0,
        "reward_type": "F1",
        "action_idf_weighted": True,
        "checkpoint_save_dpath": os.path.join(ROOT_DPATH, "experiments/ppo/outputs/checkpoints", "model_save_test"),
    },
    "noise_config": {
        "apply_noise": True,
        "part": "val",
        "side": "sys",
        "noise_type": "background(20)",
    }
}

In [7]:
project_id = "rl_nlg_test-noise"
run_id = "model_save_test"

In [16]:
gpt2_config = config["gpt2_config"]
ppo_config = config["ppo_config"]
user_nlu_config = config["user_nlu_config"]
train_config = config["train_config"]
noise_config = config["noise_config"]

if train_config["random_seed"] is not None:
    set_seed(train_config["random_seed"])

tokenizer, policy_gpt2, value_gpt2, ref_policy_gpt2  = build_gpt2(gpt2_config)

s_nlg = GPT2NLG(gpt2=policy_gpt2, tokenizer=tokenizer,
                lm_task_type=gpt2_config["lm_task_type"],
                act_bos_token=gpt2_config["act_bos_token"],
                resp_bos_token=gpt2_config["resp_bos_token"])
s_ref_nlg = GPT2NLG(gpt2=ref_policy_gpt2, tokenizer=tokenizer,
                    lm_task_type=gpt2_config["lm_task_type"],
                    act_bos_token=gpt2_config["act_bos_token"],
                    resp_bos_token=gpt2_config["resp_bos_token"])

ppo_updator = PPOUpdator(policy_model=policy_gpt2,
                            value_model=value_gpt2,
                            ref_policy_model=ref_policy_gpt2,
                            total_iterations=train_config["total_iterations"],
                            ppo_config=ppo_config)

speech_error_simulator = SpeechErrorSimulator.from_saved(part=noise_config["part"],
                                                            side=noise_config["side"],
                                                            noise_type=noise_config["noise_type"])

if user_nlu_config["nlu_name"] == "bert":
    from user_nlu.joint_bert.nlu import SimulatorBERTNLU
    u_nlu = SimulatorBERTNLU(config_fname=f"{user_nlu_config['nlu_model_name']}.json")
elif user_nlu_config["nlu_name"] == "milu":
    from user_nlu.milu.nlu import UserMILU
    u_nlu = UserMILU(archive_dname=user_nlu_config["nlu_model_name"])
elif user_nlu_config["nlu_name"] == "svm":
    from convlab2.nlu.svm.multiwoz import SVMNLU
    u_nlu = SVMNLU(mode="sys")

def run_nlu(system_response):
    if noise_config["apply_noise"]:
        noised_system_response, _ = speech_error_simulator.apply_error(src_text=system_response)
        pred_action = u_nlu.predict(noised_system_response)
    else:
        noised_system_response = ""
        pred_action = u_nlu.predict(system_response)
    return noised_system_response, pred_action

multiwoz_data = MultiWOZData()
ppo_train_data = PPOTrainData(multiwoz_data=multiwoz_data,
                                parts_used=["train"],
                                batch_size=ppo_config["batch_size"],
                                shuffle=True, infinite=True)
action_idf = ActionIDF(multiwoz_data=multiwoz_data,
                        parts_used=["train"]) # , "val"]

compute_reward = ComputeReward(reward_type=train_config["reward_type"],
                                action_idf_weighted=train_config["action_idf_weighted"])

run = wandb.init(project=project_id,
                    name=run_id,
                    config=config)
best_score = {'f1': float('-inf')}
for iteration_id in tqdm(range(train_config["total_iterations"])):
    logs = dict()
    table_log = defaultdict(list)
    ref_table_columns = ["ref/response", "ref/noised_response", "ref/f1", "ref/L_distance"]
    gen_table_columns = ["gen/response", "gen/noised_response", "gen/f1", "gen/L_distance"]
    test_table_columns = ["test/response", "test/noised_response", "test/f1", "test/L_distance"]

    env_log = defaultdict(list)
    timing = dict()
    t0 = time.time()

    rollouts = Rollouts(batch_size=ppo_config["batch_size"])

    t = time.time()
    batch = ppo_train_data.sample_batch()
    ref_batch = []
    gen_batch = []
    test_batch = []
    for fbi in range(0, ppo_config["batch_size"], ppo_config["forward_batch_size"]):
        ref_batch += s_ref_nlg.batch_generate(batch=batch[fbi:fbi+ppo_config["forward_batch_size"]],
                                            max_length=ppo_config["max_length"],
                                            temperature=1.0,
                                            do_sample=False)

        gen_batch_ = s_nlg.batch_generate(batch=batch[fbi:fbi+ppo_config["forward_batch_size"]],
                                        max_length=ppo_config["max_length"],
                                        temperature=ppo_config["temperature"],
                                        do_sample=True)
        gen_batch += gen_batch_
        for gen in gen_batch_:
            rollouts.insert_response(query_ids=gen['query_ids'].unsqueeze(0),
                                    response_ids=gen['response_ids'].unsqueeze(0),
                                    device=DEVICE)

        test_batch += s_nlg.batch_generate(batch=batch[fbi:fbi+ppo_config["forward_batch_size"]],
                                        max_length=ppo_config["max_length"],
                                        temperature=1.0,
                                        do_sample=False)
        timing['time/response_generation'] = time.time()-t
    timing['time/response_generation'] = time.time()-t

    t = time.time()
    for bi in range(ppo_config["batch_size"]):
        action = batch[bi]["system_action"]
        # gt_response = batch[bi]["system_response"]
        ref_response = ref_batch[bi]["response_txt"].replace(tokenizer.eos_token, "")
        gen_response = gen_batch[bi]["response_txt"].replace(tokenizer.eos_token, "")
        test_response = test_batch[bi]["response_txt"].replace(tokenizer.eos_token, "")

        ref_noised_response, ref_pred_action = run_nlu(system_response=ref_response)
        gen_noised_response, gen_pred_action = run_nlu(system_response=gen_response)
        test_noised_response, test_pred_action = run_nlu(system_response=test_response)

        ref_acc = da_accuracy(true_action=action, pred_action=ref_pred_action)
        gen_acc = da_accuracy(true_action=action, pred_action=gen_pred_action)
        test_acc = da_accuracy(true_action=action, pred_action=test_pred_action)

        if ref_acc["tp_acts"]:
            ref_action_idfs = np.array([action_idf[gt_act] for gt_act in ref_acc["tp_acts"]])
        else:
            ref_action_idfs = np.array([0.])
        if gen_acc["tp_acts"]:
            gen_action_idfs =np.array([action_idf[gen_act] for gen_act in gen_acc["tp_acts"]])
        else:
            gen_action_idfs = np.array([0.])

        ref_tokens = ref_response.lower().split()
        gen_tokens = gen_response.lower().split()
        test_tokens = test_response.lower().split()

        ref_gen_nld = nltk.edit_distance(ref_tokens, gen_tokens) / max(len(ref_tokens), len(gen_tokens))
        ref_test_nld = nltk.edit_distance(ref_tokens, test_tokens) / max(len(ref_tokens), len(test_tokens))

        reward = compute_reward(ref_acc=ref_acc, ref_action_idfs=ref_action_idfs, ref_tokens=ref_tokens,
                                gen_acc=gen_acc, gen_action_idfs=gen_action_idfs, gen_tokens=gen_tokens,
                                ref_gen_nld=ref_gen_nld)
        rollouts.insert_reward(reward=torch.tensor([reward]), device=DEVICE)

        table_log["ref/response"].append(ref_response)
        table_log["gen/response"].append(gen_response)
        table_log["test/response"].append(test_response)
        table_log["ref/noised_response"].append(ref_noised_response)
        table_log["gen/noised_response"].append(gen_noised_response)
        table_log["test/noised_response"].append(test_noised_response)
        table_log["ref/f1"].append(ref_acc["f1"])
        table_log["gen/f1"].append(gen_acc["f1"])
        table_log["test/f1"].append(test_acc["f1"])
        table_log["ref/L_distance"].append(0.)
        table_log["gen/L_distance"].append(ref_gen_nld)
        table_log["test/L_distance"].append(ref_test_nld)

        env_log["reward"].append(reward)
        env_log["ref/f1"].append(ref_acc["f1"])
        env_log["gen/f1"].append(gen_acc["f1"])
        env_log["test/f1"].append(test_acc["f1"])
        env_log["ref/acc"].append(ref_acc["acc"])
        env_log["gen/acc"].append(gen_acc["acc"])
        env_log["test/acc"].append(test_acc["acc"])
        env_log["gen/length_increase"].append(len(ref_tokens)-len(gen_tokens))
        env_log["test/length_increase"].append(len(ref_tokens)-len(test_tokens))
    timing['time/response_evaluation'] = time.time()-t

    env_result = {'env/reward_mean': np.mean(env_log['reward']).item(),
                    'env/reward_std': np.std(env_log['reward']).item(),
                    'env/reward_dist': env_log['reward'],
                    'env/gen_f1': np.mean(env_log["gen/f1"]).item(),
                    'env/test_f1': np.mean(env_log["test/f1"]).item(),
                    'env/ref_f1': np.mean(env_log["ref/f1"]).item(),
                    'env/gen_acc': np.mean(env_log["gen/acc"]).item(),
                    'env/test_acc': np.mean(env_log["test/acc"]).item(),
                    'env/ref_acc': np.mean(env_log["ref/acc"]).item(),
                    'env/length_increase': np.mean(env_log["test/length_increase"]).item()}

    if best_score["f1"] < env_result["env/test_f1"]:
        t = time.time()
        best_score["f1"] = env_result["env/test_f1"]
        policy_gpt2.save_checkpoint(tokenizer=tokenizer,
                                    output_dpath=train_config["checkpoint_save_dpath"],
                                    prefix=f"policy.{iteration_id}",
                                    eval_results=env_result)
        timing['time/checkpoint_save'] = time.time()-t
    
    t = time.time()
    if iteration_id < train_config["iterations_vf_pretrain"]:
        stats = ppo_updator.step(rollouts=rollouts, update_vf_only=True)
    else:
        stats = ppo_updator.step(rollouts=rollouts)
    timing['time/optimization'] = time.time()-t
    timing['time/epoch'] = time.time()-t0

    ref_table_rows = [list(row) for row in zip(*(table_log[col] for col in ref_table_columns))]
    gen_table_rows = [list(row) for row in zip(*(table_log[col] for col in gen_table_columns))]
    test_table_rows = [list(row) for row in zip(*(table_log[col] for col in test_table_columns))]
    wandb.log({**env_result,
                **stats,
                **timing,
                'table/ref': wandb.Table(columns=ref_table_columns, rows=ref_table_rows),
                'table/gen': wandb.Table(columns=gen_table_columns, rows=gen_table_rows),
                'table/test': wandb.Table(columns=test_table_columns, rows=test_table_rows)})


2022-05-04 10:39:10,448    INFO random_seed.py Set random seed to 42
Using pad_token, but it is not set yet.
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

  0%|          | 0/50 [00:00<?, ?it/s]2022-05-04 10:39:47,503    INFO     model.py Saving model checkpoint to /data/group1/z44383r/dev/rl-nlg/experiments/ppo/outputs/checkpoints/model_save_test/best.policy.0
  2%|▏         | 1/50 [00:31<25:58, 31.81s/it]2022-05-04 10:40:21,641    INFO     model.py Saving model checkpoint to /data/group1/z44383r/dev/rl-nlg/experiments/ppo/outputs/checkpoints/model_save_test/best.policy.1
  6%|▌         | 3/50 [01:26<22:02, 28.13s/it]2022-05-04 10:41:14,196    INFO     model.py Saving model checkpoint to /data/group1/z44383r/dev/rl-nlg/experiments/ppo/outputs/checkpoints/model_save_test/best.policy.3
 12%|█▏        | 6/50 [02:45<19:32, 26.65s/it]2022-05-04 10:42:33,421    INFO     model.py Saving model checkpoint to /data/group1/z44383r/dev/rl-nlg/experiments/ppo/outputs/checkpoints/model_save_test/best.policy.6
 18%|█▊        | 9/50 [04:03<17:49, 26.08s/it]2022-05-04 10:44:21,513    INFO     model.py Saving model checkpoint to /data/group1/z44383r/dev/r

KeyboardInterrupt: 

In [17]:
run.finish()

VBox(children=(Label(value='2.154 MB of 2.154 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
env/gen_acc,▁▆▄▄▂▇▃▅▄▅▅▃█▆▆▆
env/gen_f1,▁▆▄▄▂▇▃▅▃▆▅▃█▆▆▆
env/length_increase,▃▅▆▆▅▄▃▃▁▇▇█▅▅█▆
env/ref_acc,▄█▄▃▆▄▄▂▁▇▁▃▃▅▆▆
env/ref_f1,▄█▅▃▇▄▅▁▂█▁▃▃▆▇▆
env/reward_mean,▁▄▄▂▁▅▂▅▂▅▅▃█▅▆▅
env/reward_std,▆▄▇▃▇▆▂▇▇▁▇▅▆▂█▂
env/test_acc,▃▆▅▇▅▆▇▄▂▇▁▂█▃▇▇
env/test_f1,▃▆▅▆▆▆█▃▃█▁▂█▄█▇
objective/entropy,█▅▆▄█▃▅▄▅▆▃▃▁▄▃▃

0,1
env/gen_acc,0.48924
env/gen_f1,0.56448
env/length_increase,0.49219
env/ref_acc,0.61652
env/ref_f1,0.68849
env/reward_mean,1.74817
env/reward_std,1.19398
env/test_acc,0.62072
env/test_f1,0.69301
objective/entropy,25.23724
