In [1]:
# in a terminal run
# > USER_PATH=/mnt/d/workspace/clotho/userdata/ make run-redis NS=train
# > USER_PATH=/mnt/d/workspace/clotho/userdata/ make run-redis NS=test
# to allow access to the train and test namespaces

In [2]:
import os
import sys
import json
import pandas as pd
import numpy as np
from typing import Literal, TypedDict

In [3]:
sys.path.append("..")
os.environ["USER_PATH"] = "/mnt/d/workspace/clotho/userdata/"
MODEL_OUTPUT_BASE = "/mnt/d/workspace/clotho/notebooks"
MODEL_OUTPUT_CP = os.path.join(MODEL_OUTPUT_BASE, "checkpoints")

In [4]:
from misc.redis import set_redis_slow_mode
from misc.util import highest_number
from misc.io import open_write
from model.datagenerator import create_train_test
from model.transformer_embed import (
    get_epoch_and_load,
    limit_epoch_data,
    limit_epoch_data,
    get_model_filename,
)
from system.namespace.store import get_namespace

In [5]:
import torch

is_cuda = torch.cuda.is_available()
is_cuda

True

In [6]:
set_redis_slow_mode("never")
ns_test = get_namespace("test")
ns_train = get_namespace("train")
now = pd.Timestamp("2022-12-17", tz="UTC")
train_plan = [
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
     {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 15,
        "last_epoch": None,
        "weight": 50,
    }
]
eval_plan = [
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
]
ttgen = create_train_test(
    train_ns=ns_train,
    train_validation_ns=ns_train,
    test_ns=ns_test,
    test_validation_ns=ns_test,
    train_learning_plan=train_plan,
    train_val_learning_plan=eval_plan,
    test_learning_plan=eval_plan,
    test_val_learning_plan=eval_plan,
    batch_size=4 if is_cuda else 8,
    epoch_batches=5000 if is_cuda else 500,
    train_val_size=10000 if is_cuda else 1000,
    test_size=10000 if is_cuda else 1000,
    test_val_size=10000 if is_cuda else 1000,
    compute_batch_size=100 if is_cuda else 100,
    now=now)

In [7]:
import torch.nn as nn
from torch.optim import AdamW
from transformers import DistilBertTokenizer, DistilBertModel

In [8]:
device = torch.device("cuda") if is_cuda else torch.device("cpu")
device

device(type='cuda')

In [9]:
ProviderRole = Literal["child", "parent"]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
EMBED_SIZE = 768

TokenizedInput = TypedDict('TokenizedInput', {
    "input_ids": torch.Tensor,
    "attention_mask": torch.Tensor,
})


AggType = Literal["cls", "mean"]
AGG_CLS: AggType = "cls"
AGG_MEAN: AggType = "mean"


def tokens(texts: list[str]) -> TokenizedInput:
    res = tokenizer(texts.tolist(), return_tensors="pt", padding=True, truncation=True)
    return {k: v.to(device) for k, v in res.items()}


class Noise(nn.Module):
    def __init__(self, std: float = 1.0, p: float = 0.5) -> None:
        super().__init__()
        self._std = std
        self._p = p
        self._dhold = nn.Parameter(torch.Tensor([0.0]), requires_grad=False)

    def set_std(self, std: float) -> None:
        self._std = std

    def get_std(self) -> float:
        return self._std

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return x
        prob = torch.rand(size=x.shape, device=self._dhold.device) < self._p
        gauss = torch.normal(
            mean=0.0, std=self._std, size=x.shape, device=self._dhold.device)
        return x + prob * gauss


class Model(nn.Module):
    def __init__(self, version: int) -> None:
        super().__init__()
        self._bert_parent = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        self._bert_child = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        if version in (1, 3, 4, 6):
            self._pdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
            self._cdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
        else:
            self._pdense = None
            self._cdense = None
        if version < 4 or version > 5:
            self._noise = None
        else:
            self._noise = Noise(std=1.0, p=0.2)
        if version < 2 or version > 4:
            self._cos = None
        else:
            self._cos = torch.nn.CosineSimilarity()
        if version < 6:
            self._agg = AGG_CLS
        else:
            self._agg = AGG_MEAN
        self._version = version

    def set_epoch(self, epoch: int) -> None:
        noise = self._noise
        if noise is not None:
            noise.set_std(1 / (1.2 ** epoch))

    def get_version(self) -> int:
        return self._version
    
    def get_agg(self, lhs: torch.Tensor) -> torch.Tensor:
        if self._agg == AGG_CLS:
            return lhs[:, 0]
        if self._agg == AGG_MEAN:
            return torch.mean(lhs, dim=1)
        raise ValueError(f"unknown aggregation: {self._agg}")

    def get_parent_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_parent = self._bert_parent(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_parent.last_hidden_state)
        if self._pdense is not None:
            out = self._pdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def get_child_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_child = self._bert_child(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_child.last_hidden_state)
        if self._cdense is not None:
            out = self._cdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def forward(self, x: dict[ProviderRole, TokenizedInput]) -> torch.Tensor:
        parent_cls = self.get_parent_embed(
            input_ids=x["parent"]["input_ids"],
            attention_mask=x["parent"]["attention_mask"])
        child_cls = self.get_child_embed(
            input_ids=x["child"]["input_ids"],
            attention_mask=x["child"]["attention_mask"])
        if self._cos is not None:
            return self._cos(parent_cls, child_cls).reshape([-1, 1])
        batch_size = parent_cls.shape[0]
        return torch.bmm(
            parent_cls.reshape([batch_size, 1, -1]),
            child_cls.reshape([batch_size, -1, 1])).reshape([-1, 1])


class TrainingHarness(nn.Module):
    def __init__(self, model: Model) -> None:
        super().__init__()
        self._model = model
        self._softmax = nn.Softmax(dim=1)
        self._loss = nn.BCELoss()

    def get_version(self) -> int:
        return self._model.get_version()

    def forward(
            self,
            left: TokenizedInput,
            right: TokenizedInput,
            labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        out_left = self._model(left)
        out_right = self._model(right)
        preds = self._softmax(torch.hstack((out_left, out_right)))
        return preds, self._loss(preds, labels)

In [10]:
FORCE_RESTART = False

In [11]:
from transformers import get_scheduler
# from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import evaluate
import time


def compute(harness, df):
    plefts = tokens(df["parent_left"])
    clefts = tokens(df["child_left"])
    prights = tokens(df["parent_right"])
    crights = tokens(df["child_right"])
    labels = torch.tensor(
        [~df["correct_is_right"], df["correct_is_right"]],
        dtype=torch.float32).T.to(device)
    return harness(
        left={"parent": plefts, "child": clefts},
        right={"parent": prights, "child": crights},
        labels=labels)


def run_training(num_epochs, version):
    model = Model(version=version)
    model.to(device)
    harness = TrainingHarness(model)
    harness.to(device)

    mprev, epoch_offset = get_epoch_and_load(
        harness,
        MODEL_OUTPUT_CP,
        ftype="harness",
        is_cuda=is_cuda,
        device=device,
        force_restart=FORCE_RESTART)

    optimizer = AdamW(harness.parameters(), lr=5e-5)
    print(mprev, epoch_offset)
    
    num_epochs -= epoch_offset
    if num_epochs <= 0:
        print("already computed all epochs. nothing to do!")
        return model, harness, optimizer
    
    num_training_steps = num_epochs * ttgen.get_epoch_train_size()
    warmup = 10000 if is_cuda else 10
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=warmup,
        num_training_steps=num_training_steps - warmup)
    ttgen.set_epoch(epoch_offset)

    for _ in range(num_epochs):
        epoch = ttgen.get_epoch()
        print(f"epoch {epoch} version: {harness.get_version()}")
        real_time = time.monotonic()

        model.train()
        harness.train()
        model.set_epoch(epoch)
        metric_train = evaluate.load("accuracy")
        train_loss = []
        first = True
        with tqdm(desc="train", total=ttgen.get_epoch_train_size()) as progress_bar:
            for train_df in ttgen.train_dfs():
                preds, loss = compute(harness, train_df)
                train_loss.append(loss.item())
                loss.backward()

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(train_df.shape[0])

                predictions = torch.argmax(preds, dim=-1)
                metric_train.add_batch(
                    predictions=predictions,
                    references=train_df["correct_is_right"].astype(int))
                if first:
                    # display(train_df)
                    first = False

        model_fname = get_model_filename(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="harness",
            epoch=epoch)
        torch.save(harness.state_dict(), model_fname)

        model.eval()
        harness.eval()
        with torch.no_grad():
            metric_val_train = evaluate.load("accuracy")
            train_val_loss = []
            with tqdm(desc="train val", total=ttgen.get_epoch_train_validation_size()) as progress_bar:
                for train_validation_df in ttgen.train_validation_dfs():
                    preds, loss = compute(harness, train_validation_df)
                    train_val_loss.append(loss.item())
                    predictions = torch.argmax(preds, dim=-1)
                    metric_val_train.add_batch(
                        predictions=predictions,
                        references=train_validation_df["correct_is_right"].astype(int))
                    progress_bar.update(train_validation_df.shape[0])

            metric_test = evaluate.load("accuracy")
            test_loss = []
            with tqdm(desc="test", total=ttgen.get_epoch_test_size()) as progress_bar:
                for test_df in ttgen.test_dfs():
                    preds, loss = compute(harness, test_df)
                    test_loss.append(loss.item())
                    predictions = torch.argmax(preds, dim=-1)
                    metric_test.add_batch(
                        predictions=predictions,
                        references=test_df["correct_is_right"].astype(int))
                    progress_bar.update(test_df.shape[0])
            stats = {
                "epoch": int(epoch),
                "train_acc": float(metric_train.compute()['accuracy']),
                "train_loss": float(np.mean(train_loss)),
                "train_val_acc": float(metric_val_train.compute()['accuracy']),
                "train_val_loss": float(np.mean(train_val_loss)),
                "test_acc": float(metric_test.compute()['accuracy']),
                "test_loss": float(np.mean(test_loss)),
                "time": 0.0,
                "version": harness.get_version(),
                "fname": model_fname,
            }

        print(f"train: {stats['train_acc']} loss: {stats['train_loss']}")
        print(f"train val: {stats['train_val_acc']} loss: {stats['train_val_loss']}")
        print(f"test: {stats['test_acc']} loss: {stats['test_loss']}")
        ttgen.advance_epoch()
        stats["time"] = float((time.monotonic() - real_time) / 60.0)
        print(f"epoch time: {stats['time']:.2f}min")
        stats_fn = get_model_filename(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="stats",
            epoch=epoch,
            ext=".json")
        with open_write(stats_fn, text=True) as fout:
            print(json.dumps(stats, indent=2, sort_keys=True), file=fout)
            
        limit_epoch_data(
            harness,
            MODEL_OUTPUT_CP,
            is_cuda=is_cuda,
            ftype="stats",
            ext=".json",
            count=5)
    return model, harness, optimizer

In [12]:
model_final, harness_final, optimizer_final = run_training(120, 7)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- T

('harness_v7_lg_74.pkl', 74) 75
epoch 75 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87725 loss: 0.24688319284539606
train val: 0.7725 loss: 0.4387943536286475
test: 0.5912 loss: 0.6844999013572931
epoch time: 475.77min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_0.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_0.pkl
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_3.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_3.pkl
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_6.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_6.pkl
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_1.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_1.pkl
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_9.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_9.pkl
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_10.json
removing /mnt/d/workspace/clotho/not

train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87635 loss: 0.25244936092717707
train val: 0.7763 loss: 0.43755752810058185
test: 0.5892 loss: 0.6646364462137222
epoch time: 47.66min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_69.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_69.pkl
epoch 77 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8761 loss: 0.25424676765135285
train val: 0.7701 loss: 0.43542541802176277
test: 0.596 loss: 0.6690252234339714
epoch time: 40.65min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_71.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_71.pkl
epoch 78 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8778 loss: 0.2523328201648331
train val: 0.7705 loss: 0.44131886511325136
test: 0.6026 loss: 0.6695084605455398
epoch time: 38.07min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_64.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_64.pkl
epoch 79 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8725 loss: 0.2625613520741983
train val: 0.7597 loss: 0.45352330187449696
test: 0.5754 loss: 0.6992761222839355
epoch time: 34.22min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_2.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_2.pkl
epoch 80 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8722 loss: 0.26530608701907005
train val: 0.769 loss: 0.44649667372216933
test: 0.6043 loss: 0.675029324349761
epoch time: 31.37min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_79.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_79.pkl
epoch 81 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87295 loss: 0.25997082502843316
train val: 0.7658 loss: 0.46186705019338986
test: 0.5974 loss: 0.7008182034432888
epoch time: 31.08min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_70.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_70.pkl
epoch 82 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8722 loss: 0.26160743072850184
train val: 0.7623 loss: 0.44856447543819084
test: 0.606 loss: 0.6810746724814176
epoch time: 30.40min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_81.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_81.pkl
epoch 83 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87125 loss: 0.2627661601345455
train val: 0.7626 loss: 0.4402225275948178
test: 0.5913 loss: 0.6718370355010033
epoch time: 29.97min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_82.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_82.pkl
epoch 84 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87355 loss: 0.26206800372047073
train val: 0.7611 loss: 0.46427033837419585
test: 0.5987 loss: 0.697297301711142
epoch time: 29.14min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_83.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_83.pkl
epoch 85 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85835 loss: 0.28622868821649916
train val: 0.765 loss: 0.46972758333991516
test: 0.6099 loss: 0.6808240371406078
epoch time: 29.68min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_84.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_84.pkl
epoch 86 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8695 loss: 0.26622423289322755
train val: 0.7652 loss: 0.4519058105667413
test: 0.594 loss: 0.7056208173274994
epoch time: 28.65min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_85.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_85.pkl
epoch 87 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8733 loss: 0.2569895461262684
train val: 0.7641 loss: 0.4479742307905108
test: 0.6108 loss: 0.6704469347596168
epoch time: 28.66min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_86.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_86.pkl
epoch 88 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87055 loss: 0.26610853414593805
train val: 0.7636 loss: 0.4446240754237864
test: 0.6032 loss: 0.665110944211483
epoch time: 27.63min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_87.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_87.pkl
epoch 89 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87375 loss: 0.2630965968995374
train val: 0.7561 loss: 0.45062036986213644
test: 0.5814 loss: 0.6870256122112274
epoch time: 28.15min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_88.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_88.pkl
epoch 90 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8754 loss: 0.25624658458809424
train val: 0.7704 loss: 0.4358079427268589
test: 0.6053 loss: 0.667498233628273
epoch time: 28.05min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_89.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_89.pkl
epoch 91 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87985 loss: 0.2479885182559723
train val: 0.7647 loss: 0.45600684347779025
test: 0.5914 loss: 0.7143267314553261
epoch time: 27.39min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_80.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_80.pkl
epoch 92 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8788 loss: 0.2506342374682619
train val: 0.7704 loss: 0.4520729850175951
test: 0.6124 loss: 0.6624875483214855
epoch time: 27.03min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_91.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_91.pkl
epoch 93 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87805 loss: 0.25051926591781737
train val: 0.7674 loss: 0.44046729241551363
test: 0.586 loss: 0.6847463314771652
epoch time: 27.49min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_77.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_77.pkl
epoch 94 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87755 loss: 0.24928802771865777
train val: 0.7695 loss: 0.4417042987279856
test: 0.5955 loss: 0.6861962861299514
epoch time: 26.89min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_93.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_93.pkl
epoch 95 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8762 loss: 0.2582805481324741
train val: 0.7708 loss: 0.4517718021281296
test: 0.6003 loss: 0.6907281194210052
epoch time: 27.42min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_94.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_94.pkl
epoch 96 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8817 loss: 0.24657185001686158
train val: 0.7744 loss: 0.44253671838219744
test: 0.6065 loss: 0.6805795713663101
epoch time: 25.86min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_90.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_90.pkl
epoch 97 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8828 loss: 0.2513664064303293
train val: 0.7645 loss: 0.4531347836035071
test: 0.6081 loss: 0.6727110549867154
epoch time: 27.32min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_92.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_92.pkl
epoch 98 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8754 loss: 0.25829368786395934
train val: 0.7677 loss: 0.45566284569605925
test: 0.5867 loss: 0.7068396338045597
epoch time: 26.55min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_97.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_97.pkl
epoch 99 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8764 loss: 0.25544844394708927
train val: 0.7741 loss: 0.4416697927334637
test: 0.608 loss: 0.6697600438058376
epoch time: 26.67min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_98.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_98.pkl
epoch 100 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8788 loss: 0.24969321130800234
train val: 0.7731 loss: 0.44840499472402295
test: 0.6212 loss: 0.6849669998645782
epoch time: 28.94min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_78.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_78.pkl
epoch 101 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8773 loss: 0.2511974701792875
train val: 0.771 loss: 0.4349648530759383
test: 0.6075 loss: 0.6764966611862182
epoch time: 29.65min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_95.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_95.pkl
epoch 102 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8853 loss: 0.24083072516087722
train val: 0.7714 loss: 0.43221013193243707
test: 0.6206 loss: 0.6752425873011351
epoch time: 28.95min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_101.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_101.pkl
epoch 103 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8825 loss: 0.23998856845258196
train val: 0.7748 loss: 0.4448689769520599
test: 0.6013 loss: 0.695455900478363
epoch time: 29.01min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_102.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_102.pkl
epoch 104 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8805 loss: 0.2488610867795762
train val: 0.7753 loss: 0.42810474714923474
test: 0.5821 loss: 0.6937290334641933
epoch time: 29.80min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_75.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_75.pkl
epoch 105 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8826 loss: 0.239970328234123
train val: 0.7769 loss: 0.4244856853481557
test: 0.591 loss: 0.6857752368927001
epoch time: 30.28min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_100.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_100.pkl
epoch 106 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.885 loss: 0.23656932126848004
train val: 0.7796 loss: 0.4347489410510665
test: 0.6193 loss: 0.6697787057220936
epoch time: 30.16min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_99.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_99.pkl
epoch 107 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88285 loss: 0.24064298320482266
train val: 0.7738 loss: 0.4319780615868862
test: 0.5972 loss: 0.6828130581438542
epoch time: 29.56min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_96.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_96.pkl
epoch 108 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88285 loss: 0.24225432667944224
train val: 0.7815 loss: 0.42000210899477824
test: 0.6075 loss: 0.6715290419638157
epoch time: 30.61min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_107.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_107.pkl
epoch 109 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88445 loss: 0.2839179772841275
train val: 0.7827 loss: 0.42252744458148955
test: 0.6056 loss: 0.6843017982423305
epoch time: 30.90min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_103.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_103.pkl
epoch 110 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88765 loss: 0.23067306373101054
train val: 0.7823 loss: 0.425790573313387
test: 0.6053 loss: 0.6718434534907342
epoch time: 33.89min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_104.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_104.pkl
epoch 111 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8825 loss: 0.24012364380836262
train val: 0.7835 loss: 0.4234614977559482
test: 0.5987 loss: 0.6894623760938644
epoch time: 32.33min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_76.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_76.pkl
epoch 112 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8851 loss: 0.24219948006386635
train val: 0.7763 loss: 0.4253914969589387
test: 0.5916 loss: 0.675720176833868
epoch time: 31.25min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_105.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_105.pkl
epoch 113 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88625 loss: 0.2442739518488721
train val: 0.7823 loss: 0.4218938563510659
test: 0.6132 loss: 0.6673646843135357
epoch time: 31.58min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_112.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_112.pkl
epoch 114 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88875 loss: 0.23533925600915317
train val: 0.7864 loss: 0.42567152839284683
test: 0.5995 loss: 0.6751633437901735
epoch time: 40.18min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_106.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_106.pkl
epoch 115 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88605 loss: 0.26067570026401954
train val: 0.7829 loss: 0.5178981874867633
test: 0.6071 loss: 0.7001509920060635
epoch time: 59.90min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_108.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_108.pkl
epoch 116 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8862 loss: 0.2669936276885247
train val: 0.7841 loss: 0.4300037252131704
test: 0.6006 loss: 0.6978086390912532
epoch time: 52.64min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_110.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_110.pkl
epoch 117 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88795 loss: 0.23544685338197177
train val: 0.786 loss: 0.42369362017799866
test: 0.6195 loss: 0.6682990686833858
epoch time: 41.77min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_113.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_113.pkl
epoch 118 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.88495 loss: 0.2377157407395076
train val: 0.7876 loss: 0.41121565336245985
test: 0.6055 loss: 0.6637235795557499
epoch time: 38.51min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_109.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_109.pkl
epoch 119 version: 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.89205 loss: 0.22627422869708816
train val: 0.7835 loss: 0.4157945149488747
test: 0.625 loss: 0.6556196522563696
epoch time: 35.52min
removing /mnt/d/workspace/clotho/notebooks/checkpoints/stats_v7_lg_115.json
removing /mnt/d/workspace/clotho/notebooks/checkpoints/harness_v7_lg_115.pkl


In [13]:
torch.save(model_final.state_dict(), get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="model",
    epoch=None))
torch.save(harness_final.state_dict(), get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="harness",
    epoch=None))
torch.save(optimizer_final.state_dict(), get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="optimizer",
    epoch=None))

In [14]:
ttgen.reset()
model_final.eval()
harness_final.eval()
dfs = []
with torch.no_grad():
    metric_val_test = evaluate.load("accuracy")
    test_val_loss = []
    with tqdm(desc="test val", total=ttgen.get_epoch_test_validation_size()) as progress_bar:
        for test_val_df in ttgen.test_validation_dfs():
            preds, loss = compute(harness_final, test_val_df)
            test_val_loss.append(loss.item())
            predictions = torch.argmax(preds, dim=-1)
            metric_val_test.add_batch(
                predictions=predictions,
                references=test_val_df["correct_is_right"].astype(int))
            cur_df = test_val_df.copy()
            cur_df["logit_left"] = preds[:, 0].cpu()
            cur_df["logit_right"] = preds[:, 1].cpu()
            cur_df["preds"] = predictions.cpu()
            cur_df["truth"] = test_val_df["correct_is_right"].astype(int)
            dfs.append(cur_df)
            progress_bar.update(test_val_df.shape[0])
print(f"test val: {metric_val_test.compute()} loss: {np.mean(test_val_loss)}")
validation_df = pd.concat(dfs)

test val:   0%|          | 0/10000 [00:00<?, ?it/s]

test val: {'accuracy': 0.6303} loss: 0.6586172372460365


In [15]:
validation_df.to_csv(get_model_filename(
    harness_final,
    MODEL_OUTPUT_BASE,
    is_cuda=is_cuda,
    ftype="validation",
    epoch=None,
    ext=".csv"))

In [16]:
validation_df[validation_df["preds"] == validation_df["truth"]].head()

Unnamed: 0,gen_name,parent_left,child_left,parent_right,child_right,sway_left,sway_right,correct_is_right,logit_left,logit_right,preds,truth
0,*valid--!copy;(mtl:20);(sw);(st),I once knew this family where the dad had to w...,What blows my mind about this the most is this...,What blows my mind about this the most is this...,I once knew this family where the dad had to w...,1.0,1.6458109999999999e-38,False,0.7046,0.2954,0,0
1,random--valid;(sw);(st),I'll take the roast. That was hilarious,"Take care of your finances, your teeth, your b...",Would you be in favor of removing ‚ÄúOne Nation ...,You don‚Äôt need to remove ‚ÄúOne Nation.‚Äù Removin...,0.119203,0.8807971,True,0.228511,0.771489,1,1
2,random--valid;(sw);(st),Totally would. I think the more exposure you h...,No clue how i didnt see it on the first couple...,What‚Äôs something weird that porn has normalised?,Women hissing,0.047426,0.9525741,True,0.251375,0.748625,1,1
3,random--valid;(sw);(st),I really enjoyed Malcolm in the Middle as a ki...,Probably a young Steve buscemi at best.,"I‚Äôm not sure what this is called, but I have a...",I don't put the blame for bad designs like Poc...,0.047426,0.9525741,True,0.37145,0.62855,1,1
0,random--valid;(sw);(st),I edited it almost instantly,Because of those teachers in highschool my tea...,What dark secret are you hiding from everyone?,I won‚Äôt go on a throw away account. I became e...,0.047426,0.9525741,True,0.283972,0.716028,1,1


In [17]:
validation_df[validation_df["preds"] != validation_df["truth"]].head()

Unnamed: 0,gen_name,parent_left,child_left,parent_right,child_right,sway_left,sway_right,correct_is_right,logit_left,logit_right,preds,truth
3,*valid--random;(sw);(st),If I could have guaranteed my daughter would i...,Dude I thought of the exact same thing! My dau...,so then whats the difference between OP (and g...,Only if it comes with enforcing not driving li...,1.0,7.753691e-96,False,0.363922,0.636078,1,0
0,*valid--random;(mtl:20);(sw);(st),If you could rename America what would it be c...,Just name it all Florida,Cargo cults exist mostly because of the societ...,When your stomachs trap air then make a fart n...,1.0,0.0,False,0.465955,0.534045,1,0
2,*valid--random;(sw);(st),I followed all the rules like a good little bo...,I think that a big part of it...,call her mayonnaise and tell her that if she f...,We'll never know until we get there. You've go...,0.880797,0.1192029,False,0.331302,0.668698,1,0
3,*valid--random;(sw);(st),Nice scumbag defence. Are you their legal coun...,Where do you see me defending any of these peo...,Girlfriend of 10 years convinced me to get her...,"No problem, bud! üëç",0.997527,0.002472623,False,0.485748,0.514252,1,0
2,random--valid;(sw);(st),"exactly!! and even if they get a F, you should...",It was worse‚Ä¶ much much worse. Open racism and...,"After reading the comments, TIL a lot of peopl...","And also that americans call it ""tamale"" inste...",0.401312,0.5986877,True,0.530227,0.469773,0,1


In [18]:
ttgen.reset()
model_final.eval()
harness_final.eval()
with torch.no_grad():
    count = 0
    for test_val_df in ttgen.test_validation_dfs():
        plefts = tokens(test_val_df["parent_left"])
        clefts = tokens(test_val_df["child_left"])
        prights = tokens(test_val_df["parent_right"])
        crights = tokens(test_val_df["child_right"])
        display(model_final.get_child_embed(
            clefts["input_ids"],
            clefts["attention_mask"]).cpu().numpy())
        display(model_final.get_child_embed(
            crights["input_ids"],
            crights["attention_mask"]).cpu().numpy())
        count += 1
        if count >= 5:
            break

array([[-0.06550326,  0.02709419,  0.21112214, ...,  0.11420967,
        -0.38496572, -0.05269245],
       [-0.06571657,  0.02793262,  0.2312173 , ...,  0.11194899,
        -0.38510418, -0.05252533],
       [-0.06566565,  0.02773809,  0.2264941 , ...,  0.11248553,
        -0.38508695, -0.05256777],
       [-0.06568879,  0.02782691,  0.22864635, ...,  0.11224143,
        -0.38509583, -0.05254867]], dtype=float32)

array([[-0.06535928,  0.02648313,  0.19693018, ...,  0.11577048,
        -0.38476798, -0.05278942],
       [-0.06574368,  0.02803508,  0.23371796, ...,  0.11166368,
        -0.38510934, -0.05250207],
       [-0.0657125 ,  0.02791727,  0.23084284, ...,  0.11199162,
        -0.3851029 , -0.05252872],
       [-0.06563538,  0.02762094,  0.22366752, ...,  0.11280513,
        -0.38507187, -0.05259226]], dtype=float32)

array([[-0.0656829 ,  0.02780432,  0.22809815, ...,  0.11230366,
        -0.3850937 , -0.05255355],
       [-0.06569798,  0.02786201,  0.2294989 , ...,  0.11214454,
        -0.3850989 , -0.05254098],
       [-0.06571472,  0.0279257 ,  0.23104848, ...,  0.11196826,
        -0.3851037 , -0.05252688],
       [-0.06575096,  0.02806242,  0.23438704, ...,  0.1115872 ,
        -0.3851104 , -0.05249573]], dtype=float32)

array([[-0.06533301,  0.02636671,  0.19427145, ...,  0.11605955,
        -0.38472182, -0.05280573],
       [-0.06574761,  0.02804977,  0.23407796, ...,  0.11162256,
        -0.3851099 , -0.05249863],
       [-0.06567356,  0.02776845,  0.22722957, ...,  0.11240222,
        -0.38509   , -0.05256129],
       [-0.06568024,  0.02779406,  0.22784998, ...,  0.11233185,
        -0.38509274, -0.0525558 ]], dtype=float32)

array([[-0.06561287,  0.02753301,  0.22155438, ...,  0.11304329,
        -0.3850586 , -0.05261013],
       [-0.06574942,  0.02805658,  0.23424454, ...,  0.11160348,
        -0.3851102 , -0.05249709],
       [-0.06573675,  0.02800887,  0.2330777 , ...,  0.11173683,
        -0.38510838, -0.05250804],
       [-0.06563716,  0.02762791,  0.22383493, ...,  0.11278623,
        -0.38507292, -0.05259083]], dtype=float32)

array([[-0.06572815,  0.02797643,  0.23228577, ...,  0.11182722,
        -0.38510668, -0.05251546],
       [-0.06568072,  0.02779597,  0.22789657, ...,  0.11232647,
        -0.38509297, -0.05255538],
       [-0.06574763,  0.02804986,  0.23407973, ...,  0.11162236,
        -0.38511008, -0.05249863],
       [-0.06575137,  0.0280639 ,  0.23442353, ...,  0.11158305,
        -0.3851104 , -0.05249535]], dtype=float32)

array([[-0.0656904 ,  0.02783306,  0.22879606, ...,  0.11222443,
        -0.3850964 , -0.05254735],
       [-0.06578612,  0.02819377,  0.23761076, ...,  0.11121783,
        -0.38511246, -0.05246475],
       [-0.06580459,  0.02826248,  0.23930213, ...,  0.11102349,
        -0.3851116 , -0.05244821],
       [-0.06556657,  0.02735009,  0.2171814 , ...,  0.11353415,
        -0.38502502, -0.0526458 ]], dtype=float32)

array([[-0.06561741,  0.0275508 ,  0.22198111, ...,  0.11299524,
        -0.38506138, -0.05260655],
       [-0.06569718,  0.02785897,  0.22942495, ...,  0.11215296,
        -0.38509858, -0.05254164],
       [-0.06573445,  0.02800029,  0.23286769, ...,  0.11176081,
        -0.385108  , -0.05251   ],
       [-0.06568942,  0.02782934,  0.22870544, ...,  0.11223471,
        -0.38509616, -0.05254814]], dtype=float32)

array([[-0.06559324,  0.0274559 ,  0.21970677, ...,  0.11325098,
        -0.3850454 , -0.05262538],
       [-0.06577191,  0.02814086,  0.23631021, ...,  0.11136699,
        -0.38511208, -0.05247736],
       [-0.06566589,  0.027739  ,  0.22651641, ...,  0.11248302,
        -0.3850869 , -0.05256756],
       [-0.06571781,  0.02793735,  0.23133253, ...,  0.11193586,
        -0.38510448, -0.05252425]], dtype=float32)

array([[-0.06564257,  0.0276489 ,  0.22434081, ...,  0.11272905,
        -0.38507572, -0.0525865 ],
       [-0.06568698,  0.02781999,  0.2284786 , ...,  0.11226046,
        -0.38509533, -0.05255016],
       [-0.06566387,  0.02773128,  0.22632957, ...,  0.11250415,
        -0.38508606, -0.05256923],
       [-0.06555977,  0.02732293,  0.21653499, ...,  0.11360647,
        -0.38501945, -0.05265088]], dtype=float32)