In [None]:
!pip install datasets pandas transformers simpletransformers torch tables >/dev/null

In [None]:
from datasets import load_dataset

dyda_da = load_dataset("silicone", "dyda_da")
maptask = load_dataset("silicone", "maptask")
mrda = load_dataset("silicone", "mrda")
oasis = load_dataset("silicone", "oasis")
swda = load_dataset("silicone", "swda")

In [None]:
mapping = {
    "acknowledge": {
        "swda": [
            "aap_am",
            "b",
            "bk"
        ],
        "mrda": [],
        "oasis": [
            "ackn",
            "accept",
            "complete"
        ],
        "maptask": [
            "acknowledge",
            "align"
        ],
        "dyda_da": [
            "commissive"
        ]
    },
    "answer": {
        "swda": [
            "bf",
        ],
        "mrda": [],
        "oasis": [
            "answ",
            "informCont",
            "inform",
            "answElab",
            "directElab",
            "refer"
        ],
        "maptask": [
            "reply_w",
            "explain"
        ],
        "dyda_da": [
            "inform"
        ]
    },
    "backchannel": {
        "swda": [
            "ad",
            "bh",
            "bd",
            "b^m"
        ],
        "mrda": [
            "b"
        ],
        "oasis": [
            "backch",
            "selfTalk",
            "init"
        ],
        "maptask": ["ready"],
        "dyda_da": []
    },
    "reply_yes": {
        "swda": [
            "na",
            "aa"
        ],
        "mrda": [],
        "oasis": [
            "confirm"
        ],
        "maptask": [
            "reply_y"
        ],
        "dyda_da": []
    },
    "exclaim": {
        "swda": [
            "ft",
            "fa",
            "fc",
            "fp"
        ],
        "mrda": [],
        "oasis": [
            "appreciate",
            "bye",
            "exclaim",
            "greet",
            "thank",
            "pardon",
            "thank-identitySelf",
            "expressRegret"
        ],
        "maptask": [],
        "dyda_da": []
    },
    "say": {
        "swda": [
            "qh",
            "sd"
        ],
        "mrda": ["s"],
        "oasis": [
            "expressPossibility",
            "expressOpinion",
            "suggest"
        ],
        "maptask": [],
        "dyda_da": []
    },
    "reply_no": {
        "swda": [
            "nn",
            "ng",
            "ar"
        ],
        "mrda": [],
        "oasis": [
            "refuse",
            "negate"
        ],
        "maptask": [
            "reply_n"
        ],
        "dyda_da": []
    },
    "hold": {
        "swda": [
            "^h",
            "t1"
        ],
        "mrda": [
            "f"
        ],
        "oasis": [
            "hold"
        ],
        "maptask": [],
        "dyda_da": []
    },
    "ask": {
        "swda": [
            "qw",
            "qo",
            "qw^d",
            "br",
            "qrr"
        ],
        "mrda": [
            "q"
        ],
        "oasis": [
            "reqInfo",
            "reqDirect",
            "offer"
        ],
        "maptask": [
            "query_w"
        ],
        "dyda_da": [
            "question"
        ]
    },
    "intent": {
        "swda": [],
        "mrda": [],
        "oasis": [
            "informIntent",
            "informIntent-hold",
            "expressWish",
            "direct",
            "raiseIssue",
            "correct"
        ],
        "maptask": [
            "instruct",
            "clarify"
        ],
        "dyda_da": [
            "directive"
        ]
    },
    "ask_yes_no": {
        "swda": [
            "qy^d",
            "^g"
        ],
        "mrda": [],
        "oasis": [
            "reqModal"
        ],
        "maptask": [
            "query_yn",
            "check"
        ],
        "dyda_da": []
    }
}

new_acts = list(mapping.keys())

print("Acts:")
[print(pair) for pair in list(enumerate(new_acts))] and 1

# maptask["train"].filter(lambda x: x["Dialogue_Act"] == "align") [:100]["Utterance"]

In [None]:
from functools import reduce
import pandas as pd

subsets = dict(dyda_da=dyda_da, maptask=maptask, swda=swda, oasis=oasis, mrda=mrda)

rev = [
    {
        subset_name: {prev_act: new_act for prev_act in prev_acts}
        for subset_name, prev_acts in mapped_to.items()
    } for new_act, mapped_to in mapping.items()
]

reversed_mapping = {
    subset: reduce(
        lambda merged, current: {**merged, **current}, 
        map(lambda x: x[subset], rev), 
        {}
    ) for subset in subsets.keys()
}

def try_else(fn, args=[], kwargs={}, else_val=None):
    try:
        return fn(*args, **kwargs)
    except Exception as e:
        print(e)
        return else_val
    
def flatten(t):
    return [item for sublist in t for item in sublist]

def merge_data(split="train"):
    return flatten([
        [
            [
                "" if i == 0 else subset[split][i-1]["Utterance"],
                subset[split][i]["Utterance"],
                new_acts.index(reversed_mapping[name][subset[split][i]["Dialogue_Act"]])
            ] for i in range(len(subset[split])) if subset[split][i]["Dialogue_Act"] in reversed_mapping[name]
        ] for name, subset in subsets.items()
    ])

# Check cache too
train_df = try_else(pd.read_hdf, ["train_cache.h5", "data"])
train_df = train_df if train_df is not None else pd.DataFrame(merge_data("train"))
train_df.columns = ["text_a", "text_b", "labels"]

eval_df = try_else(pd.read_hdf, ["eval_cache.h5", "data"])
eval_df = eval_df if eval_df is not None else pd.DataFrame(merge_data("validation"))
eval_df.columns = ["text_a", "text_b", "labels"]

test_df = try_else(pd.read_hdf, ["test_cache.h5", "data"])
test_df = test_df if test_df is not None else pd.DataFrame(merge_data("test"))
test_df.columns = ["text_a", "text_b", "labels"]

In [None]:
# Store into cache
train_df.to_hdf("train_cache.h5", "data")
eval_df.to_hdf("eval_cache.h5", "data")
test_df.to_hdf("test_cache.h5", "data")

In [None]:
import wandb
project = "da-silicone-combined"

# Turn off debug logs
import transformers
transformers.logging.set_verbosity_error()

# model configuration
from simpletransformers.classification import (
    ClassificationModel, ClassificationArgs
)

model_args = ClassificationArgs(
    n_gpu=4,
    num_train_epochs=8,
    learning_rate=3e-5,
    model_type="deberta", 
    model_name="microsoft/deberta-large",
    wandb_project=project,
    use_multiprocessing=False,
    train_batch_size=80,
    eval_batch_size=80,
    evaluate_during_training=True,
    evaluate_during_training_verbose=True,
    use_early_stopping=True,
    scheduler="polynomial_decay_schedule_with_warmup",
    overwrite_output_dir=True,
    evaluate_during_training_steps=1000,
)

# Create a ClassificationModel
model = ClassificationModel("deberta", "microsoft/deberta-large", args=model_args, num_labels=len(new_acts))

# Train model
model.train_model(train_df, eval_df=eval_df)

In [None]:
# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(
    eval_df
)

In [None]:
# Evaluate the model on the holdout set
result, model_outputs, wrong_predictions = model.eval_model(
    test_df
)

In [None]:
!gsutil cp -r ./outputs gs://internal.whitehead.ai/silicone/large-`date '+%m-%d-%Y_%H_%M_%S'`/outputs