In [1]:
# in a terminal run
# > USER_PATH=/mnt/d/workspace/clotho/userdata/ make run-redis NS=train
# > USER_PATH=/mnt/d/workspace/clotho/userdata/ make run-redis NS=test
# to allow access to the train and test namespaces

In [2]:
import os
import sys
import json
import pandas as pd
import numpy as np

In [3]:
sys.path.append("..")
os.environ["USER_PATH"] = "/mnt/d/workspace/clotho/userdata/"
MODEL_OUTPUT_BASE = "/mnt/d/workspace/clotho/notebooks"
MODEL_OUTPUT_CP = os.path.join(MODEL_OUTPUT_BASE, "checkpoints")

In [4]:
from misc.redis import set_redis_slow_mode
from misc.util import highest_number
from model.datagenerator import create_train_test
from system.namespace.store import get_namespace

In [5]:
import torch

is_cuda = torch.cuda.is_available()
is_cuda

True

In [6]:
set_redis_slow_mode("never")
ns_test = get_namespace("test")
ns_train = get_namespace("train")
now = pd.Timestamp("2022-12-17", tz="UTC")
train_plan = [
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 10,
        "last_epoch": None,
        "weight": 50,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
     {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "path", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 5,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": None,
        "last_epoch": None,
        "weight": 40,
    },
    {
        "left": {"mode": "valid", "flip_pc": 1.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "first_epoch": 15,
        "last_epoch": None,
        "weight": 50,
    }
]
eval_plan = [
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": 20,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
    {
        "left": {"mode": "random", "flip_pc": 0.0},
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 60,
    },
    {
        "left": None,
        "right": {"mode": "valid", "flip_pc": 0.0},
        "min_text_length": None,
        "skip_weak": True,
        "skip_topics": True,
        "flip_lr": 0.5,
        "weight": 40,
    },
]
ttgen = create_train_test(
    train_ns=ns_train,
    train_validation_ns=ns_train,
    test_ns=ns_test,
    test_validation_ns=ns_test,
    train_learning_plan=train_plan,
    train_val_learning_plan=eval_plan,
    test_learning_plan=eval_plan,
    test_val_learning_plan=eval_plan,
    batch_size=4 if is_cuda else 8,
    epoch_batches=5000 if is_cuda else 500,
    train_val_size=10000 if is_cuda else 1000,
    test_size=10000 if is_cuda else 1000,
    test_val_size=10000 if is_cuda else 1000,
    compute_batch_size=100 if is_cuda else 100,
    now=now)

In [7]:
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel

In [8]:
device = torch.device("cuda") if is_cuda else torch.device("cpu")
device

device(type='cuda')

In [9]:
from typing import Literal, TypedDict

ProviderRole = Literal["child", "parent"]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
EMBED_SIZE = 768

TokenizedInput = TypedDict('TokenizedInput', {
    "input_ids": torch.Tensor,
    "attention_mask": torch.Tensor,
})


AggType = Literal["cls", "mean"]
AGG_CLS: AggType = "cls"
AGG_MEAN: AggType = "mean"


def tokens(texts: list[str]) -> TokenizedInput:
    res = tokenizer(texts.tolist(), return_tensors="pt", padding=True, truncation=True)
    return {k: v.to(device) for k, v in res.items()}


class Noise(nn.Module):
    def __init__(self, std: float = 1.0, p: float = 0.5) -> None:
        super().__init__()
        self._std = std
        self._p = p
        self._dhold = nn.Parameter(torch.Tensor([0.0]), requires_grad=False)

    def set_std(self, std: float) -> None:
        self._std = std

    def get_std(self) -> float:
        return self._std

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return x
        prob = torch.rand(size=x.shape, device=self._dhold.device) < self._p
        gauss = torch.normal(
            mean=0.0, std=self._std, size=x.shape, device=self._dhold.device)
        return x + prob * gauss


class Model(nn.Module):
    def __init__(self, version: int) -> None:
        super().__init__()
        self._bert_parent = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        self._bert_child = DistilBertModel.from_pretrained(
            "distilbert-base-uncased")
        if version in (1, 3, 4, 6):
            self._pdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
            self._cdense: nn.Sequential | None = nn.Sequential(
                nn.Linear(EMBED_SIZE, EMBED_SIZE),
                nn.Dropout(p=0.2),
                nn.ReLU(),
                nn.Linear(EMBED_SIZE, EMBED_SIZE))
        else:
            self._pdense = None
            self._cdense = None
        if version < 4 or version > 5:
            self._noise = None
        else:
            self._noise = Noise(std=1.0, p=0.2)
        if version < 2 or version > 4:
            self._cos = None
        else:
            self._cos = torch.nn.CosineSimilarity()
        if version < 6:
            self._agg = AGG_CLS
        else:
            self._agg = AGG_MEAN
        self._version = version

    def set_epoch(self, epoch: int) -> None:
        noise = self._noise
        if noise is not None:
            noise.set_std(1 / (1.2 ** epoch))

    def get_version(self) -> int:
        return self._version
    
    def get_agg(self, lhs: torch.Tensor) -> torch.Tensor:
        if self._agg == AGG_CLS:
            return lhs[:, 0]
        if self._agg == AGG_MEAN:
            return torch.mean(lhs, dim=1)
        raise ValueError(f"unknown aggregation: {self._agg}")

    def get_parent_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_parent = self._bert_parent(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_parent.last_hidden_state)
        if self._pdense is not None:
            out = self._pdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def get_child_embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_child = self._bert_child(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self.get_agg(outputs_child.last_hidden_state)
        if self._cdense is not None:
            out = self._cdense(out)
        if self._noise is not None:
            out = self._noise(out)
        return out

    def forward(self, x: dict[ProviderRole, TokenizedInput]) -> torch.Tensor:
        parent_cls = self.get_parent_embed(
            input_ids=x["parent"]["input_ids"],
            attention_mask=x["parent"]["attention_mask"])
        child_cls = self.get_child_embed(
            input_ids=x["child"]["input_ids"],
            attention_mask=x["child"]["attention_mask"])
        if self._cos is not None:
            return self._cos(parent_cls, child_cls).reshape([-1, 1])
        batch_size = parent_cls.shape[0]
        return torch.bmm(
            parent_cls.reshape([batch_size, 1, -1]),
            child_cls.reshape([batch_size, -1, 1])).reshape([-1, 1])


class TrainingHarness(nn.Module):
    def __init__(self, model: Model) -> None:
        super().__init__()
        self._model = model
        self._softmax = nn.Softmax(dim=1)
        self._loss = nn.BCELoss()

    def get_version(self) -> int:
        return self._model.get_version()

    def forward(
            self,
            left: TokenizedInput,
            right: TokenizedInput,
            labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        out_left = self._model(left)
        out_right = self._model(right)
        preds = self._softmax(torch.hstack((out_left, out_right)))
        return preds, self._loss(preds, labels)

In [10]:
def get_filename(
        harness: TrainingHarness,
        *,
        is_final: bool,
        ftype: str,
        epoch: int | None,
        ext: str = ".pkl") -> tuple[str, str, str]:
    folder = MODEL_OUTPUT_BASE if is_final else MODEL_OUTPUT_CP
    postfix = "_lg" if is_cuda else ""
    version_tag = f"_v{harness.get_version()}"
    out_pre = f"{ftype}{version_tag}{postfix}_"
    out_post = ext
    return (
        os.path.join(
            folder,
            f"{out_pre}{'' if epoch is None else epoch}{out_post}"),
        out_pre,
        out_post,
        folder,
    )

In [11]:
from torch.optim import AdamW

model = Model(version=7)
model.to(device)
harness = TrainingHarness(model)
harness.to(device)

FORCE_RESTART = False

_, spre, spost, sfolder = get_filename(harness, is_final=False, ftype="harness", epoch=None)
mprev = highest_number(os.listdir(sfolder), prefix=spre, postfix=spost)
if not FORCE_RESTART and mprev is not None:
    prev_fname, prev_epoch = mprev
    harness.load_state_dict(torch.load(
        os.path.join(MODEL_OUTPUT_CP, prev_fname), map_location=device))
    epoch_offset = prev_epoch + 1
else:
    epoch_offset = 0

optimizer = AdamW(harness.parameters(), lr=5e-5)
mprev, epoch_offset

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- T

(None, 0)

In [None]:
from transformers import get_scheduler
# from tqdm.notebook import tqdm
from tqdm.auto import tqdm
import evaluate
import time

def compute(df):
    plefts = tokens(df["parent_left"])
    clefts = tokens(df["child_left"])
    prights = tokens(df["parent_right"])
    crights = tokens(df["child_right"])
    labels = torch.tensor([~df["correct_is_right"], df["correct_is_right"]], dtype=torch.float32).T.to(device)
    return harness({"parent": plefts, "child": clefts}, {"parent": prights, "child": crights}, labels)

num_epochs = max((120 if is_cuda else 10) - epoch_offset, 3)
num_training_steps = num_epochs * ttgen.get_epoch_train_size()
warmup = 10000 if is_cuda else 10
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=warmup,
    num_training_steps=num_training_steps - warmup)
ttgen.set_epoch(epoch_offset)

for _ in range(num_epochs):
    epoch = ttgen.get_epoch()
    print(f"epoch {epoch}")
    real_time = time.monotonic()
    
    model.train()
    harness.train()
    model.set_epoch(epoch)
    metric_train = evaluate.load("accuracy")
    train_loss = []
    first = True
    with tqdm(desc="train", total=ttgen.get_epoch_train_size()) as progress_bar:
        for train_df in ttgen.train_dfs():
            preds, loss = compute(train_df)
            train_loss.append(loss.item())
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(train_df.shape[0])
            
            predictions = torch.argmax(preds, dim=-1)
            metric_train.add_batch(predictions=predictions, references=train_df["correct_is_right"].astype(int))
            if first:
                # display(train_df)
                first = False

    model_fname = get_filename(harness, is_final=False, ftype="harness", epoch=epoch)[0]
    torch.save(harness.state_dict(), model_fname)
            
    model.eval()
    harness.eval()
    with torch.no_grad():
        metric_val_train = evaluate.load("accuracy")
        train_val_loss = []
        with tqdm(desc="train val", total=ttgen.get_epoch_train_validation_size()) as progress_bar:
            for train_validation_df in ttgen.train_validation_dfs():
                preds, loss = compute(train_validation_df)
                train_val_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_val_train.add_batch(
                    predictions=predictions, references=train_validation_df["correct_is_right"].astype(int))
                progress_bar.update(train_validation_df.shape[0])
        
        metric_test = evaluate.load("accuracy")
        test_loss = []
        with tqdm(desc="test", total=ttgen.get_epoch_test_size()) as progress_bar:
            for test_df in ttgen.test_dfs():
                preds, loss = compute(test_df)
                test_loss.append(loss.item())
                predictions = torch.argmax(preds, dim=-1)
                metric_test.add_batch(
                    predictions=predictions, references=test_df["correct_is_right"].astype(int))
                progress_bar.update(test_df.shape[0])
        stats = {
            "epoch": int(epoch),
            "train_acc": float(metric_train.compute()['accuracy']),
            "train_loss": float(np.mean(train_loss)),
            "train_val_acc": float(metric_val_train.compute()['accuracy']),
            "train_val_loss": float(np.mean(train_val_loss)),
            "test_acc": float(metric_test.compute()['accuracy']),
            "test_loss": float(np.mean(test_loss)),
            "time": 0.0,
            "version": harness.get_version(),
            "fname": model_fname,
        }
    
    print(f"train: {stats['train_acc']} loss: {stats['train_loss']}")
    print(f"train val: {stats['train_val_acc']} loss: {stats['train_val_loss']}")
    print(f"test: {stats['test_acc']} loss: {stats['test_loss']}")
    ttgen.advance_epoch()
    stats["time"] = float((time.monotonic() - real_time) / 60.0)
    print(f"epoch time: {stats['time']:.2f}min")
    stats_fn = get_filename(harness, is_final=False, ftype="stats", epoch=epoch, ext=".json")[0]
    with open(stats_fn, mode="w", encoding="utf-8") as fout:
        print(json.dumps(stats, indent=2, sort_keys=True), file=fout)

epoch 0


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.58425 loss: 1.7016048232006662
train val: 0.6266 loss: 0.8229065572973341
test: 0.5691 loss: 0.9697318340606987
epoch time: 333.69min
epoch 1


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.7201 loss: 0.5844099140023813
train val: 0.7331 loss: 0.5326953510738909
test: 0.6489 loss: 0.7034539060860873
epoch time: 54.17min
epoch 2


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.7901 loss: 0.44431922408516983
train val: 0.7632 loss: 0.562996527594747
test: 0.6576 loss: 0.7744384820755571
epoch time: 49.65min
epoch 3


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.7963 loss: 0.4367930789816659
train val: 0.7182 loss: 0.6399536329229828
test: 0.6259 loss: 0.8038391570530832
epoch time: 78.72min
epoch 4


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.7915 loss: 0.4177847496665956
train val: 0.7374 loss: 0.4800995312042534
test: 0.6152 loss: 0.6480094968914986
epoch time: 83.26min
epoch 5


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.831 loss: 0.348318674525588
train val: 0.7343 loss: 0.5157822916574776
test: 0.6149 loss: 0.6988879308015108
epoch time: 57.24min
epoch 6


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8435 loss: 0.3217397753254656
train val: 0.7262 loss: 0.5227155487713869
test: 0.607 loss: 0.7052130870759488
epoch time: 52.15min
epoch 7


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8443 loss: 0.3465921436331999
train val: 0.7351 loss: 0.5247397595138289
test: 0.6301 loss: 0.6805898647159337
epoch time: 50.64min
epoch 8


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.84995 loss: 0.3495053067108052
train val: 0.738 loss: 0.5889556806535576
test: 0.6272 loss: 0.6719326809406281
epoch time: 51.42min
epoch 9


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85775 loss: 0.3033361676988074
train val: 0.7252 loss: 0.5177842367846519
test: 0.6068 loss: 0.6844373150289059
epoch time: 50.28min
epoch 10


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8334 loss: 0.35053460838331013
train val: 0.7267 loss: 0.5008562006552587
test: 0.6001 loss: 0.6632093757987022
epoch time: 52.26min
epoch 11


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8187 loss: 0.3371719292360709
train val: 0.6969 loss: 0.6869142449943348
test: 0.6124 loss: 0.6658628508746623
epoch time: 52.93min
epoch 12


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.82195 loss: 0.3486729122198949
train val: 0.7073 loss: 0.5263039422301575
test: 0.6185 loss: 0.6792970645070076
epoch time: 54.24min
epoch 13


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8302 loss: 0.3322572812802618
train val: 0.7112 loss: 0.5105196523564518
test: 0.6261 loss: 0.6508558141112327
epoch time: 54.61min
epoch 14


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.791 loss: 0.46797306476727096
train val: 0.5944 loss: 0.6277429662108421
test: 0.5368 loss: 0.7014953180909157
epoch time: 53.59min
epoch 15


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.7394 loss: 2.8661329350242157
train val: 0.5846 loss: 0.6715769128799438
test: 0.542 loss: 0.7173304299235344
epoch time: 53.34min
epoch 16


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8134 loss: 0.35445660410274293
train val: 0.7171 loss: 0.4979819887930789
test: 0.6229 loss: 0.647068657529354
epoch time: 52.00min
epoch 17


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.83005 loss: 0.3332631499380779
train val: 0.7122 loss: 0.5177623570443189
test: 0.5988 loss: 0.6788009871423244
epoch time: 51.14min
epoch 18


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8293 loss: 0.3284079703728157
train val: 0.7147 loss: 0.6760426752098064
test: 0.6171 loss: 0.7993671876266599
epoch time: 52.09min
epoch 19


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8336 loss: 0.36098291223074547
train val: 0.728 loss: 0.5908226754452102
test: 0.6023 loss: 0.6668382709980011
epoch time: 51.08min
epoch 20


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.83755 loss: 0.3741421945009738
train val: 0.7291 loss: 0.5078853075744584
test: 0.6056 loss: 0.6896208468854428
epoch time: 51.47min
epoch 21


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8418 loss: 0.32885100786178917
train val: 0.7299 loss: 0.49385302046687574
test: 0.6299 loss: 0.6495373114287853
epoch time: 52.07min
epoch 22


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.84375 loss: 0.3118581853313943
train val: 0.7287 loss: 0.4980291385424556
test: 0.6026 loss: 0.6669254047691822
epoch time: 51.54min
epoch 23


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8401 loss: 0.37434244716501636
train val: 0.7255 loss: 0.4895509145147633
test: 0.5934 loss: 0.6758750948131085
epoch time: 51.40min
epoch 24


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8312 loss: 0.4690802422107467
train val: 0.7226 loss: 0.5146097725519445
test: 0.574 loss: 0.7015061695933342
epoch time: 51.43min
epoch 25


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8482 loss: 0.30103933455690857
train val: 0.7251 loss: 0.49173135534108153
test: 0.5959 loss: 0.6681152654111385
epoch time: 51.53min
epoch 26


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.84455 loss: 0.3145650933628349
train val: 0.7307 loss: 0.490730106620444
test: 0.6074 loss: 0.6631485035598278
epoch time: 51.82min
epoch 27


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8501 loss: 0.3036091169115069
train val: 0.7274 loss: 0.4862791467270465
test: 0.6269 loss: 0.6414907908320427
epoch time: 51.65min
epoch 28


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.852 loss: 0.29487244062708523
train val: 0.7374 loss: 0.47689622142451116
test: 0.6318 loss: 0.6373275725662708
epoch time: 55.00min
epoch 29


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.84705 loss: 0.29884650976791016
train val: 0.7358 loss: 0.4983939809558622
test: 0.6082 loss: 0.6678283493161201
epoch time: 68.25min
epoch 30


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85625 loss: 0.289861863061733
train val: 0.7328 loss: 0.49625183301592335
test: 0.6104 loss: 0.6754354777574539
epoch time: 114.25min
epoch 31


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85005 loss: 0.293192453554967
train val: 0.7331 loss: 0.537525435579855
test: 0.5932 loss: 0.7388740806907416
epoch time: 64.75min
epoch 32


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8555 loss: 0.2917489332969206
train val: 0.7392 loss: 0.4896272788348375
test: 0.6055 loss: 0.6798496955513954
epoch time: 59.79min
epoch 33


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8522 loss: 0.28844024254819467
train val: 0.737 loss: 0.47475037404680626
test: 0.5971 loss: 0.6639123582005501
epoch time: 55.64min
epoch 34


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8466 loss: 0.30799672887503693
train val: 0.7321 loss: 0.5027280655072304
test: 0.5863 loss: 0.6907152214407921
epoch time: 56.15min
epoch 35


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8594 loss: 0.28535218964936465
train val: 0.7427 loss: 0.47883324090929236
test: 0.6042 loss: 0.6680307479023934
epoch time: 55.33min
epoch 36


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85885 loss: 0.2864082611145015
train val: 0.7419 loss: 0.46911744398618177
test: 0.6032 loss: 0.6585608354747295
epoch time: 54.74min
epoch 37


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86135 loss: 0.2831154358044152
train val: 0.7455 loss: 0.4833774195801467
test: 0.6158 loss: 0.6543838939905167
epoch time: 53.42min
epoch 38


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8652 loss: 0.2760298643457743
train val: 0.7444 loss: 0.47859468688251217
test: 0.6026 loss: 0.6566339807510376
epoch time: 52.38min
epoch 39


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8479 loss: 0.30739739841232805
train val: 0.7407 loss: 0.4817006700458936
test: 0.5993 loss: 0.6797730654120445
epoch time: 52.43min
epoch 40


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8568 loss: 0.286102233939078
train val: 0.7485 loss: 0.46472560033368643
test: 0.6102 loss: 0.654618021607399
epoch time: 59.84min
epoch 41


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85875 loss: 0.28458999247050815
train val: 0.7459 loss: 0.47277563805626704
test: 0.5792 loss: 0.7032617125809193
epoch time: 54.83min
epoch 42


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.85955 loss: 0.283212071798907
train val: 0.7455 loss: 0.4631311293406412
test: 0.6036 loss: 0.6805636981248856
epoch time: 54.67min
epoch 43


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86405 loss: 0.2781021273562939
train val: 0.7442 loss: 0.4983799922336824
test: 0.6195 loss: 0.7220169494271278
epoch time: 53.70min
epoch 44


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86615 loss: 0.27280620534088523
train val: 0.7459 loss: 0.47208247802078257
test: 0.6086 loss: 0.6714882531285286
epoch time: 53.78min
epoch 45


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86805 loss: 0.26982001117168336
train val: 0.7401 loss: 0.46770215828985146
test: 0.5788 loss: 0.6770075541198254
epoch time: 52.47min
epoch 46


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86595 loss: 0.2711243385814366
train val: 0.7458 loss: 0.46509824596667604
test: 0.595 loss: 0.6763152911007404
epoch time: 52.01min
epoch 47


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.63895 loss: 8.845353162932
train val: 0.5868 loss: 1.0870840163886548
test: 0.5404 loss: 1.2051032441258431
epoch time: 54.04min
epoch 48


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.68155 loss: 0.9157308402884475
train val: 0.5832 loss: 0.8250112376511097
test: 0.5441 loss: 0.8061847987651825
epoch time: 52.73min
epoch 49


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.69555 loss: 0.7411252787271542
train val: 0.5898 loss: 0.7826881280094385
test: 0.5426 loss: 0.8088748101472855
epoch time: 52.58min
epoch 50


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.72115 loss: 0.6636455060861918
train val: 0.6999 loss: 0.5165067438177764
test: 0.5436 loss: 0.7066317106187343
epoch time: 53.34min
epoch 51


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.84405 loss: 0.30785670894012074
train val: 0.7356 loss: 0.4772800112699275
test: 0.5818 loss: 0.6714881829500199
epoch time: 52.14min
epoch 52


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8478 loss: 0.3036768703407142
train val: 0.7193 loss: 0.5085716403590981
test: 0.5572 loss: 0.7202702605187893
epoch time: 52.90min
epoch 53


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8537 loss: 0.3000787537625094
train val: 0.7296 loss: 0.4890913811423816
test: 0.6055 loss: 0.6731235923469067
epoch time: 52.71min
epoch 54


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8541 loss: 0.29334776586494415
train val: 0.7355 loss: 0.48810932410694075
test: 0.5966 loss: 0.6808747551679611
epoch time: 52.64min
epoch 55


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86075 loss: 0.2840742835275881
train val: 0.7373 loss: 0.48643714387230574
test: 0.6051 loss: 0.6803384600579738
epoch time: 123.61min
epoch 56


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8597 loss: 0.2799607008858802
train val: 0.7451 loss: 0.483679813411762
test: 0.6102 loss: 0.6711524618446827
epoch time: 78.39min
epoch 57


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8624 loss: 0.27126667379977737
train val: 0.7466 loss: 0.4707179947748256
test: 0.5819 loss: 0.6689758400082588
epoch time: 72.54min
epoch 58


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86445 loss: 0.2732171893467239
train val: 0.7515 loss: 0.4604816907688335
test: 0.5955 loss: 0.6676182787895203
epoch time: 61.54min
epoch 59


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.868 loss: 0.27016546068248826
train val: 0.7559 loss: 0.5087072968559805
test: 0.6091 loss: 0.7190360497549176
epoch time: 56.51min
epoch 60


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8677 loss: 0.2710957050444278
train val: 0.7536 loss: 0.4581650145623833
test: 0.6122 loss: 0.6632554172813893
epoch time: 55.77min
epoch 61


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8669 loss: 0.2664857329281571
train val: 0.755 loss: 0.45633287740759554
test: 0.6219 loss: 0.6596266486644745
epoch time: 54.48min
epoch 62


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8755 loss: 0.25450448877428355
train val: 0.7509 loss: 0.47119444829378915
test: 0.6005 loss: 0.6793373610854149
epoch time: 53.80min
epoch 63


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8678 loss: 0.26596218562953144
train val: 0.7581 loss: 0.45500611742852487
test: 0.6111 loss: 0.6635620620667935
epoch time: 54.69min
epoch 64


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86925 loss: 0.2625599035356042
train val: 0.7628 loss: 0.46279072239208036
test: 0.5983 loss: 0.7024875634849072
epoch time: 54.01min
epoch 65


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.86875 loss: 0.2601352408942393
train val: 0.7602 loss: 0.4714574421689709
test: 0.5929 loss: 0.6814809662878514
epoch time: 53.82min
epoch 66


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87485 loss: 0.25497195980991744
train val: 0.7572 loss: 0.47229601384980924
test: 0.5958 loss: 0.7110398691147566
epoch time: 52.36min
epoch 67


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8698 loss: 0.2671685783052791
train val: 0.7502 loss: 0.45718999668704347
test: 0.5844 loss: 0.6862625468313694
epoch time: 53.84min
epoch 68


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87395 loss: 0.25859209363247354
train val: 0.7555 loss: 0.4620616571449049
test: 0.587 loss: 0.7002534094750881
epoch time: 52.47min
epoch 69


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87265 loss: 0.2579258313120132
train val: 0.7622 loss: 0.46328385009944906
test: 0.5848 loss: 0.7007347873359918
epoch time: 53.04min
epoch 70


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.8746 loss: 0.25362286223055736
train val: 0.7658 loss: 0.4552601112201024
test: 0.6083 loss: 0.6854997290492058
epoch time: 59.32min
epoch 71


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87665 loss: 0.25719869095718956
train val: 0.7624 loss: 0.46288597749505206
test: 0.5966 loss: 0.6973749386787415
epoch time: 56.61min
epoch 72


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.87475 loss: 0.25486205479264024
train val: 0.7602 loss: 0.4784235259650566
test: 0.5948 loss: 0.6973274705350399
epoch time: 54.62min
epoch 73


train:   0%|          | 0/20000 [00:00<?, ?it/s]

train val:   0%|          | 0/10000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train: 0.876 loss: 0.2544052758469386
train val: 0.7612 loss: 0.4645390307061549
test: 0.5927 loss: 0.6972030085593462
epoch time: 53.64min
epoch 74


train:   0%|          | 0/20000 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), get_filename(harness, is_final=True, ftype="model", epoch=None)[0])
torch.save(harness.state_dict(), get_filename(harness, is_final=True, ftype="harness", epoch=None)[0])
torch.save(optimizer.state_dict(), get_filename(harness, is_final=True, ftype="optimizer", epoch=None)[0])

In [None]:
ttgen.reset()
model.eval()
harness.eval()
dfs = []
with torch.no_grad():
    metric_val_test = evaluate.load("accuracy")
    test_val_loss = []
    with tqdm(desc="test val", total=ttgen.get_epoch_test_validation_size()) as progress_bar:
        for test_val_df in ttgen.test_validation_dfs():
            preds, loss = compute(test_val_df)
            test_val_loss.append(loss.item())
            predictions = torch.argmax(preds, dim=-1)
            metric_val_test.add_batch(
                predictions=predictions, references=test_val_df["correct_is_right"].astype(int))
            cur_df = test_val_df.copy()
            cur_df["logit_left"] = preds[:, 0].cpu()
            cur_df["logit_right"] = preds[:, 1].cpu()
            cur_df["preds"] = predictions.cpu()
            cur_df["truth"] = test_val_df["correct_is_right"].astype(int)
            dfs.append(cur_df)
            progress_bar.update(test_val_df.shape[0])
print(f"test val: {metric_val_test.compute()} loss: {np.mean(test_val_loss)}")
validation_df = pd.concat(dfs)

In [None]:
validation_df.to_csv(
    get_filename(harness, is_final=True, ftype="validation", epoch=None, ext=".csv")[0])

In [None]:
validation_df[validation_df["preds"] == validation_df["truth"]].head()

In [None]:
validation_df[validation_df["preds"] != validation_df["truth"]].head()

In [None]:
ttgen.reset()
model.eval()
harness.eval()
with torch.no_grad():
    count = 0
    for test_val_df in ttgen.test_validation_dfs():
        plefts = tokens(test_val_df["parent_left"])
        clefts = tokens(test_val_df["child_left"])
        prights = tokens(test_val_df["parent_right"])
        crights = tokens(test_val_df["child_right"])
        display(model.get_child_embed(
            clefts["input_ids"],
            clefts["attention_mask"]).cpu().numpy())
        display(model.get_child_embed(
            crights["input_ids"],
            crights["attention_mask"]).cpu().numpy())
        count += 1
        if count >= 5:
            break