In [1]:
# in a terminal run
# > USER_PATH=/mnt/d/workspace/clotho/userdata/ make run-redis NS=train
# > USER_PATH=/mnt/d/workspace/clotho/userdata/ make run-redis NS=test
# to allow access to the train and test namespaces

In [2]:
import os
import sys
import json
import pandas as pd
import numpy as np
from typing import Literal, TypedDict

In [3]:
sys.path.append("..")
os.environ["USER_PATH"] = "/mnt/d/workspace/clotho/userdata/"
MODEL_OUTPUT_BASE = "/mnt/d/workspace/clotho/notebooks"
MODEL_OUTPUT_CP = os.path.join(MODEL_OUTPUT_BASE, "checkpoints")

In [4]:
from misc.redis import set_redis_slow_mode
from misc.util import highest_number
from misc.io import open_write
from model.datagenerator import create_train_test
from model.transformer_embed import (
    get_epoch_and_load,
    limit_epoch_data,
    limit_epoch_data,
    get_model_filename,
)
from system.namespace.store import get_namespace

In [5]:
import torch

is_cuda = torch.cuda.is_available()
is_cuda

True

In [6]:
set_redis_slow_mode("never")
ns_test = get_namespace("test")
ns_train = get_namespace("train")
now = pd.Timestamp("2022-12-17", tz="UTC")
train_plan = [
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
     {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 15,
        "last_epoch": None,
        "weight": 50,
    }
]
eval_plan = [
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
]
ttgen = create_train_test(
    train_ns=ns_train,
    train_validation_ns=ns_train,
    test_ns=ns_test,
    test_validation_ns=ns_test,
    train_learning_plan=train_plan,
    train_val_learning_plan=eval_plan,
    test_learning_plan=eval_plan,
    test_val_learning_plan=eval_plan,
    batch_size=4 if is_cuda else 8,
    epoch_batches=5000 if is_cuda else 500,
    train_val_size=10000 if is_cuda else 1000,
    test_size=10000 if is_cuda else 1000,
    test_val_size=10000 if is_cuda else 1000,
    compute_batch_size=100 if is_cuda else 100,
    now=now)

In [7]:
import torch.nn as nn
from torch.optim import AdamW
from transformers import DistilBertTokenizer, DistilBertModel

In [8]:
device = torch.device("cuda") if is_cuda else torch.device("cpu")
device

device(type='cuda')

In [9]:
ProviderRole = Literal["child", "parent"]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
EMBED_SIZE = 768

TokenizedInput = TypedDict('TokenizedInput', {
    "input_ids": torch.Tensor,
    "attention_mask": torch.Tensor,
})


AggType = Literal["cls", "mean"]
AGG_CLS: AggType = "cls"
AGG_MEAN: AggType = "mean"


def tokens(texts: list[str]) -> TokenizedInput:
    res = tokenizer(texts.tolist(), return_tensors="pt", padding=True, truncation=True)
    return {k: v.to(device) for k, v in res.items()}


class Noise(nn.Module):
    def __init__(self, std: float = 1.0, p: float = 0.5) -> None:
        super().__init__()
        self._std = std
        self._p = p
        self._dhold = nn.Parameter(torch.Tensor([0.0]), requires_grad=False)

    def set_std(self, std: float) -> None:
        self._std = std

    def get_std(self) -> float:
        return self._std

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return x
        prob = torch.rand(size=x.shape, device=self._dhold.device) < self._p
        gauss = torch.normal(
            mean=0.0, std=self._std, size=x.shape, device=self._dhold.device)
        return x + prob * gauss


class Model(nn.Module):
    def __init__(self, version: int) -> None:
        super().__init__()
        self._bert_parent = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        self._bert_child = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        if version in (1, 3, 4, 6):
            self._pdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
            self._cdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
        else:
            self._pdense = None
            self._cdense = None
        if version < 4 or version > 5:
            self._noise = None
        else:
            self._noise = Noise(std=1.0, p=0.2)
        if version < 2 or version > 4:
            self._cos = None
        else:
            self._cos = torch.nn.CosineSimilarity()
        if version < 6:
            self._agg = AGG_CLS
        else:
            self._agg = AGG_MEAN
        self._version = version

    def set_epoch(self, epoch: int) -> None:
        noise = self._noise
        if noise is not None:
            noise.set_std(1 / (1.2 ** epoch))

    def get_version(self) -> int:
        return self._version
    
    def get_agg(self, lhs: torch.Tensor) -> torch.Tensor:
        if self._agg == AGG_CLS:
            return lhs[:, 0]
        if self._agg == AGG_MEAN:
            return torch.mean(lhs, dim=1)
        raise ValueError(f"unknown aggregation: {self._agg}")

    def get_parent_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_parent = self._bert_parent(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_parent.last_hidden_state)
        if self._pdense is not None:
            out = self._pdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def get_child_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_child = self._bert_child(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_child.last_hidden_state)
        if self._cdense is not None:
            out = self._cdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def forward(self, x: dict[ProviderRole, TokenizedInput]) -> torch.Tensor:
        parent_cls = self.get_parent_embed(
            input_ids=x["parent"]["input_ids"],
            attention_mask=x["parent"]["attention_mask"])
        child_cls = self.get_child_embed(
            input_ids=x["child"]["input_ids"],
            attention_mask=x["child"]["attention_mask"])
        if self._cos is not None:
            return self._cos(parent_cls, child_cls).reshape([-1, 1])
        batch_size = parent_cls.shape[0]
        return torch.bmm(
            parent_cls.reshape([batch_size, 1, -1]),
            child_cls.reshape([batch_size, -1, 1])).reshape([-1, 1])


class TrainingHarness(nn.Module):
    def __init__(self, model: Model) -> None:
        super().__init__()
        self._model = model
        self._softmax = nn.Softmax(dim=1)
        self._loss = nn.BCELoss()

    def get_version(self) -> int:
        return self._model.get_version()

    def forward(
            self,
            left: TokenizedInput,
            right: TokenizedInput,
            labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        out_left = self._model(left)
        out_right = self._model(right)
        preds = self._softmax(torch.hstack((out_left, out_right)))
        return preds, self._loss(preds, labels)

In [10]:
FORCE_RESTART = False

In [11]:
from transformers import get_scheduler
# from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import evaluate
import time


def compute(harness, df):
    plefts = tokens(df["parent_left"])
    clefts = tokens(df["child_left"])
    prights = tokens(df["parent_right"])
    crights = tokens(df["child_right"])
    labels = torch.tensor(
        [~df["correct_is_right"], df["correct_is_right"]],
        dtype=torch.float32).T.to(device)
    return harness(
        left={"parent": plefts, "child": clefts},
        right={"parent": prights, "child": crights},
        labels=labels)


def run_training(num_epochs, version):
    model = Model(version=version)
    model.to(device)
    harness = TrainingHarness(model)
    harness.to(device)

    mprev, epoch_offset = get_epoch_and_load(
        harness,
        MODEL_OUTPUT_CP,
        ftype="harness",
        is_cuda=is_cuda,
        device=device,
        force_restart=FORCE_RESTART)

    optimizer = AdamW(harness.parameters(), lr=5e-5)
    print(mprev, epoch_offset)
    
    num_epochs -= epoch_offset
    if num_epochs <= 0:
        print("already computed all epochs. nothing to do!")
        return model, harness, optimizer
    
    num_training_steps = num_epochs * ttgen.get_epoch_train_size()
    warmup = 10000 if is_cuda else 10
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=warmup,
        num_training_steps=num_training_steps - warmup)
    ttgen.set_epoch(epoch_offset)

    for _ in range(num_epochs):
        epoch = ttgen.get_epoch()
        print(f"epoch {epoch} version: {harness.get_version()}")
        real_time = time.monotonic()

        model.train()
        harness.train()
        model.set_epoch(epoch)
        metric_train = evaluate.load("accuracy")
        train_loss = []
        first = True
        with tqdm(desc="train", total=ttgen.get_epoch_train_size()) as progress_bar:
            for train_df in ttgen.train_dfs():
                preds, loss = compute(harness, train_df)
                train_loss.append(loss.item())
                loss.backward()

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(train_df.shape[0])

                predictions = torch.argmax(preds, dim=-1)
                metric_train.add_batch(
                    predictions=predictions,
                    references=train_df["correct_is_right"].astype(int))
                if first:
                    # display(train_df)
                    first = False

        model_fname = get_model_filename(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="harness",
            epoch=epoch)
        torch.save(harness.state_dict(), model_fname)

        model.eval()
        harness.eval()
        with torch.no_grad():
            metric_val_train = evaluate.load("accuracy")
            train_val_loss = []
            with tqdm(desc="train val", total=ttgen.get_epoch_train_validation_size()) as progress_bar:
                for train_validation_df in ttgen.train_validation_dfs():
                    preds, loss = compute(harness, train_validation_df)
                    train_val_loss.append(loss.item())
                    predictions = torch.argmax(preds, dim=-1)
                    metric_val_train.add_batch(
                        predictions=predictions,
                        references=train_validation_df["correct_is_right"].astype(int))
                    progress_bar.update(train_validation_df.shape[0])

            metric_test = evaluate.load("accuracy")
            test_loss = []
            with tqdm(desc="test", total=ttgen.get_epoch_test_size()) as progress_bar:
                for test_df in ttgen.test_dfs():
                    preds, loss = compute(harness, test_df)
                    test_loss.append(loss.item())
                    predictions = torch.argmax(preds, dim=-1)
                    metric_test.add_batch(
                        predictions=predictions,
                        references=test_df["correct_is_right"].astype(int))
                    progress_bar.update(test_df.shape[0])
            stats = {
                "epoch": int(epoch),
                "train_acc": float(metric_train.compute()['accuracy']),
                "train_loss": float(np.mean(train_loss)),
                "train_val_acc": float(metric_val_train.compute()['accuracy']),
                "train_val_loss": float(np.mean(train_val_loss)),
                "test_acc": float(metric_test.compute()['accuracy']),
                "test_loss": float(np.mean(test_loss)),
                "time": 0.0,
                "version": harness.get_version(),
                "fname": model_fname,
            }

        print(f"train: {stats['train_acc']} loss: {stats['train_loss']}")
        print(f"train val: {stats['train_val_acc']} loss: {stats['train_val_loss']}")
        print(f"test: {stats['test_acc']} loss: {stats['test_loss']}")
        ttgen.advance_epoch()
        stats["time"] = float((time.monotonic() - real_time) / 60.0)
        print(f"epoch time: {stats['time']:.2f}min")
        stats_fn = get_model_filename(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="stats",
            epoch=epoch,
            ext=".json")
        with open_write(stats_fn, text=True) as fout:
            print(json.dumps(stats, indent=2, sort_keys=True), file=fout)
            
        limit_epoch_data(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="stats",
            ext=".json",
            count=5)
    return model, harness, optimizer

In [None]:
model_final, harness_final, optimizer_final = run_training(120, 7)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- T

('harness_v7_lg_74.pkl', 74) 75
epoch 75 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
torch.save(model_final.state_dict(), get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="model",
    epoch=None))
torch.save(harness_final.state_dict(), get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="harness",
    epoch=None))
torch.save(optimizer_final.state_dict(), get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="optimizer",
    epoch=None))

In [None]:
ttgen.reset()
model_final.eval()
harness_final.eval()
dfs = []
with torch.no_grad():
    metric_val_test = evaluate.load("accuracy")
    test_val_loss = []
    with tqdm(desc="test val", total=ttgen.get_epoch_test_validation_size()) as progress_bar:
        for test_val_df in ttgen.test_validation_dfs():
            preds, loss = compute(harness_final, test_val_df)
            test_val_loss.append(loss.item())
            predictions = torch.argmax(preds, dim=-1)
            metric_val_test.add_batch(
                predictions=predictions,
                references=test_val_df["correct_is_right"].astype(int))
            cur_df = test_val_df.copy()
            cur_df["logit_left"] = preds[:, 0].cpu()
            cur_df["logit_right"] = preds[:, 1].cpu()
            cur_df["preds"] = predictions.cpu()
            cur_df["truth"] = test_val_df["correct_is_right"].astype(int)
            dfs.append(cur_df)
            progress_bar.update(test_val_df.shape[0])
print(f"test val: {metric_val_test.compute()} loss: {np.mean(test_val_loss)}")
validation_df = pd.concat(dfs)

In [None]:
validation_df.to_csv(get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="validation",
    epoch=None,
    ext=".csv"))

In [None]:
validation_df[validation_df["preds"] == validation_df["truth"]].head()

In [None]:
validation_df[validation_df["preds"] != validation_df["truth"]].head()

In [None]:
ttgen.reset()
model_final.eval()
harness_final.eval()
with torch.no_grad():
    count = 0
    for test_val_df in ttgen.test_validation_dfs():
        plefts = tokens(test_val_df["parent_left"])
        clefts = tokens(test_val_df["child_left"])
        prights = tokens(test_val_df["parent_right"])
        crights = tokens(test_val_df["child_right"])
        display(model_final.get_child_embed(
            clefts["input_ids"],
            clefts["attention_mask"]).cpu().numpy())
        display(model_final.get_child_embed(
            crights["input_ids"],
            crights["attention_mask"]).cpu().numpy())
        count += 1
        if count >= 5:
            break