In [1]:
# in a terminal run
# > make run-redis NS=train
# > make run-redis NS=test
# to allow access to the train and test namespaces

In [2]:
import os
import sys
import pandas as pd
import numpy as np

In [3]:
sys.path.append("..")
os.environ["USER_PATH"] = "/mnt/d/workspace/clotho/userdata/"
MODEL_OUTPUT_BASE = "/mnt/d/workspace/clotho/notebooks"
MODEL_OUTPUT_CP = os.path.join(MODEL_OUTPUT_BASE, "checkpoints")

In [4]:
from misc.redis import set_redis_slow_mode
from misc.util import highest_number
from model.datagenerator import create_train_test
from system.namespace.store import get_namespace

In [5]:
import torch

is_cuda = torch.cuda.is_available()
is_cuda

True

In [6]:
set_redis_slow_mode("never")
ns_test = get_namespace("test")
ns_train = get_namespace("train")
now = pd.Timestamp("2022-12-17", tz="UTC")
train_plan = [
    {
        "left": {"mode": "valid", "flip_pc": 0.5},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 100,
    },
    {
        "left": {"mode": "valid", "flip_pc": 0.5},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 100,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
     {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "valid", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0,
        "first_epoch": 15,
        "last_epoch": None,
        "weight": 50,
    }
]
eval_plan = [
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
]
ttgen = create_train_test(
    train_ns=ns_train,
    train_validation_ns=ns_train,
    test_ns=ns_test,
    test_validation_ns=ns_test,
    train_learning_plan=train_plan,
    train_val_learning_plan=eval_plan,
    test_learning_plan=eval_plan,
    test_val_learning_plan=eval_plan,
    batch_size=4 if is_cuda else 8,
    epoch_batches=5000 if is_cuda else 500,
    train_val_size=10000 if is_cuda else 1000,
    test_size=10000 if is_cuda else 1000,
    test_val_size=10000 if is_cuda else 1000,
    compute_batch_size=100 if is_cuda else 100,
    now=now)

In [7]:
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel

In [8]:
device = torch.device("cuda") if is_cuda else torch.device("cpu")
device

device(type='cuda')

In [9]:
from typing import Literal, TypedDict

ProviderRole = Literal["child", "parent"]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
EMBED_SIZE = 768

TokenizedInput = TypedDict('TokenizedInput', {
    "input_ids": torch.Tensor,
    "attention_mask": torch.Tensor,
})


def tokens(texts: list[str]) -> TokenizedInput:
    res = tokenizer(texts.tolist(), return_tensors="pt", padding=True, truncation=True)
    return {k: v.to(device) for k, v in res.items()}


class Noise(nn.Module):
    def __init__(self, std: float = 1.0, p: float = 0.5) -> None:
        super().__init__()
        self._std = std
        self._p = p
        self._dhold = nn.Parameter(torch.Tensor([0.0]), requires_grad=False)

    def set_std(self, std: float) -> None:
        self._std = std

    def get_std(self) -> float:
        return self._std

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return x
        prob = torch.rand(size=x.shape, device=self._dhold.device) < self._p
        gauss = torch.normal(
            mean=0.0, std=self._std, size=x.shape, device=self._dhold.device)
        return x + prob * gauss


class Model(nn.Module):
    def __init__(self, version: int) -> None:
        super().__init__()
        self._bert_parent = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        self._bert_child = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        if version in (1, 3, 4, 6):
            self._pdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
            self._cdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
        else:
            self._pdense = None
            self._cdense = None
        if version < 4 or noise > 5:
            self._noise = None
        else:
            self._noise = Noise(std=1.0, p=0.2)
        if version < 2 or version > 4:
            self._cos = None
        else:
            self._cos = torch.nn.CosineSimilarity()
        if version < 6:
            self._agg = "cls"
        else:
            self._agg = "mean"
        self._version = version

    def set_epoch(self, epoch: int) -> None:
        noise = self._noise
        if noise is not None:
            noise.set_std(1 / (1.2 ** epoch))

    def get_version(self) -> int:
        return self._version
    
    def get_agg(self, lhs: torch.Tensor) -> torch.Tensor:
        if self._agg == "cls":
            return lhs[:, 0]
        if self._agg == "mean":
            return torch.mean(lhs, dim=1)
        raise ValueError(f"unknown aggregation: {self._agg}")

    def get_parent_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_parent = self._bert_parent(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_parent.last_hidden_state)
        if self._pdense is not None:
            out = self._pdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def get_child_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_child = self._bert_child(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_child.last_hidden_state)
        if self._cdense is not None:
            out = self._cdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def forward(self, x: dict[ProviderRole, TokenizedInput]) -> torch.Tensor:
        parent_cls = self.get_parent_embed(
            input_ids=x["parent"]["input_ids"],
            attention_mask=x["parent"]["attention_mask"])
        child_cls = self.get_child_embed(
            input_ids=x["child"]["input_ids"],
            attention_mask=x["child"]["attention_mask"])
        if self._cos is not None:
            return self._cos(parent_cls, child_cls).reshape([-1, 1])
        batch_size = parent_cls.shape[0]
        return torch.bmm(
            parent_cls.reshape([batch_size, 1, -1]),
            child_cls.reshape([batch_size, -1, 1])).reshape([-1, 1])


class TrainingHarness(nn.Module):
    def __init__(self, model: Model) -> None:
        super().__init__()
        self._model = model
        self._softmax = nn.Softmax(dim=1)
        self._loss = nn.BCELoss()

    def get_version(self) -> int:
        return self._model.get_version()

    def forward(
            self,
            left: TokenizedInput,
            right: TokenizedInput,
            labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        out_left = self._model(left)
        out_right = self._model(right)
        preds = self._softmax(torch.hstack((out_left, out_right)))
        return preds, self._loss(preds, labels)

In [None]:
def get_filename(
        harness: TrainingHarness,
        *,
        is_final: bool,
        ftype: str,
        epoch: int | None,
        ext: str = ".pkl") -> tuple[str, str, str]:
    folder = MODEL_OUTPUT_BASE if is_final else MODEL_OUTPUT_CP
    postfix = "_lg" if is_cuda else ""
    version_tag = f"_v{harness.get_version()}"
    out_pre = f"{ftype}{version_tag}{postfix}_"
    out_post = ext
    return (
        os.path.join(
            folder,
            f"{out_pre}{"" if epoch is None else epoch}{out_post}"),
        out_pre,
        out_post,
        folder,
    )

In [10]:
from torch.optim import AdamW

model = Model(version=6)
model.to(device)
harness = TrainingHarness(model)
harness.to(device)

FORCE_RESTART = False

_, spre, spost, sfolder = get_filename(harness, is_final=False, ftype="harness", epoch=None)
mprev = highest_number(os.listdir(sfolder), prefix=spre, postfix=spost)
if not FORCE_RESTART and mprev is not None:
    prev_fname, prev_epoch = mprev
    harness.load_state_dict(torch.load(os.path.join(folder, prev_fname), map_location=device))
    epoch_offset = prev_epoch + 1
else:
    epoch_offset = 0

optimizer = AdamW(harness.parameters(), lr=5e-5)
mprev, epoch_offset

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias']
- T

(None, 0)

In [11]:
from transformers import get_scheduler
# from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import evaluate
import time

def compute(df):
    plefts = tokens(df["parent_left"])
    clefts = tokens(df["child_left"])
    prights = tokens(df["parent_right"])
    crights = tokens(df["child_right"])
    labels = torch.tensor([~df["correct_is_right"], df["correct_is_right"]], dtype=torch.float32).T.to(device)
    return harness({"parent": plefts, "child": clefts}, {"parent": prights, "child": crights}, labels)

num_epochs = max((120 if is_cuda else 10) - epoch_offset, 3)
num_training_steps = num_epochs * ttgen.get_epoch_train_size()
warmup = 10000 if is_cuda else 10
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=warmup,
    num_training_steps=num_training_steps - warmup)
ttgen.set_epoch(epoch_offset)

for _ in range(num_epochs):
    epoch = ttgen.get_epoch()
    print(f"epoch {epoch}")
    real_time = time.monotonic()
    
    model.train()
    harness.train()
    model.set_epoch(epoch)
    metric_train = evaluate.load("accuracy")
    train_loss = []
    first = True
    with tqdm(desc="train", total=ttgen.get_epoch_train_size()) as progress_bar:
        for train_df in ttgen.train_dfs():
            preds, loss = compute(train_df)
            train_loss.append(loss.item())
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(train_df.shape[0])
            
            predictions = torch.argmax(preds, dim=-1)
            metric_train.add_batch(predictions=predictions, references=train_df["correct_is_right"].astype(int))
            if first:
                # display(train_df)
                first = False

    torch.save(harness.state_dict(), get_filename(harness, is_final=False, ftype="harness", epoch=epoch))
            
    model.eval()
    harness.eval()
    with torch.no_grad():
        metric_val_train = evaluate.load("accuracy")
        train_val_loss = []
        with tqdm(desc="train val", total=ttgen.get_epoch_train_validation_size()) as progress_bar:
            for train_validation_df in ttgen.train_validation_dfs():
                preds, loss = compute(train_validation_df)
                train_val_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_val_train.add_batch(
                    predictions=predictions, references=train_validation_df["correct_is_right"].astype(int))
                progress_bar.update(train_validation_df.shape[0])
        
        metric_test = evaluate.load("accuracy")
        test_loss = []
        with tqdm(desc="test", total=ttgen.get_epoch_test_size()) as progress_bar:
            for test_df in ttgen.test_dfs():
                preds, loss = compute(test_df)
                test_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_test.add_batch(
                    predictions=predictions, references=test_df["correct_is_right"].astype(int))
                progress_bar.update(test_df.shape[0])
        
        print(f"train: {metric_train.compute()} loss: {np.mean(train_loss)}")
        print(f"train val: {metric_val_train.compute()} loss: {np.mean(train_val_loss)}")
        print(f"test: {metric_test.compute()} loss: {np.mean(test_loss)}")
    ttgen.advance_epoch()
    print(f"epoch time: {(time.monotonic() - real_time) / 60.0:.2f}min")

epoch 0


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.51025} loss: 22.655790788561532
train val: {'accuracy': 0.5311} loss: 0.6934446441173554
test: {'accuracy': 0.4985} loss: 0.7083818842530251
epoch time: 84.80min
epoch 1


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.49875} loss: 22.049154417168726
train val: {'accuracy': 0.457} loss: 0.6953360244274139
test: {'accuracy': 0.4934} loss: 0.6928981063842773
epoch time: 15.74min
epoch 2


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.503} loss: 19.41779045229834
train val: {'accuracy': 0.4772} loss: 0.6936146483898162
test: {'accuracy': 0.532} loss: 0.692195683169365
epoch time: 15.89min
epoch 3


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.4932} loss: 17.575540262147335
train val: {'accuracy': 0.4563} loss: 0.6933389202833176
test: {'accuracy': 0.495} loss: 0.6932094643115997
epoch time: 17.29min
epoch 4


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.496} loss: 15.725588903693628
train val: {'accuracy': 0.6181} loss: 0.6926936051607132
test: {'accuracy': 0.5023} loss: 0.6928451885938645
epoch time: 16.13min
epoch 5


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5017} loss: 13.572738920020157
train val: {'accuracy': 0.393} loss: 0.6944171579837799
test: {'accuracy': 0.502} loss: 0.6931974095582962
epoch time: 16.39min
epoch 6


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5033} loss: 11.754992915035524
train val: {'accuracy': 0.488} loss: 0.6892236686468124
test: {'accuracy': 0.5397} loss: 0.6926834681272507
epoch time: 15.80min
epoch 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50435} loss: 10.009444292305366
train val: {'accuracy': 0.5755} loss: 0.6926531358003616
test: {'accuracy': 0.4881} loss: 0.6931538519620896
epoch time: 15.72min
epoch 8


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50225} loss: 8.494893871269761
train val: {'accuracy': 0.5969} loss: 0.6927846732616425
test: {'accuracy': 0.4756} loss: 0.6932169940948486
epoch time: 15.37min
epoch 9


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.49465} loss: 7.190322187525271
train val: {'accuracy': 0.4014} loss: 0.6932524302005768
test: {'accuracy': 0.544} loss: 0.6931128970384598
epoch time: 16.08min
epoch 10


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.49655} loss: 5.890023180220956
train val: {'accuracy': 0.5118} loss: 0.6931351244211197
test: {'accuracy': 0.4878} loss: 0.693153289604187
epoch time: 19.60min
epoch 11


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50265} loss: 4.856784200347271
train val: {'accuracy': 0.5531} loss: 0.6931240751743316
test: {'accuracy': 0.486} loss: 0.6931501620054245
epoch time: 15.78min
epoch 12


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50195} loss: 3.7240962373980566
train val: {'accuracy': 0.5286} loss: 0.6931394065856934
test: {'accuracy': 0.4948} loss: 0.693148525261879
epoch time: 15.73min
epoch 13


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50375} loss: 2.9978565034908238
train val: {'accuracy': 0.4934} loss: 0.6931497947692871
test: {'accuracy': 0.5014} loss: 0.6931455955743789
epoch time: 16.00min
epoch 14


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50245} loss: 2.418153400265667
train val: {'accuracy': 0.4784} loss: 0.6931614898920059
test: {'accuracy': 0.4959} loss: 0.6931492459535599
epoch time: 15.93min
epoch 15


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.4979} loss: 2.0935098392324756
train val: {'accuracy': 0.5463} loss: 0.6931271878004074
test: {'accuracy': 0.4952} loss: 0.6931490626811981
epoch time: 16.22min
epoch 16


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.49475} loss: 1.710819789071195
train val: {'accuracy': 0.5126} loss: 0.6931424585819245
test: {'accuracy': 0.4935} loss: 0.6931490379571915
epoch time: 15.67min
epoch 17


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.4958} loss: 1.4872780491665938
train val: {'accuracy': 0.5406} loss: 0.693048208141327
test: {'accuracy': 0.4985} loss: 0.6931447621107102
epoch time: 15.87min
epoch 18


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50465} loss: 2.881600488134119
train val: {'accuracy': 0.5955} loss: 0.7231557749271392
test: {'accuracy': 0.5184} loss: 0.8494689993262291
epoch time: 15.52min
epoch 19


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.4973} loss: 1.7868125592696247
train val: {'accuracy': 0.4943} loss: 0.702419086766243
test: {'accuracy': 0.4962} loss: 0.7322219789385795
epoch time: 15.81min
epoch 20


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.4968} loss: 1.3294474137319252
train val: {'accuracy': 0.5948} loss: 0.6953262700796128
test: {'accuracy': 0.5061} loss: 0.7069197916030884
epoch time: 15.93min
epoch 21


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50025} loss: 1.0306217489995062
train val: {'accuracy': 0.5483} loss: 0.6931062350511551
test: {'accuracy': 0.4869} loss: 0.6931647469282151
epoch time: 15.78min
epoch 22


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.49715} loss: 0.9765520164430141
train val: {'accuracy': 0.5583} loss: 0.6924959521532059
test: {'accuracy': 0.4848} loss: 0.6931851873397827
epoch time: 15.74min
epoch 23


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.49615} loss: 0.9324944511801004
train val: {'accuracy': 0.4385} loss: 0.6932542025089264
test: {'accuracy': 0.501} loss: 0.6931529967546463
epoch time: 16.01min
epoch 24


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50145} loss: 0.8778518315240741
train val: {'accuracy': 0.5689} loss: 0.6930576257228851
test: {'accuracy': 0.482} loss: 0.6931562332391739
epoch time: 16.01min
epoch 25


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5014} loss: 0.84261955152452
train val: {'accuracy': 0.4398} loss: 0.6930391448497772
test: {'accuracy': 0.5181} loss: 0.6918151673078538
epoch time: 15.78min
epoch 26


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5005} loss: 0.8216227032482624
train val: {'accuracy': 0.6058} loss: 0.6570997936964035
test: {'accuracy': 0.4789} loss: 0.7341879445791244
epoch time: 16.14min
epoch 27


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50635} loss: 0.850709900663793
train val: {'accuracy': 0.4729} loss: 0.6883407987117768
test: {'accuracy': 0.4986} loss: 0.6937460000991821
epoch time: 15.90min
epoch 28


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.50465} loss: 0.7543868218481541
train val: {'accuracy': 0.6133} loss: 0.6613406552791595
test: {'accuracy': 0.5017} loss: 0.6977592961311341
epoch time: 15.95min
epoch 29


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5091} loss: 0.7595095428735018
train val: {'accuracy': 0.6189} loss: 0.6516192964315415
test: {'accuracy': 0.5183} loss: 0.6913807391405106
epoch time: 15.74min
epoch 30


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.54985} loss: 1.4770074505329132
train val: {'accuracy': 0.6389} loss: 0.6136519784927368
test: {'accuracy': 0.4989} loss: 0.7080599735975266
epoch time: 16.69min
epoch 31


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.54545} loss: 1.4483389072270598
train val: {'accuracy': 0.6098} loss: 0.6309693775773049
test: {'accuracy': 0.5132} loss: 0.7067063385009765
epoch time: 15.65min
epoch 32


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.57485} loss: 0.7202538917973638
train val: {'accuracy': 0.6384} loss: 0.642916285777092
test: {'accuracy': 0.5305} loss: 0.69418661454916
epoch time: 16.07min
epoch 33


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.58405} loss: 0.6609686676174402
train val: {'accuracy': 0.6478} loss: 0.6315303629040718
test: {'accuracy': 0.5104} loss: 0.6917122703909874
epoch time: 15.89min
epoch 34


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.59065} loss: 0.6650344816684723
train val: {'accuracy': 0.6408} loss: 0.5896816130876541
test: {'accuracy': 0.5155} loss: 0.6889377160787582
epoch time: 15.96min
epoch 35


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.58785} loss: 0.6689139528974891
train val: {'accuracy': 0.6391} loss: 0.6152003086686134
test: {'accuracy': 0.5011} loss: 0.6926184389710426
epoch time: 15.76min
epoch 36


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5878} loss: 0.660241052659601
train val: {'accuracy': 0.6393} loss: 0.6157842682957649
test: {'accuracy': 0.5142} loss: 0.6895527868032455
epoch time: 16.31min
epoch 37


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5928} loss: 0.6551466562304064
train val: {'accuracy': 0.6348} loss: 0.6131461641013622
test: {'accuracy': 0.4979} loss: 0.6917044654369354
epoch time: 16.00min
epoch 38


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6053} loss: 0.6469511500640772
train val: {'accuracy': 0.643} loss: 0.6051840082764626
test: {'accuracy': 0.5356} loss: 0.6923486887812614
epoch time: 16.03min
epoch 39


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6046} loss: 0.6545786474909634
train val: {'accuracy': 0.6483} loss: 0.6137043646931648
test: {'accuracy': 0.4428} loss: 0.7338751566171646
epoch time: 15.85min
epoch 40


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6016} loss: 0.6580289980528876
train val: {'accuracy': 0.6473} loss: 0.6018846308112145
test: {'accuracy': 0.5072} loss: 0.7368719680190087
epoch time: 15.60min
epoch 41


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5223} loss: 6.757183684678376
train val: {'accuracy': 0.6243} loss: 0.6555771126627922
test: {'accuracy': 0.4903} loss: 0.7065268615961074
epoch time: 15.58min
epoch 42


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.59595} loss: 0.655377830016613
train val: {'accuracy': 0.6464} loss: 0.5826511860013008
test: {'accuracy': 0.5194} loss: 0.6935315600514412
epoch time: 16.18min
epoch 43


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.62035} loss: 0.6295690436445177
train val: {'accuracy': 0.6422} loss: 0.6164606383204461
test: {'accuracy': 0.5129} loss: 0.6983289807677269
epoch time: 15.81min
epoch 44


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.58005} loss: 0.6655556664966047
train val: {'accuracy': 0.6345} loss: 0.6619166282057762
test: {'accuracy': 0.535} loss: 0.6948422470808029
epoch time: 15.85min
epoch 45


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.58985} loss: 0.6601145118787884
train val: {'accuracy': 0.6661} loss: 0.616151810836792
test: {'accuracy': 0.5188} loss: 0.6994537524580956
epoch time: 16.00min
epoch 46


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6026} loss: 0.6422524961389601
train val: {'accuracy': 0.6423} loss: 0.6364057011842728
test: {'accuracy': 0.5286} loss: 0.7088630456328392
epoch time: 15.81min
epoch 47


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.60105} loss: 0.6415217233832926
train val: {'accuracy': 0.6681} loss: 0.6449330246567726
test: {'accuracy': 0.5404} loss: 0.6884417070984841
epoch time: 16.35min
epoch 48


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5996} loss: 0.6421306029036641
train val: {'accuracy': 0.6653} loss: 0.5912803902924061
test: {'accuracy': 0.5433} loss: 0.7270296947658061
epoch time: 16.23min
epoch 49


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.60475} loss: 0.6344280487731099
train val: {'accuracy': 0.6683} loss: 0.6075473522245884
test: {'accuracy': 0.5174} loss: 0.6858900582551957
epoch time: 15.78min
epoch 50


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.60485} loss: 0.6448180771771819
train val: {'accuracy': 0.6541} loss: 0.6275166255354881
test: {'accuracy': 0.5136} loss: 0.7196429426193237
epoch time: 15.84min
epoch 51


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6027} loss: 0.6385970738455653
train val: {'accuracy': 0.6745} loss: 0.5990153449296951
test: {'accuracy': 0.5036} loss: 0.6910688071012497
epoch time: 16.30min
epoch 52


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5915} loss: 0.6532795864060521
train val: {'accuracy': 0.656} loss: 0.6354117130994796
test: {'accuracy': 0.5179} loss: 0.6912091602683067
epoch time: 16.03min
epoch 53


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.59935} loss: 0.6364879228144884
train val: {'accuracy': 0.6674} loss: 0.5789474238485098
test: {'accuracy': 0.546} loss: 0.6996633095026016
epoch time: 16.60min
epoch 54


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.61695} loss: 0.629201526542753
train val: {'accuracy': 0.6778} loss: 0.6103274715423584
test: {'accuracy': 0.535} loss: 0.6866570870637894
epoch time: 16.27min
epoch 55


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.62985} loss: 0.6172159304603935
train val: {'accuracy': 0.6747} loss: 0.5904367057859897
test: {'accuracy': 0.5714} loss: 0.6716911962628365
epoch time: 16.47min
epoch 56


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6223} loss: 0.6317610348634422
train val: {'accuracy': 0.6568} loss: 0.607513680678606
test: {'accuracy': 0.5856} loss: 0.6796952055692673
epoch time: 16.29min
epoch 57


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.62495} loss: 0.6280092351421713
train val: {'accuracy': 0.6544} loss: 0.6146561507105828
test: {'accuracy': 0.5675} loss: 0.6830918902039528
epoch time: 16.26min
epoch 58


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.61575} loss: 0.6317306707914918
train val: {'accuracy': 0.6436} loss: 0.661312414765358
test: {'accuracy': 0.5088} loss: 0.7035441834926606
epoch time: 16.54min
epoch 59


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.61605} loss: 0.6287128928717226
train val: {'accuracy': 0.6514} loss: 0.6224469376385212
test: {'accuracy': 0.5281} loss: 0.7049710429787636
epoch time: 16.53min
epoch 60


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6108} loss: 0.6344250080289974
train val: {'accuracy': 0.6585} loss: 0.616494692826271
test: {'accuracy': 0.5226} loss: 0.6962281279683113
epoch time: 16.54min
epoch 61


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6201} loss: 0.6204642332702875
train val: {'accuracy': 0.6697} loss: 0.5956827125668526
test: {'accuracy': 0.517} loss: 0.6968556618094445
epoch time: 16.36min
epoch 62


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.63865} loss: 0.6084218690313399
train val: {'accuracy': 0.6693} loss: 0.5864376405894757
test: {'accuracy': 0.5216} loss: 0.6898326637387275
epoch time: 16.53min
epoch 63


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6302} loss: 0.613984053948149
train val: {'accuracy': 0.6723} loss: 0.5870052537441254
test: {'accuracy': 0.5251} loss: 0.7158285625934601
epoch time: 16.30min
epoch 64


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6422} loss: 0.6041946912914514
train val: {'accuracy': 0.6717} loss: 0.5812879116475582
test: {'accuracy': 0.5379} loss: 0.6941475687146187
epoch time: 15.87min
epoch 65


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.618} loss: 0.6301868735997938
train val: {'accuracy': 0.6519} loss: 0.6342424039721489
test: {'accuracy': 0.5322} loss: 0.7120558575749397
epoch time: 16.39min
epoch 66


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5927} loss: 0.6621572444036603
train val: {'accuracy': 0.5948} loss: 0.735285979104042
test: {'accuracy': 0.5018} loss: 0.8372150959298015
epoch time: 16.56min
epoch 67


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.5458} loss: 0.6950307304918766
train val: {'accuracy': 0.4756} loss: 0.7558771532535553
test: {'accuracy': 0.5106} loss: 0.726904780638218
epoch time: 16.39min
epoch 68


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.57925} loss: 0.66919618011415
train val: {'accuracy': 0.6648} loss: 0.6232781063675881
test: {'accuracy': 0.5007} loss: 0.6977160790085792
epoch time: 16.09min
epoch 69


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6357} loss: 0.6158823770347983
train val: {'accuracy': 0.6188} loss: 0.5800273397728801
test: {'accuracy': 0.5373} loss: 0.6948786136746407
epoch time: 17.12min
epoch 70


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.64805} loss: 0.6022546653017402
train val: {'accuracy': 0.6711} loss: 0.6084521661102772
test: {'accuracy': 0.5149} loss: 0.6867942536234856
epoch time: 16.50min
epoch 71


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6545} loss: 0.5942322601433844
train val: {'accuracy': 0.6699} loss: 0.6098112634062767
test: {'accuracy': 0.5383} loss: 0.6831989003300667
epoch time: 16.23min
epoch 72


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.65865} loss: 0.5770669899897585
train val: {'accuracy': 0.6589} loss: 0.6103768048524857
test: {'accuracy': 0.5567} loss: 0.6855943791508675
epoch time: 16.03min
epoch 73


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6572} loss: 0.5812114943534136
train val: {'accuracy': 0.67} loss: 0.5799970477670431
test: {'accuracy': 0.6122} loss: 0.6659101774930954
epoch time: 17.06min
epoch 74


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6745} loss: 0.5664511268235743
train val: {'accuracy': 0.6741} loss: 0.6059746125936508
test: {'accuracy': 0.5984} loss: 0.6795380562186241
epoch time: 16.18min
epoch 75


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.67435} loss: 0.5649134400932119
train val: {'accuracy': 0.677} loss: 0.5706403456777335
test: {'accuracy': 0.5966} loss: 0.6797803143501282
epoch time: 16.82min
epoch 76


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6781} loss: 0.5625772573251278
train val: {'accuracy': 0.6793} loss: 0.5546775067657232
test: {'accuracy': 0.5969} loss: 0.6709079030752182
epoch time: 16.34min
epoch 77


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.67725} loss: 0.5624359168002732
train val: {'accuracy': 0.6785} loss: 0.5646795721322299
test: {'accuracy': 0.605} loss: 0.6774349558591842
epoch time: 16.89min
epoch 78


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68115} loss: 0.5566164871193469
train val: {'accuracy': 0.6762} loss: 0.5726637514412403
test: {'accuracy': 0.645} loss: 0.6379768812537193
epoch time: 16.45min
epoch 79


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6826} loss: 0.5546887269832194
train val: {'accuracy': 0.6783} loss: 0.5908378523766995
test: {'accuracy': 0.6311} loss: 0.6687863513708114
epoch time: 16.61min
epoch 80


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6901} loss: 0.5512915297485887
train val: {'accuracy': 0.6825} loss: 0.5635703306853771
test: {'accuracy': 0.5974} loss: 0.6641324471354485
epoch time: 16.23min
epoch 81


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68415} loss: 0.5510016732804477
train val: {'accuracy': 0.6833} loss: 0.5706749406635762
test: {'accuracy': 0.6232} loss: 0.6648777266860009
epoch time: 16.21min
epoch 82


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6831} loss: 0.5563803624021472
train val: {'accuracy': 0.6826} loss: 0.5720634600430727
test: {'accuracy': 0.6275} loss: 0.6562074196577072
epoch time: 16.45min
epoch 83


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68445} loss: 0.5536327035784722
train val: {'accuracy': 0.6807} loss: 0.5886469828009605
test: {'accuracy': 0.6112} loss: 0.6774987057805061
epoch time: 16.65min
epoch 84


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6843} loss: 0.551475761084631
train val: {'accuracy': 0.6827} loss: 0.5744986033678054
test: {'accuracy': 0.6267} loss: 0.6649392602920532
epoch time: 17.00min
epoch 85


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6955} loss: 0.5427440061021597
train val: {'accuracy': 0.6888} loss: 0.5613787086218596
test: {'accuracy': 0.6174} loss: 0.6623060029745101
epoch time: 16.55min
epoch 86


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69605} loss: 0.5397577558114194
train val: {'accuracy': 0.6893} loss: 0.5616906631320715
test: {'accuracy': 0.6081} loss: 0.6602988770365715
epoch time: 16.02min
epoch 87


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69585} loss: 0.5378912607646321
train val: {'accuracy': 0.6935} loss: 0.5513443793207407
test: {'accuracy': 0.613} loss: 0.6631573091387749
epoch time: 16.72min
epoch 88


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6957} loss: 0.5344540726298467
train val: {'accuracy': 0.6859} loss: 0.5716922341048718
test: {'accuracy': 0.6452} loss: 0.642309991300106
epoch time: 16.23min
epoch 89


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69815} loss: 0.5366030247587944
train val: {'accuracy': 0.6928} loss: 0.5568987752318382
test: {'accuracy': 0.6415} loss: 0.6466519674062728
epoch time: 16.42min
epoch 90


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70385} loss: 0.5341723330974579
train val: {'accuracy': 0.696} loss: 0.5565834381759167
test: {'accuracy': 0.6353} loss: 0.6442320507884025
epoch time: 16.25min
epoch 91


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70295} loss: 0.5318304283905775
train val: {'accuracy': 0.693} loss: 0.5563340867429972
test: {'accuracy': 0.6454} loss: 0.6503355922222137
epoch time: 16.28min
epoch 92


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69745} loss: 0.5320420153267682
train val: {'accuracy': 0.6916} loss: 0.5823991558074951
test: {'accuracy': 0.6458} loss: 0.64983454246521
epoch time: 16.87min
epoch 93


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7093} loss: 0.5307276310182643
train val: {'accuracy': 0.6943} loss: 0.5723373819649219
test: {'accuracy': 0.6287} loss: 0.6543218973755837
epoch time: 17.64min
epoch 94


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7053} loss: 0.5301671376655577
train val: {'accuracy': 0.6918} loss: 0.5739717519998551
test: {'accuracy': 0.649} loss: 0.6422775353312492
epoch time: 19.17min
epoch 95


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70205} loss: 0.5337336874946952
train val: {'accuracy': 0.6899} loss: 0.570110905945301
test: {'accuracy': 0.6019} loss: 0.6674351117134094
epoch time: 18.99min
epoch 96


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70585} loss: 0.5287456960555166
train val: {'accuracy': 0.6962} loss: 0.5602055155992508
test: {'accuracy': 0.6149} loss: 0.6498424701333047
epoch time: 20.23min
epoch 97


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.711} loss: 0.5275163939048536
train val: {'accuracy': 0.6932} loss: 0.5673470091223717
test: {'accuracy': 0.6192} loss: 0.6572561732530594
epoch time: 19.49min
epoch 98


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7097} loss: 0.5258638797266408
train val: {'accuracy': 0.6913} loss: 0.5614567510306835
test: {'accuracy': 0.6159} loss: 0.6636211219191551
epoch time: 19.28min
epoch 99


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70885} loss: 0.5282410882648081
train val: {'accuracy': 0.6928} loss: 0.560701883161068
test: {'accuracy': 0.6391} loss: 0.6381455819368362
epoch time: 19.34min


In [12]:
torch.save(model.state_dict(), get_filename(harness, is_final=True, ftype="model", epoch=None)[0])
torch.save(harness.state_dict(), get_filename(harness, is_final=True, ftype="harness", epoch=None)[0])
torch.save(optimizer.state_dict(), get_filename(harness, is_final=True, ftype="optimizer", epoch=None)[0])

In [13]:
ttgen.reset()
model.eval()
harness.eval()
dfs = []
with torch.no_grad():
    metric_val_test = evaluate.load("accuracy")
    test_val_loss = []
    with tqdm(desc="test val", total=ttgen.get_epoch_test_validation_size()) as progress_bar:
        for test_val_df in ttgen.test_validation_dfs():
            preds, loss = compute(test_val_df)
            test_val_loss.append(loss.item())
            predictions = torch.argmax(preds, dim=-1)
            metric_val_test.add_batch(
                predictions=predictions, references=test_val_df["correct_is_right"].astype(int))
            cur_df = test_val_df.copy()
            cur_df["logit_left"] = preds[:, 0].cpu()
            cur_df["logit_right"] = preds[:, 1].cpu()
            cur_df["preds"] = predictions.cpu()
            cur_df["truth"] = test_val_df["correct_is_right"].astype(int)
            dfs.append(cur_df)
            progress_bar.update(test_val_df.shape[0])
print(f"test val: {metric_val_test.compute()} loss: {np.mean(test_val_loss)}")
validation_df = pd.concat(dfs)

test val:   0%|          | 0/10000 [00:00<?, ?it/s]

test val: {'accuracy': 0.6371} loss: 0.6438971089482307


In [14]:
validation_df.to_csv(
    get_filename(harness, is_final=True, ftype="validation", epoch=None, ext=".csv")[0])

In [15]:
validation_df[validation_df["preds"] == validation_df["truth"]].head()

Unnamed: 0,gen_name,parent_left,child_left,parent_right,child_right,sway_left,sway_right,correct_is_right,logit_left,logit_right,preds,truth
2,random--valid;(sw),Dungeons and Dragons,r/Kamikazebywords,"Question for the men, what's the worst part of...",When sitting wrong on your bike seat and cutti...,0.119203,0.8807971,True,0.386937,0.613063,1,1
3,random--valid;(sw),I always brush my teeth but need to floss more...,As an american all I can say is: why?,[removed],User profile checks out. Bet you're on a list ...,0.119203,0.8807971,True,0.115898,0.884102,1,1
0,*valid--random;(sw),And they taste delicious when chopped into hal...,"In Spain/Portugal we also cook them ""a feira"" ...",Seinfeld seems to be the one that paved the wa...,Wiping down gym equipment.,0.993307,0.006692851,False,0.539984,0.460016,0,0
1,*valid--!copy;(mtl:20),Yeah that's a thing. I spend my summer vacatio...,you ever done on going maintenance on a boat? ...,you ever done on going maintenance on a boat? ...,Yeah that's a thing. I spend my summer vacatio...,1.0,1.5628820000000001e-18,False,0.547646,0.452354,0,0
2,*valid--!copy;(sw),She actually has a song where this is the poin...,"Yep, this one came to mind. She is incredible ...","Yep, this one came to mind. She is incredible ...",She actually has a song where this is the poin...,0.997527,0.002472623,False,0.627038,0.372962,0,0


In [16]:
validation_df[validation_df["preds"] != validation_df["truth"]].head()

Unnamed: 0,gen_name,parent_left,child_left,parent_right,child_right,sway_left,sway_right,correct_is_right,logit_left,logit_right,preds,truth
0,!copy--valid;(mtl:20),Gonna ride the baloney pony all the way to Was...,I'd literally take the law into my own hands.,I'd literally take the law into my own hands.,Gonna ride the baloney pony all the way to Was...,4.662937e-15,1.0,True,0.527256,0.472744,0,1
1,*valid--!copy;(mtl:20),Men are still plenty selective about who they ...,"Men are selective, as in, they'll select from ...","Men are selective, as in, they'll select from ...",Men are still plenty selective about who they ...,0.9998766,0.000123,False,0.480113,0.519887,1,0
3,*valid--random;(mtl:20),The cigarette industry social lied about cigar...,covid vaccines can protect from covid,“Women get sex so easy!” “Man I bet no one wan...,You guys watch porn with sound on??? What a pr...,0.8807971,0.119203,False,0.184878,0.815122,1,0
1,*valid--!copy;(mtl:20),"Yeah that's a fair point, which is why I belie...",>The idea was for God's people to be such good...,>The idea was for God's people to be such good...,"Yeah that's a fair point, which is why I belie...",0.7310586,0.268941,False,0.462959,0.537041,1,0
1,!copy--valid;(mtl:20),I remember learning in 2nd grade that pizza wa...,The food pyramid -brought to you by big grain,The food pyramid -brought to you by big grain,I remember learning in 2nd grade that pizza wa...,0.0,1.0,True,0.654011,0.345989,0,1


In [17]:
ttgen.reset()
model.eval()
harness.eval()
with torch.no_grad():
    count = 0
    for test_val_df in ttgen.test_validation_dfs():
        plefts = tokens(test_val_df["parent_left"])
        clefts = tokens(test_val_df["child_left"])
        prights = tokens(test_val_df["parent_right"])
        crights = tokens(test_val_df["child_right"])
        display(model.get_child_embed(
            clefts["input_ids"],
            clefts["attention_mask"]).cpu().numpy())
        display(model.get_child_embed(
            crights["input_ids"],
            crights["attention_mask"]).cpu().numpy())
        count += 1
        if count >= 5:
            break

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')

tensor([[-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537],
        [-0.3967,  0.3497,  0.7303,  ...,  0.3481,  0.4669,  0.4537]],
       device='cuda:0')