In [None]:
# in a terminal run
# > USER_PATH=/home/krause/userdata/ make run-redis NS=train
# > USER_PATH=/home/krause/userdata/ make run-redis NS=test
# to allow access to the train and test namespaces

In [None]:
import os
import sys
import json
import pandas as pd
import numpy as np
from typing import Literal, TypedDict

In [None]:
sys.path.append("..")
os.environ["USER_PATH"] = "/home/krause/userdata/"
MODEL_OUTPUT_BASE = "/mnt/d/workspace/clotho/notebooks"
MODEL_OUTPUT_CP = os.path.join(MODEL_OUTPUT_BASE, "checkpoints")

In [None]:
from misc.redis import set_redis_slow_mode
from misc.util import highest_number
from misc.io import open_write
from model.datagenerator import create_train_test
from model.transformer_embed import (
    get_epoch_and_load,
    limit_epoch_data,
    limit_epoch_data,
    get_model_filename,
)
from system.namespace.store import get_namespace

In [None]:
import torch

is_cuda = torch.cuda.is_available()
is_cuda

In [None]:
set_redis_slow_mode("never")
ns_test = get_namespace("test")
ns_train = get_namespace("train")
now = pd.Timestamp("2022-12-17", tz="UTC")
train_plan = [
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 20,
        "last_epoch": None,
        "weight": 50,
    },
    # {
    #     "left": {"mode": "random", "flip_pc": 0.0},
    #     "right": {"mode": "path", "flip_pc": 0.0},
    #     "min_text_length": None,
    #     "skip_weak": True,
    #     "skip_topics": True,
    #     "flip_lr": 0.5,
    #     "first_epoch": None,
    #     "last_epoch": None,
    #     "weight": 60,
    # },
    # {
    #     "left": None,
    #     "right": {"mode": "path", "flip_pc": 0.0},
    #     "min_text_length": None,
    #     "skip_weak": True,
    #     "skip_topics": True,
    #     "flip_lr": 0.5,
    #     "first_epoch": None,
    #     "last_epoch": None,
    #     "weight": 40,
    # },
    # {
    #     "left": {"mode": "random", "flip_pc": 0.0},
    #     "right": {"mode": "path", "flip_pc": 0.0},
    #     "min_text_length": None,
    #     "skip_weak": False,
    #     "skip_topics": True,
    #     "flip_lr": 0.5,
    #     "first_epoch": 5,
    #     "last_epoch": None,
    #     "weight": 60,
    # },
    # {
    #     "left": None,
    #     "right": {"mode": "path", "flip_pc": 0.0},
    #     "min_text_length": None,
    #     "skip_weak": True,
    #     "skip_topics": True,
    #     "flip_lr": 0.5,
    #     "first_epoch": 5,
    #     "last_epoch": None,
    #     "weight": 40,
    # },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 15,
        "last_epoch": None,
        "weight": 50,
    }
]
eval_plan = [
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": False,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
]
ttgen = create_train_test(
    train_ns=ns_train,
    train_validation_ns=ns_train,
    test_ns=ns_test,
    test_validation_ns=ns_test,
    train_learning_plan=train_plan,
    train_val_learning_plan=eval_plan,
    test_learning_plan=eval_plan,
    test_val_learning_plan=eval_plan,
    batch_size=4 if is_cuda else 8,
    epoch_batches=5000 if is_cuda else 500,
    train_val_size=10000 if is_cuda else 1000,
    test_size=10000 if is_cuda else 1000,
    test_val_size=10000 if is_cuda else 1000,
    use_fast_gen_only=True,
    compute_batch_size=100 if is_cuda else 100,
    now=now)

In [None]:
import torch.nn as nn
from torch.optim import AdamW
from transformers import DistilBertTokenizer, DistilBertModel

In [None]:
device = torch.device("cuda") if is_cuda else torch.device("cpu")
device

In [None]:
from model.transformer_embed import (
    EMBED_SIZE,
    TokenizedInput,
    Model,
    BaselineModel,
    EitherModel,
    TrainingHarness,
    get_tokenizer,
)


tokens = get_tokenizer()

In [None]:
from transformers import get_scheduler
# from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import evaluate
import time


def create_model(version: int) -> EitherModel:
    return Model(version) if version >= 0 else BaselineModel(version)


def compute(harness, df):
    plefts = tokens(df["parent_left"].tolist())
    clefts = tokens(df["child_left"].tolist())
    prights = tokens(df["parent_right"].tolist())
    crights = tokens(df["child_right"].tolist())
    labels = torch.tensor(
        [~df["correct_is_right"], df["correct_is_right"]],
        dtype=torch.float32).T.to(device)
   
    preds, loss = harness(
        left={"parent": plefts, "child": clefts},
        right={"parent": prights, "child": crights},
        labels=labels)
    # TODO add selective push losses
    return preds, loss


def run_training(num_epochs, version, force_restart):
    model = create_model(version)
    model.to(device)
    harness = TrainingHarness(model)
    harness.to(device)

    mprev, epoch_offset = get_epoch_and_load(
        harness,
        MODEL_OUTPUT_CP,
        ftype="harness",
        is_cuda=is_cuda,
        device=device,
        force_restart=force_restart)

    optimizer = AdamW(harness.parameters(), lr=5e-5)
    print(mprev, epoch_offset)
    
    num_epochs -= epoch_offset
    if num_epochs <= 0:
        print("already computed all epochs. nothing to do!")
        return model, harness, optimizer
    
    num_training_steps = num_epochs * ttgen.get_epoch_train_size()
    warmup = 10000 if is_cuda else 10
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=warmup,
        num_training_steps=num_training_steps - warmup)
    ttgen.set_epoch(epoch_offset)
    
    log_csv = get_model_filename(
        harness,
        MODEL_OUTPUT_BASE,
        is_cuda=is_cuda,
        ftype="val_log",
        epoch=None,
        ext=".csv")
    columns = [
        "epoch",
        "train_acc",
        "train_loss",
        "train_val_acc",
        "train_val_loss",
        "test_acc",
        "test_loss",
        "time",
        "version",
        "fname",
    ]
    if not os.path.exists(log_csv):
        pd.DataFrame([], columns=columns).to_csv(
            log_csv, header=True, mode="w", columns=columns)

    for _ in range(num_epochs):
        epoch = ttgen.get_epoch()
        print(f"epoch {epoch} version: {harness.get_version()}")
        real_time = time.monotonic()

        model.train()
        harness.train()
        model.set_epoch(epoch)
        metric_train = evaluate.load("accuracy")
        train_loss = []
        first = True
        with tqdm(desc="train", total=ttgen.get_epoch_train_size()) as progress_bar:
            for train_df in ttgen.train_dfs():
                preds, loss = compute(harness, train_df)
                train_loss.append(loss.item())
                loss.backward()

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(train_df.shape[0])

                predictions = torch.argmax(preds, dim=-1)
                metric_train.add_batch(
                    predictions=predictions,
                    references=train_df["correct_is_right"].astype(int))
                if first:
                    # display(train_df)
                    first = False

        model_fname = get_model_filename(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="harness",
            epoch=epoch)
        torch.save(harness.state_dict(), model_fname)

        model.eval()
        harness.eval()
        with torch.no_grad():
            metric_val_train = evaluate.load("accuracy")
            train_val_loss = []
            with tqdm(desc="train val", total=ttgen.get_epoch_train_validation_size()) as progress_bar:
                for train_validation_df in ttgen.train_validation_dfs():
                    preds, loss = compute(harness, train_validation_df)
                    train_val_loss.append(loss.item())
                    predictions = torch.argmax(preds, dim=-1)
                    metric_val_train.add_batch(
                        predictions=predictions,
                        references=train_validation_df["correct_is_right"].astype(int))
                    progress_bar.update(train_validation_df.shape[0])

            metric_test = evaluate.load("accuracy")
            test_loss = []
            with tqdm(desc="test", total=ttgen.get_epoch_test_size()) as progress_bar:
                for test_df in ttgen.test_dfs():
                    preds, loss = compute(harness, test_df)
                    test_loss.append(loss.item())
                    predictions = torch.argmax(preds, dim=-1)
                    metric_test.add_batch(
                        predictions=predictions,
                        references=test_df["correct_is_right"].astype(int))
                    progress_bar.update(test_df.shape[0])
            stats = {
                "epoch": int(epoch),
                "train_acc": float(metric_train.compute()['accuracy']),
                "train_loss": float(np.mean(train_loss)),
                "train_val_acc": float(metric_val_train.compute()['accuracy']),
                "train_val_loss": float(np.mean(train_val_loss)),
                "test_acc": float(metric_test.compute()['accuracy']),
                "test_loss": float(np.mean(test_loss)),
                "time": 0.0,
                "version": harness.get_version(),
                "fname": model_fname,
            }

        print(f"train: {stats['train_acc']} loss: {stats['train_loss']}")
        print(f"train val: {stats['train_val_acc']} loss: {stats['train_val_loss']}")
        print(f"test: {stats['test_acc']} loss: {stats['test_loss']}")
        ttgen.advance_epoch()
        stats["time"] = float((time.monotonic() - real_time) / 60.0)
        print(f"epoch time: {stats['time']:.2f}min")
        stats_fn = get_model_filename(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="stats",
            epoch=epoch,
            ext=".json")
        with open_write(stats_fn, text=True) as fout:
            print(json.dumps(stats, indent=2, sort_keys=True), file=fout)
        stats_df = pd.DataFrame(
            {key: [val] for key, val in stats.items()},
            columns=columns)
        stats_df.to_csv(
            log_csv, header=False, mode="a")
            
        limit_epoch_data(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="stats",
            ext=".json",
            count=5)
    return model, harness, optimizer

In [None]:
def save_model(model, harness, optimizer):
    torch.save(model.state_dict(), get_model_filename(
        harness,
        MODEL_OUTPUT_BASE,
        is_cuda=is_cuda,
        ftype="model",
        epoch=None))
    torch.save(harness.state_dict(), get_model_filename(
        harness,
        MODEL_OUTPUT_BASE,
        is_cuda=is_cuda,
        ftype="harness",
        epoch=None))
    torch.save(optimizer.state_dict(), get_model_filename(
        harness,
        MODEL_OUTPUT_BASE,
        is_cuda=is_cuda,
        ftype="optimizer",
        epoch=None))

In [None]:
def validation(model, harness):
    ttgen.reset()
    model.eval()
    harness.eval()
    dfs = []
    with torch.no_grad():
        metric_val_test = evaluate.load("accuracy")
        test_val_loss = []
        with tqdm(desc="test val", total=ttgen.get_epoch_test_validation_size()) as progress_bar:
            for test_val_df in ttgen.test_validation_dfs():
                preds, loss = compute(harness, test_val_df)
                test_val_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_val_test.add_batch(
                    predictions=predictions,
                    references=test_val_df["correct_is_right"].astype(int))
                cur_df = test_val_df.copy()
                cur_df["logit_left"] = preds[:, 0].cpu()
                cur_df["logit_right"] = preds[:, 1].cpu()
                cur_df["preds"] = predictions.cpu()
                cur_df["truth"] = test_val_df["correct_is_right"].astype(int)
                dfs.append(cur_df)
                progress_bar.update(test_val_df.shape[0])
    print(f"test val: {metric_val_test.compute()} loss: {np.mean(test_val_loss)}")
    validation_df = pd.concat(dfs)
    validation_df.to_csv(get_model_filename(
        harness,
        MODEL_OUTPUT_BASE,
        is_cuda=is_cuda,
        ftype="validation",
        epoch=None,
        ext=".csv"))
    print("correct")
    display(validation_df[validation_df["preds"] == validation_df["truth"]].head())
    print("incorrect")
    display(validation_df[validation_df["preds"] != validation_df["truth"]].head())

In [None]:
def embeds(model):
    ttgen.reset()
    model.eval()
    with torch.no_grad():
        count = 0
        for test_val_df in ttgen.test_validation_dfs():
            plefts = tokens(test_val_df["parent_left"].tolist())
            clefts = tokens(test_val_df["child_left"].tolist())
            prights = tokens(test_val_df["parent_right"].tolist())
            crights = tokens(test_val_df["child_right"].tolist())
            display(model.get_child_embed(
                clefts["input_ids"],
                clefts["attention_mask"]).cpu().numpy())
            display(model.get_child_embed(
                crights["input_ids"],
                crights["attention_mask"]).cpu().numpy())
            count += 1
            if count >= 5:
                break

In [None]:
def full_run(*, num_epochs, version, force_restart):
    model, harness, optimizer = run_training(
        num_epochs, version, force_restart)
    save_model(model, harness, optimizer)
    validation(model, harness)
    embeds(model)

In [None]:
# for version in range(8):
#     full_run(num_epochs=30, version=version, force_restart=False)

In [None]:
full_run(num_epochs=10, version=-1, force_restart=False)

In [None]:
full_run(num_epochs=10, version=7, force_restart=False)

In [None]:
full_run(num_epochs=10, version=0, force_restart=False)

In [None]:
full_run(num_epochs=10, version=5, force_restart=False)

In [None]:
full_run(num_epochs=30, version=-1, force_restart=False)

In [None]:
full_run(num_epochs=30, version=7, force_restart=False)

In [None]:
full_run(num_epochs=30, version=0, force_restart=False)

In [None]:
full_run(num_epochs=30, version=5, force_restart=False)

In [None]:
full_run(num_epochs=60, version=-1, force_restart=False)

In [None]:
full_run(num_epochs=60, version=7, force_restart=False)

In [None]:
full_run(num_epochs=60, version=0, force_restart=False)

In [None]:
full_run(num_epochs=60, version=5, force_restart=False)

In [None]:
full_run(num_epochs=90, version=-1, force_restart=False)

In [None]:
full_run(num_epochs=90, version=7, force_restart=False)

In [None]:
full_run(num_epochs=90, version=0, force_restart=False)

In [None]:
full_run(num_epochs=90, version=5, force_restart=False)

In [None]:
full_run(num_epochs=120, version=-1, force_restart=False)

In [None]:
full_run(num_epochs=120, version=7, force_restart=False)

In [None]:
full_run(num_epochs=120, version=0, force_restart=False)

In [None]:
full_run(num_epochs=120, version=5, force_restart=False)