In [1]:
import glob
import time
import os
import datetime
import math
import datasets
from datasets import DatasetDict, load_dataset, Dataset
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    T5Tokenizer,
    T5ForSequenceClassification,
    T5Config,
    BertTokenizer, 
    RobertaTokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from load_set import load_set
from epoch_stats import EpochStats#, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    train_bern_model, 
    mask_tokens, 
    preprocess_with_given_labels, 
    num_parameters, 
    num_trainable_parameters, 
    preprocess_for_causallm, 
    preprocess_for_multiple_choice,
    preprocess_for_seq2seq_swag
)
import bert_i1_1_modeling as bert_i1_1
import coin_i2C_modeling as ci2C
import coin_i2D_modeling as ci2D
import coin_i3A_modeling as ci3A
import coin_i3B_modeling as ci3B
import coin_i3C_modeling as ci3C

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TO_FILE = 1
TRAIN_BATCH_SIZE = 8
TEST_BATCH_SIZE = 8
CHECKPOINT_PATH = None # datetime.datetime.now().strftime("tmp_models/rann_sffn/run_part_load_%Y-%m-%d_%H:%M:%S")
USE_CUSTOM_DATALOADER = True
SHUFFLE_CUSTOM_DATALOADER = True
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 10

In [3]:
CHECKPOINT_PATH

In [4]:
VOCAB_SIZE = 30_522
#VOCAB_SIZE = 32_000
#VOCAB_SIZE = 52_000
MAX_POSITION_EMBEDDINGS = 128#24**2#512
IS_HF_MODEL = False
GENERIC_OUTPUT_CLASS = True

In [5]:
# ALWAYS CHECK num_labels, RFFN doesn't throw an error on a wrong parameter
rffn_base_model_path = "tmp_models/COIN-i3C_mcca-translation-en-de_0029-500k_1x2_1dec-none_no-revert_chunkwise_group-exp_congen-head_B10_multi-query-2_switch-ii_mpe128/"
#rffn_base_model_path = "tmp_models/COIN-i2B_oasst1_25k_1x3_000dec_none-decoder-revert-out_chunkwise_nt-case-3_2-decay-parts_allow-enc-tf/"
#rffn_base_model_path = "tmp_models/RRB_oasst1_25k_2-2-encoder_0-1-decoder_decay_maskedLM.15_.2share_docx1_wtf/"
#rffn_tokenizer_path = "pretrained_models/rffn_wikitext_516_tokenizer"
rffn_tokenizer_path = rffn_base_model_path

In [6]:
# default, sst2, swag, uni-main-hyp
TEST_METHOD = "sst2"
ILOC_LIMIT = None
DEFAULT_TEACHER_FORCING = False
DOC_PAD_TOKENS = False

In [7]:
if TEST_METHOD == "default":
    NUM_LABELS = 7
elif TEST_METHOD == "sst2":
    NUM_LABELS = 2
elif TEST_METHOD == "swag":
    NUM_LABELS = 4
elif TEST_METHOD == "uni-main-hyp":
    NUM_LABELS = 10

TEST_SST2 = TEST_METHOD == "sst2"
ONE_LABEL_ONLY = TEST_SST2

In [8]:
CHECK_RUN = False

In [9]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3C.COINForSequenceClassification(
        config=ci3C.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            forward_method="chunkwise",
            apply_decay=False,
            num_decay_parts=2,
            hidden_retention_act="relu",
            hidden_pos_offset=True,
            rope_dim=16,
            num_query_heads=2,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[1, 1],
            cross_encoder_schema=[0, 0],
            experts_schema=None,#[2, 2],
            block_io_schema=None,#[[1024, 1024*4, 1024], [1024, 1024*2, 1024]],
        )
    )

In [10]:
if 1:
    N_ITER = 9
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci3C.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

num layers: 2
gamma schema: [[0.96875, 0.9875984191894531], [0.995078444480896, 0.998046875]]
layer 0 num experts: 1
layer 1 num experts: 1


Some weights of COINForSequenceClassification were not initialized from the model checkpoint at tmp_models/COIN-i3C_mcca-translation-en-de_0029-500k_1x2_1dec-none_no-revert_chunkwise_group-exp_congen-head_B10_multi-query-2_switch-ii_mpe128//model/epoch_9/model and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3B.COINForSequenceClassification(
        config=ci3B.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=False,
            num_decay_parts=1,
            hidden_retention_act="relu",

            decoder_output="strict",
            decoder_schema=[0, 1],
            cross_encoder_schema=[0, 0],
            experts_schema=None,
        )
    )

In [12]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3A.COINForSequenceClassification(
        config=ci3A.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=False,
            num_decay_parts=1,
            hidden_retention_act="relu",
            apply_hidden_pos_offset=False,
            #fuzed_decay_attention_mask=False,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[1, 1],
            cross_encoder_schema=[0] * 2,
            experts_schema=None,
        )
    )

In [13]:
if 0:
    N_ITER = 0
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci3A.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [14]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    NUM_REGIONS = 1
    model = ci2D.COINForSequenceClassification(
        config=ci2D.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            hidden_retention_act="relu",
            #hidden_out_act=None,
            forward_method="parallel",
            apply_decay=False,
            reverse_decay=False,
            num_decay_parts=1,
            decoder_output="strict",
            #rope_dim=16,
            
            num_regions=NUM_REGIONS,
            decoder_schema=      [0, 1],
            cross_encoder_schema=[0, 0],
            
            share_S=False,
            
            #layer_norm_eps=1e-12,
            #retention_group_norm_eps=1e-8,
            #rms_norm_eps=1e-12,

            disable_teacher_forcing=False,
        )
    )

In [15]:
if 0:
    N_ITER = 0
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci2D.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [16]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    NUM_REGIONS = 1
    model = ci2C.COINForSequenceClassification(
        config=ci2C.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            hidden_retention_act="relu",
            #hidden_out_act=None,
            forward_method="parallel",
            apply_decay=False,
            #fixed_decay_value=None,
            num_decay_parts=1,
            #reverse_decay=False,
            chunkwise_num_chunks=4,
            apply_chunking_globally=False,
            #apply_hidden_pos_offset=False,
            decoder_output="none",
            
            num_regions=NUM_REGIONS,
            decoder_schema=      [0],
            cross_encoder_schema=[0],
            multi_head_qkv=False,
            num_heads=16,
            share_S=False,
            #num_repetitions=1,
            add_residual_query_skip=False,

            #layer_norm_eps=1e-12,
            #retention_group_norm_eps=1e-8,
            #rms_norm_eps=1e-12,

            print_checks=CHECK_RUN,
            reset_S_n_state=False,
            disable_teacher_forcing=False,

            apply_selective_attention_params=False,
            #selective_param_Ns=(2, 2),
        )
    )

In [17]:
if 0:
    N_ITER = 0
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci2C.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [18]:
if 0:
    tokenizer = transformers.AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
    class MambaOutput:
        def __init__(self, out):
            self.logits = out
            self.encoder_hidden_state = None
            self.S = None
            self.C = None
            self.loss = None
            self.aux_loss = None

    class MambaForSequenceClassification(nn.Module):
        def __init__(self, path, **kwargs):
            super().__init__()
            self.mamba = transformers.MambaModel.from_pretrained(path, num_hidden_layers=12, **kwargs)
            self.config = self.mamba.config
            self.dense = nn.Linear(768, 768)
            self.act = nn.Tanh()
            self.cls = nn.Linear(768, NUM_LABELS)

        def forward(self, **kwargs):
            logits = self.mamba(**kwargs).last_hidden_state
            out = self.act(self.dense(logits[:, 0, :]))
            out = self.cls(out)
            return MambaOutput(out)

        
    #model = transformers.MambaModel.from_pretrained("state-spaces/mamba-130m-hf", num_labels=NUM_LABELS)
    model = MambaForSequenceClassification("state-spaces/mamba-130m-hf")

In [19]:
print("{:,}\n{:,}".format(num_parameters(model), num_trainable_parameters(model)))

115,813,410
115,813,378


In [20]:
base_ds_path = "../datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/"

train_ds = [
    f"{base_ds_path}/train/train_00[0-{TO_FILE}].csv"

    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/train/train_00[0-9].csv",
    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/train/train_01[0-9].csv"
]
test_ds = [
    f"{base_ds_path}/validation/validation_00[0-{TO_FILE}].csv"

    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/validation/validation_00[0-9].csv",
    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/validation/validation_01[0-9].csv"
]

In [21]:
if TEST_METHOD == "sst2":
    dataset = DatasetDict({
        "train": load_dataset("glue", name="sst2", split="train[:10000]").rename_column("sentence", "text"),
        "test": load_dataset("glue", name="sst2", split="validation").rename_column("sentence", "text")
    })
elif TEST_METHOD == "default":
    dataset = DatasetDict({
        "train": load_set(train_ds, unused_fields=["author", "subreddit", "style"], iloc_limit=ILOC_LIMIT),
        "test":  load_set(test_ds, unused_fields=["author", "subreddit", "style"], iloc_limit=ILOC_LIMIT)

        #"train": load_dataset("glue", name="mnli", split="train[0:10000]"),
        #"test": load_dataset("glue", name="mnli", split="validation_matched[0:1500]")

        #"train": load_dataset("squad_v2", split="train[0:10000]"),
        #"test": load_dataset("squad_v2", split="test[0:1500]")
    })
elif TEST_METHOD == "swag":
    dataset = DatasetDict({
        "train": load_dataset("Rowan/hellaswag", split="train"),
        "test": load_dataset("Rowan/hellaswag", split="validation")
    })
elif TEST_METHOD == "uni-main-hyp":
    DS_PATH = "../uni-hyp-class/wordpiece_abstracts_train_side_label_1.csv"
    dataset = DatasetDict({
        "train": load_set([DS_PATH], unused_fields=["head", "body", "strlabels"]),
        "test": load_set([DS_PATH], unused_fields=["head", "body", "strlabels"]),
    })
else:
    raise ValueError(TEST_METHOD)
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'idx'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['text', 'label', 'idx'],
        num_rows: 872
    })
})

In [22]:
# %%
if ONE_LABEL_ONLY:
    #labels = np.unique(train_df["label"]).tolist()
    labels = np.unique(dataset["train"]["label"]).tolist()
else:
    labels = [label for label in dataset['train'].features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)


{0: 0, 1: 1}
{0: 0, 1: 1}
[0, 1]


In [23]:
# %%
if TEST_METHOD in ("default", "sst2", "uni-main-hyp"):
    encoded_dataset = preprocess_with_given_labels(dataset, tokenizer, labels, label2id, MAX_POSITION_EMBEDDINGS, ONE_LABEL_ONLY, remove_columns=dataset["train"].column_names, 
                                                   default_teacher_forcing=DEFAULT_TEACHER_FORCING, doc_pad_tokens=DOC_PAD_TOKENS)
elif TEST_METHOD == "swag":
    #encoded_dataset = preprocess_for_multiple_choice(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names, num_proc=4)
    encoded_dataset = preprocess_for_seq2seq_swag(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names, num_proc=4)
encoded_dataset


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'decoder_input_ids'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'decoder_input_ids'],
        num_rows: 872
    })
})

In [24]:
# %%
batch_schema = list(encoded_dataset["train"].features.keys())
batch_schema


['input_ids',
 'token_type_ids',
 'attention_mask',
 'labels',
 'decoder_input_ids']

In [25]:
# %%
if USE_CUSTOM_DATALOADER:
    #train_dataloader = create_dataloader(encoded_dataset["train"])
    #test_dataloader = create_dataloader(encoded_dataset["test"])
    train_dataloader = BatchBuffer(encoded_dataset["train"], TRAIN_BATCH_SIZE)
    if SHUFFLE_CUSTOM_DATALOADER:
        train_dataloader.shuffle()
    test_dataloader = BatchBuffer(encoded_dataset["test"], TEST_BATCH_SIZE)
else:
    USE_TOKEN_TYPE_IDS = "token_type_ids" in encoded_dataset["train"].features
    USE_DEC_II = "decoder_input_ids" in encoded_dataset["train"].features
    # Load input data into tensors
    train_input_ids = torch.tensor(encoded_dataset["train"]["input_ids"])
    if USE_TOKEN_TYPE_IDS:
        train_token_type_ids = torch.tensor(encoded_dataset["train"]["token_type_ids"])
    train_masks = torch.tensor(encoded_dataset["train"]["attention_mask"])
    train_labels = torch.tensor(encoded_dataset["train"]["labels"])
    if USE_DEC_II:
        train_dec_ii = torch.tensor(encoded_dataset["train"]["decoder_input_ids"])

    test_input_ids = torch.tensor(encoded_dataset["test"]["input_ids"])
    if USE_TOKEN_TYPE_IDS:
        test_token_type_ids = torch.tensor(encoded_dataset["test"]["token_type_ids"])
    test_masks = torch.tensor(encoded_dataset["test"]["attention_mask"])
    test_labels = torch.tensor(encoded_dataset["test"]["labels"])
    if USE_DEC_II:
        test_dec_ii = torch.tensor(encoded_dataset["test"]["decoder_input_ids"])

    # Create the DataLoader and Sampler for both sets.
    if USE_TOKEN_TYPE_IDS:
        train_data = TensorDataset(train_input_ids, train_token_type_ids, train_masks, train_labels)
    else:
        train_data = TensorDataset(train_input_ids, train_masks, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, 
        sampler=train_sampler, 
        batch_size=BATCH_SIZE)

    if USE_TOKEN_TYPE_IDS:
        test_data = TensorDataset(test_input_ids, test_token_type_ids, test_masks, test_labels)
    else:
        test_data = TensorDataset(test_input_ids, test_masks, test_labels)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data, 
        sampler=test_sampler, 
        batch_size=BATCH_SIZE)


In [26]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

if len(labels) <= 2:
    loss_function = nn.CrossEntropyLoss()
else:
    loss_function = nn.BCEWithLogitsLoss()


In [27]:
# %%
print(loss_function)
print(total_steps)
print(warmup_steps)

CrossEntropyLoss()
1562.5
79


In [28]:
# %%
if CHECKPOINT_PATH is not None:
    try:
        os.mkdir(CHECKPOINT_PATH)
    except OSError as err:
        print(err)


In [29]:
# %%
#len(encoded_dataset["train"]["input_ids"]), len(encoded_dataset["train"]["input_ids"][0])


In [30]:
# %%
batch_schema


['input_ids',
 'token_type_ids',
 'attention_mask',
 'labels',
 'decoder_input_ids']

In [31]:
# %%
model


COINForSequenceClassification(
  (coin): COINModel(
    (encoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (decoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (residual_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (stack): COINStack(
      (layers): ModuleList(
        (0-1): 2 x COINLayer(
          (blocks): ModuleList(
            (0): COINBlock(
              (cross_qkv): MultiQueryQKV(
                (rope): RotaryEmbedding()
                (xpos): XPOS()
         

In [32]:
model.config

COINConfig {
  "_name_or_path": "tmp_models/COIN-i3C_mcca-translation-en-de_0029-500k_1x2_1dec-none_no-revert_chunkwise_group-exp_congen-head_B10_multi-query-2_switch-ii_mpe128//model/epoch_9/model",
  "allow_encoder_teacher_forcing": false,
  "apply_decay": false,
  "apply_decoder_heads": true,
  "apply_ffn": true,
  "apply_hidden_pos_offset": false,
  "apply_softmax_gate": true,
  "architectures": [
    "COINForConditionalGeneration"
  ],
  "block_io_schema": null,
  "cross_encoder_schema": [
    0,
    0
  ],
  "decoder_output": "none",
  "decoder_schema": [
    1,
    1
  ],
  "decoder_start_token_id": 0,
  "disable_teacher_forcing": false,
  "experts_schema": null,
  "ffn_intermediate_factor": 4,
  "ffn_intermediate_size": 4096,
  "fixed_decay_value": null,
  "fixed_ffn_intermediate_size": false,
  "fixed_intermediate_size": false,
  "forward_method": "chunkwise",
  "global_recurrence_check": false,
  "group_norm_channels": 1024,
  "group_norm_num": 32,
  "hidden_act": "tanh",
  "

In [33]:
torch.set_printoptions(threshold=100000_000)

In [34]:
%%time
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    loss_function,
    id2label,
    batch_schema=batch_schema,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    only_save_core=False,
    one_label_only=ONE_LABEL_ONLY,
    mixed_lm_task=False,
    mixed_lm_loss_function=nn.CrossEntropyLoss(),
    
    generic_output_class=True,
    
    #forward_args=["input_ids", "token_type_ids", "attention_mask", "labels"],
    #forward_args=["input_ids"],

    add_layers_on_stagnation=False,
    num_layers_to_add=1,
    add_layers_threshold=0.01, #0.005,
    plot_k_topics=False,

    batch_hack_train=True,
    mlm_decode_n=0,#.0075,
    tokenizer=tokenizer,

    masked_lm_task=False,
    check_run=CHECK_RUN,
    retain_graph=False,
)



Training...


  Batch    31  of  1,250.    Elapsed:  0:00:02, Remaining:  0:01:19.
  Batch    62  of  1,250.    Elapsed:  0:00:03, Remaining:  0:00:38.
  Batch    93  of  1,250.    Elapsed:  0:00:04, Remaining:  0:00:37.
  Batch   124  of  1,250.    Elapsed:  0:00:06, Remaining:  0:00:36.
  Batch   155  of  1,250.    Elapsed:  0:00:07, Remaining:  0:00:35.
  Batch   186  of  1,250.    Elapsed:  0:00:09, Remaining:  0:00:34.
  Batch   217  of  1,250.    Elapsed:  0:00:10, Remaining:  0:00:33.
  Batch   248  of  1,250.    Elapsed:  0:00:11, Remaining:  0:00:32.
  Batch   279  of  1,250.    Elapsed:  0:00:13, Remaining:  0:00:31.
  Batch   310  of  1,250.    Elapsed:  0:00:14, Remaining:  0:00:30.
  Batch   341  of  1,250.    Elapsed:  0:00:15, Remaining:  0:00:29.
  Batch   372  of  1,250.    Elapsed:  0:00:17, Remaining:  0:00:28.
  Batch   403  of  1,250.    Elapsed:  0:00:18, Remaining:  0:00:27.
  Batch   434  of  1,250.    Elapsed:  0:00:19, Remaining:  0:00:26.
  Batch   465  of  1,250.    Elaps

In [None]:
# %%
## SST 2 test
## [loss] / [acc/ham]

# 35k
# .61 / .79 epoch 4 ; .59 / .78 epoch 3 ; .52 / .776 epoch 2 15k pretraining, num_labels=2

# .48 / .80 epoch 3


# 10k
# .51 / .758 epoch 2 empty
# .57 / .778 epoch 3 empty


In [None]:
# %%
# hf bert (empty): batch_size=6, time per epoch=6:30min, 5734MiB VRAM, 0.756 0.786 epoch 6
# hf t5 (empty): batch_size=6, time per epoch=5:17min, 5306MiB VRAM, 0.758 0.773 epoch 5
# hf t5 (empty) num_layers=12 num_heads=12: batch_size=3, time per epoch=11:45min, 7416MiB VRAM, 0.758 0.787 epoch 5



