In [1]:
# in a terminal run
# > make run-redis NS=train
# > make run-redis NS=test
# to allow access to the train and test namespaces

In [2]:
import os
import sys
import pandas as pd
import numpy as np

In [3]:
sys.path.append("..")
os.environ["USER_PATH"] = "../userdata"

In [4]:
from misc.redis import set_redis_slow_mode
from misc.util import highest_number
from model.datagenerator import create_train_test
from system.namespace.store import get_namespace

In [5]:
import torch

is_cuda = torch.cuda.is_available()
is_cuda

True

In [6]:
set_redis_slow_mode("never")
ns_test = get_namespace("test")
ns_train = get_namespace("train")
now = pd.Timestamp("2022-12-17", tz="UTC")
train_plan = [
    {
        "left": {"mode": "valid", "flip_pc": 0.5},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 100,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": 5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": 5,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    }
]
eval_plan = [
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "flip_lr": 0.5,
        "weight": 40,
    }
]
ttgen = create_train_test(
    train_ns=ns_train,
    train_validation_ns=ns_train,
    test_ns=ns_test,
    test_validation_ns=ns_test,
    train_learning_plan=train_plan,
    train_val_learning_plan=eval_plan,
    test_learning_plan=eval_plan,
    test_val_learning_plan=eval_plan,
    batch_size=4 if is_cuda else 8,
    epoch_batches=5000 if is_cuda else 500,
    train_val_size=10000 if is_cuda else 1000,
    test_size=10000 if is_cuda else 1000,
    test_val_size=10000 if is_cuda else 1000,
    compute_batch_size=100 if is_cuda else 100,
    now=now)

In [7]:
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel

In [8]:
device = torch.device("cuda") if is_cuda else torch.device("cpu")
device

device(type='cuda')

In [9]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def tokens(texts):
    res = tokenizer(texts.tolist(), return_tensors="pt", padding=True, truncation=True)
    return {k: v.to(device) for k, v in res.items()}

class Model(nn.Module):
    def __init__(self, version=0):
        super().__init__()
        self._bert_parent = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self._bert_child = DistilBertModel.from_pretrained("distilbert-base-uncased")
        if version > 0:
            embed_size = 768
            self._pdense = nn.Sequential(
                nn.Linear(embed_size, embed_size),
                nn.Dropout(p=0.5),
                nn.ReLU(),
                nn.Linear(embed_size, embed_size))
            self._cdense = nn.Sequential(
                nn.Linear(embed_size, embed_size),
                nn.Dropout(p=0.5),
                nn.ReLU(),
                nn.Linear(embed_size, embed_size))
        else:
            self._pdense = None
            self._cdense = None
        self._version = version
    
    def get_version(self):
        return self._version
    
    def get_parent_embed(self, input_ids, attention_mask):
        outputs_parent = self._bert_parent(input_ids=input_ids, attention_mask=attention_mask)
        out = outputs_parent.last_hidden_state[:, 0]
        if self._pdense is not None:
            out = self._pdense(out)
        return out
    
    def get_child_embed(self, input_ids, attention_mask):
        outputs_child = self._bert_child(input_ids=input_ids, attention_mask=attention_mask)
        out = outputs_child.last_hidden_state[:, 0]
        if self._cdense is not None:
            out = self._cdense(out)
        return out
        
    def forward(self, x):
        parent_cls = self.get_parent_embed(
            input_ids=x["parent"]["input_ids"],
            attention_mask=x["parent"]["attention_mask"])
        child_cls = self.get_child_embed(
            input_ids=x["child"]["input_ids"],
            attention_mask=x["child"]["attention_mask"])
        batch_size = parent_cls.shape[0]
        return torch.bmm(
            parent_cls.reshape([batch_size, 1, -1]),
            child_cls.reshape([batch_size, -1, 1])).reshape([-1, 1])
    
class TrainingHarness(nn.Module):
    def __init__(self, model):
        super().__init__()
        self._model = model
        self._softmax = nn.Softmax(dim=1)
        self._loss = nn.BCELoss()
        
    def get_version(self):
        return self._model.get_version()
        
    def forward(self, left, right, labels):
        out_left = self._model(left)
        out_right = self._model(right)
        preds = self._softmax(torch.hstack((out_left, out_right)))
        return preds, self._loss(preds, labels)

In [10]:
from torch.optim import AdamW

model = Model(version=0)
model.to(device)
harness = TrainingHarness(model)
harness.to(device)

folder = "checkpoints"
postfix = "_lg" if is_cuda else ""
version_tag = "" if harness.get_version() == 0 else f"_v{harness.get_version()}"
mprev = highest_number(os.listdir(folder), prefix=f"harness{version_tag}{postfix}_", postfix=".pkl")
if mprev is not None:
    prev_fname, prev_epoch = mprev
    harness.load_state_dict(torch.load(os.path.join(folder, prev_fname), map_location=device))
    epoch_offset = prev_epoch + 1
else:
    epoch_offset = 0

optimizer = AdamW(harness.parameters(), lr=5e-5)
mprev, epoch_offset

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias']
- T

(None, 0)

In [11]:
from transformers import get_scheduler
# from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import evaluate

def compute(df):
    plefts = tokens(df["parent_left"])
    clefts = tokens(df["child_left"])
    prights = tokens(df["parent_right"])
    crights = tokens(df["child_right"])
    labels = torch.tensor([~df["correct_is_right"], df["correct_is_right"]], dtype=torch.float32).T.to(device)
    return harness({"parent": plefts, "child": clefts}, {"parent": prights, "child": crights}, labels)

num_epochs = max((50 if is_cuda else 10) - epoch_offset, 3)
num_training_steps = num_epochs * ttgen.get_epoch_train_size()
warmup = 10000 if is_cuda else 10
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=warmup,
    num_training_steps=num_training_steps - warmup)
ttgen.set_epoch(epoch_offset)

for _ in range(num_epochs):
    epoch = ttgen.get_epoch()
    print(f"epoch {epoch}")
    
    model.train()
    harness.train()
    metric_train = evaluate.load("accuracy")
    train_loss = []
    with tqdm(desc="train", total=ttgen.get_epoch_train_size()) as progress_bar:
        for train_df in ttgen.train_dfs():
            preds, loss = compute(train_df)
            train_loss.append(loss.item())
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(train_df.shape[0])
            
            predictions = torch.argmax(preds, dim=-1)
            metric_train.add_batch(predictions=predictions, references=train_df["correct_is_right"].astype(int))

    folder = "checkpoints"
    postfix = "_lg" if is_cuda else ""
    version_tag = "" if harness.get_version() == 0 else f"_v{harness.get_version()}"
    torch.save(harness.state_dict(), os.path.join(folder, f"harness{version_tag}{postfix}_{epoch}.pkl"))
            
    model.eval()
    harness.eval()
    with torch.no_grad():
        metric_val_train = evaluate.load("accuracy")
        train_val_loss = []
        with tqdm(desc="train val", total=ttgen.get_epoch_train_validation_size()) as progress_bar:
            for train_validation_df in ttgen.train_validation_dfs():
                preds, loss = compute(train_validation_df)
                train_val_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_val_train.add_batch(
                    predictions=predictions, references=train_validation_df["correct_is_right"].astype(int))
                progress_bar.update(train_validation_df.shape[0])
        
        metric_test = evaluate.load("accuracy")
        test_loss = []
        with tqdm(desc="test", total=ttgen.get_epoch_test_size()) as progress_bar:
            for test_df in ttgen.test_dfs():
                preds, loss = compute(test_df)
                test_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_test.add_batch(
                    predictions=predictions, references=test_df["correct_is_right"].astype(int))
                progress_bar.update(test_df.shape[0])
        
        print(f"train: {metric_train.compute()} loss: {np.mean(train_loss)}")
        print(f"train val: {metric_val_train.compute()} loss: {np.mean(train_val_loss)}")
        print(f"test: {metric_test.compute()} loss: {np.mean(test_loss)}")
    ttgen.advance_epoch()

epoch 0


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.59725} loss: 3.215016347665597
train val: {'accuracy': 0.6037} loss: 1.2676223382135388
test: {'accuracy': 0.5362} loss: 1.640066008257866
epoch 1


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.55215} loss: 1.356880630963482
train val: {'accuracy': 0.6645} loss: 0.5935506270743907
test: {'accuracy': 0.5108} loss: 0.7949401561260223
epoch 2


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.63845} loss: 0.8638567955555627
train val: {'accuracy': 0.6867} loss: 0.5703577803210356
test: {'accuracy': 0.6221} loss: 0.684007783344388
epoch 3


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70515} loss: 0.5761682349750772
train val: {'accuracy': 0.6978} loss: 0.531854347718507
test: {'accuracy': 0.6391} loss: 0.6373122554540634
epoch 4


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.67145} loss: 0.7820889577804483
train val: {'accuracy': 0.6929} loss: 0.5516026649753563
test: {'accuracy': 0.5701} loss: 0.7105053052544594
epoch 5


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6976} loss: 0.6615889593004715
train val: {'accuracy': 0.7077} loss: 0.5430639030538499
test: {'accuracy': 0.6471} loss: 0.6503112724244594
epoch 6


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68105} loss: 0.5598640273844357
train val: {'accuracy': 0.7193} loss: 0.4853554356001317
test: {'accuracy': 0.6184} loss: 0.6416449435353279
epoch 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68765} loss: 0.5306327194914571
train val: {'accuracy': 0.7255} loss: 0.4841609169074334
test: {'accuracy': 0.6269} loss: 0.6338545219361782
epoch 8


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68845} loss: 0.5270150064961822
train val: {'accuracy': 0.7291} loss: 0.48135250300148036
test: {'accuracy': 0.6689} loss: 0.6086499064147473
epoch 9


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7038} loss: 0.5085206098142109
train val: {'accuracy': 0.7242} loss: 0.48195396753028036
test: {'accuracy': 0.6615} loss: 0.6295013772845268
epoch 10


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.66595} loss: 0.6008616559633985
train val: {'accuracy': 0.724} loss: 0.5410975919216872
test: {'accuracy': 0.6276} loss: 0.6586541964888573
epoch 11


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.67815} loss: 0.5800808531184215
train val: {'accuracy': 0.7267} loss: 0.5203870826303959
test: {'accuracy': 0.6724} loss: 0.637415790438652
epoch 12


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68615} loss: 0.5572254863353912
train val: {'accuracy': 0.7309} loss: 0.5148589306041599
test: {'accuracy': 0.6485} loss: 0.6505198984384537
epoch 13


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6259} loss: 1.1679691581056453
train val: {'accuracy': 0.7274} loss: 0.49587097853571177
test: {'accuracy': 0.6725} loss: 0.599519294911623
epoch 14


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.64035} loss: 0.6540469693265856
train val: {'accuracy': 0.7295} loss: 0.5012899103015661
test: {'accuracy': 0.663} loss: 0.6361900060296058
epoch 15


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.60805} loss: 13.46977786630094
train val: {'accuracy': 0.6799} loss: 11.654206052066385
test: {'accuracy': 0.6207} loss: 14.613930555030704
epoch 16


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6383} loss: 7.764325430430659
train val: {'accuracy': 0.7258} loss: 0.5106975454479455
test: {'accuracy': 0.7093} loss: 0.5691628396570683
epoch 17


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.66965} loss: 0.5981618671566248
train val: {'accuracy': 0.7309} loss: 0.5114579101443291
test: {'accuracy': 0.6925} loss: 0.604181079351902
epoch 18


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.66955} loss: 0.5948995905816555
train val: {'accuracy': 0.7313} loss: 0.4963505600064993
test: {'accuracy': 0.7} loss: 0.6039020679712296
epoch 19


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6708} loss: 1.0217318786270917
train val: {'accuracy': 0.7319} loss: 0.4988401906579733
test: {'accuracy': 0.718} loss: 0.5938542392492294
epoch 20


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.66905} loss: 0.5983502360880375
train val: {'accuracy': 0.7368} loss: 0.5386134144604207
test: {'accuracy': 0.6688} loss: 0.6261400929450989
epoch 21


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69025} loss: 0.5615797987293452
train val: {'accuracy': 0.7339} loss: 0.49715623763650657
test: {'accuracy': 0.6768} loss: 0.6181771191775799
epoch 22


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6949} loss: 0.6318552572019398
train val: {'accuracy': 0.7376} loss: 0.4935537839204073
test: {'accuracy': 0.6728} loss: 0.6103708609461784
epoch 23


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69305} loss: 0.5600816511372919
train val: {'accuracy': 0.7354} loss: 0.5119815194547176
test: {'accuracy': 0.6618} loss: 0.6289896616697311
epoch 24


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6914} loss: 0.6539482101985487
train val: {'accuracy': 0.7355} loss: 0.4943741336926818
test: {'accuracy': 0.6423} loss: 0.6373547699928284
epoch 25


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.68765} loss: 0.5523007018150762
train val: {'accuracy': 0.7394} loss: 0.5106113257631659
test: {'accuracy': 0.661} loss: 0.6170630213141441
epoch 26


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69075} loss: 0.7281559211090207
train val: {'accuracy': 0.726} loss: 0.5215313596203923
test: {'accuracy': 0.6064} loss: 0.6775986192584038
epoch 27


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69125} loss: 0.552034846326895
train val: {'accuracy': 0.736} loss: 0.5062943329572678
test: {'accuracy': 0.6464} loss: 0.6329363282322884
epoch 28


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70115} loss: 0.5454018932269886
train val: {'accuracy': 0.7395} loss: 0.49617588239908217
test: {'accuracy': 0.6839} loss: 0.6157445001244545
epoch 29


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.69795} loss: 0.5468665424466133
train val: {'accuracy': 0.7418} loss: 0.48380647393241527
test: {'accuracy': 0.6598} loss: 0.6401602240443229
epoch 30


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7078} loss: 0.5422624142155051
train val: {'accuracy': 0.7489} loss: 0.48084971620738504
test: {'accuracy': 0.6781} loss: 0.6218306931078434
epoch 31


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.66385} loss: 3.6176857344951947
train val: {'accuracy': 0.7396} loss: 0.8405964349493384
test: {'accuracy': 0.6693} loss: 0.8321590461909771
epoch 32


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.6987} loss: 0.570101989775151
train val: {'accuracy': 0.747} loss: 0.4846293524473906
test: {'accuracy': 0.6972} loss: 0.6160382192254067
epoch 33


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7012} loss: 0.542573688723892
train val: {'accuracy': 0.7455} loss: 0.48555241042673586
test: {'accuracy': 0.7106} loss: 0.5883459811031818
epoch 34


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70205} loss: 0.5426218068041839
train val: {'accuracy': 0.7493} loss: 0.4869929105088115
test: {'accuracy': 0.671} loss: 0.6212029481530189
epoch 35


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7046} loss: 0.6172896729204804
train val: {'accuracy': 0.7483} loss: 0.6501070472463966
test: {'accuracy': 0.6876} loss: 0.5958385620355606
epoch 36


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70545} loss: 0.5946243202528451
train val: {'accuracy': 0.7511} loss: 0.6332554916448891
test: {'accuracy': 0.6888} loss: 0.5924023243904114
epoch 37


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7026} loss: 0.7398373082470745
train val: {'accuracy': 0.7502} loss: 0.48757561110779646
test: {'accuracy': 0.6727} loss: 0.6127846806585788
epoch 38


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7079} loss: 0.5386074834896252
train val: {'accuracy': 0.7599} loss: 0.46419881380572914
test: {'accuracy': 0.6999} loss: 0.5849647282361984
epoch 39


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70765} loss: 0.5381948151288554
train val: {'accuracy': 0.7575} loss: 0.4674645638793707
test: {'accuracy': 0.6786} loss: 0.6061113358557224
epoch 40


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7096} loss: 0.5343287027924176
train val: {'accuracy': 0.756} loss: 0.47215787317454816
test: {'accuracy': 0.6842} loss: 0.5918662914663553
epoch 41


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.71105} loss: 0.5343817627638578
train val: {'accuracy': 0.7582} loss: 0.4612866611883044
test: {'accuracy': 0.6887} loss: 0.5961028850913048
epoch 42


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70655} loss: 0.5334180312976241
train val: {'accuracy': 0.7516} loss: 0.4852803424045444
test: {'accuracy': 0.6815} loss: 0.60403291939497
epoch 43


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7051} loss: 0.5384068049414084
train val: {'accuracy': 0.7572} loss: 0.4949209113910794
test: {'accuracy': 0.6678} loss: 0.62088674197793
epoch 44


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70935} loss: 0.5355684824734926
train val: {'accuracy': 0.7577} loss: 0.4665052006259561
test: {'accuracy': 0.6904} loss: 0.6193836279273033
epoch 45


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.71535} loss: 0.5278326802529104
train val: {'accuracy': 0.7568} loss: 0.4837630094602704
test: {'accuracy': 0.6841} loss: 0.5998912817835808
epoch 46


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.7177} loss: 0.532636313408066
train val: {'accuracy': 0.7542} loss: 0.48723518887013195
test: {'accuracy': 0.6738} loss: 0.6139456552147865
epoch 47


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.71255} loss: 0.5429256294395775
train val: {'accuracy': 0.7464} loss: 0.4943097806155682
test: {'accuracy': 0.6896} loss: 0.6046874371170997
epoch 48


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70785} loss: 0.540644616303593
train val: {'accuracy': 0.7507} loss: 0.4898062232732773
test: {'accuracy': 0.6726} loss: 0.6229939701914787
epoch 49


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: {'accuracy': 0.70455} loss: 0.5558582721833721
train val: {'accuracy': 0.7494} loss: 0.48485377307534216
test: {'accuracy': 0.6364} loss: 0.6390970001280307


In [12]:
folder = "."
postfix = "_lg" if is_cuda else ""
version_tag = "" if harness.get_version() == 0 else f"_v{harness.get_version()}"
torch.save(model.state_dict(), os.path.join(folder, f"model{version_tag}{postfix}.pkl"))
torch.save(harness.state_dict(), os.path.join(folder, f"harness{version_tag}{postfix}.pkl"))
torch.save(optimizer.state_dict(), os.path.join(folder, f"optimizer{version_tag}{postfix}.pkl"))

In [13]:
ttgen.reset()
model.eval()
harness.eval()
dfs = []
with torch.no_grad():
    metric_val_test = evaluate.load("accuracy")
    test_val_loss = []
    with tqdm(desc="test val", total=ttgen.get_epoch_test_validation_size()) as progress_bar:
        for test_val_df in ttgen.test_validation_dfs():
            preds, loss = compute(test_val_df)
            test_val_loss.append(loss.item())
            predictions = torch.argmax(preds, dim=-1)
            metric_val_test.add_batch(
                predictions=predictions, references=test_val_df["correct_is_right"].astype(int))
            cur_df = test_val_df.copy()
            cur_df["logit_left"] = preds[:, 0].cpu()
            cur_df["logit_right"] = preds[:, 1].cpu()
            cur_df["preds"] = predictions.cpu()
            cur_df["truth"] = test_val_df["correct_is_right"].astype(int)
            dfs.append(cur_df)
            progress_bar.update(test_val_df.shape[0])
print(f"test val: {metric_val_test.compute()} loss: {np.mean(test_val_loss)}")
validation_df = pd.concat(dfs)

test val:   0%|          | 0/10000 [00:00<?, ?it/s]

test val: {'accuracy': 0.6347} loss: 0.6420372705578804


In [14]:
validation_df.to_csv(os.path.join(folder, "validation.csv"))

In [15]:
validation_df[validation_df["preds"] == validation_df["truth"]].head()

Unnamed: 0,gen_name,parent_left,child_left,parent_right,child_right,sway_left,sway_right,correct_is_right,logit_left,logit_right,preds,truth
0,random--valid,Gonna ride the baloney pony all the way to Was...,anyone using “i’m a bad bitch” in an argument ...,Destiny’s Child was better than Beyoncé on her...,"The last iteration of them yeah, just thanks t...",0.268941,0.731059,True,0.364254,0.635746,1,1
1,random--valid,User profile checks out. Bet you're on a list ...,I don’t know why. From what I can gather from ...,[Serious] How would you react if your partner ...,Best kept as a fantasy.,0.268941,0.731059,True,0.251961,0.748039,1,1
2,random--valid,Me too! Guess I'm alone in liking them.,“Accidental” pregnancy. If she’s “accidentally...,Why is it terrible? If you sold a painting for...,As a programmer I sell my time and skill to a ...,0.268941,0.731059,True,0.159981,0.840019,1,1
3,*valid--!copy,[TIL that owls and eagles hate each other.They...,Day-bird aaaAAAAaaaa Fighter of the Night-bird...,Day-bird aaaAAAAaaaa Fighter of the Night-bird...,[TIL that owls and eagles hate each other.They...,0.731059,0.268941,False,0.650724,0.349276,0,0
0,*valid--random,“Leslie I typed your symptoms into the compute...,"""Leslie I tried to make ramen in the coffee ma...",Seriously. Threads like this are really predic...,Playing games !!! DuH,0.999998,2e-06,False,0.613233,0.386767,0,0


In [16]:
validation_df[validation_df["preds"] != validation_df["truth"]].head()

Unnamed: 0,gen_name,parent_left,child_left,parent_right,child_right,sway_left,sway_right,correct_is_right,logit_left,logit_right,preds,truth
1,*valid--random,Oh? You sat down wrong? *immense pain*,Another one: have you ever been running in lik...,Incredible that he did it by trying to cut a w...,Just Spaghetti with red sauce from a jar. Not ...,0.952574,0.047426,False,0.458417,0.541583,1,0
3,*valid--!copy,My husband was astounded when he found men's s...,Please name brands. I didn’t know they made st...,Please name brands. I didn’t know they made st...,My husband was astounded when he found men's s...,0.731059,0.268941,False,0.272118,0.727882,1,0
0,*valid--random,That would definitely raise red flags! That’s ...,It’s Jeffrey Dahmer vibes,"In Spain/Portugal we also cook them ""a feira"" ...",I was always confused how she became the more ...,0.731059,0.268941,False,0.318889,0.681111,1,0
2,random--valid,"Yep, this one came to mind. She is incredible ...","The more money I make, the more frugal and les...",What is the most egregious display of wealth y...,2 Chainz has a show called the most expensives...,0.268941,0.731059,True,0.568019,0.431981,0,1
3,*valid--!copy,BlackBerry without the keyboard. Fell flat on ...,Didn't the CEO turn his focus to owning a hock...,Didn't the CEO turn his focus to owning a hock...,BlackBerry without the keyboard. Fell flat on ...,0.880797,0.119203,False,0.430271,0.569729,1,0
