In [1]:
import glob
import time
import os
import datetime
import math
import datasets
from datasets import DatasetDict, load_dataset, Dataset
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    T5Tokenizer,
    T5ForSequenceClassification,
    T5Config,
    BertTokenizer, 
    RobertaTokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass

from load_set import *
from epoch_stats import EpochStats#, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    train_bern_model, 
    mask_tokens, 
    preprocess_with_given_labels, 
    num_parameters, 
    num_trainable_parameters, 
    preprocess_for_causallm, 
    preprocess_for_multiple_choice,
    preprocess_for_seq2seq_swag
)
import bert_i1_1_modeling as bert_i1_1
import coin_i2C_modeling as ci2C
import coin_i2D_modeling as ci2D
import coin_i3A_modeling as ci3A
import coin_i3B_modeling as ci3B
import coin_i3C_modeling as ci3C
import rnn_modeling as ci_rnn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TO_FILE = 1
TRAIN_BATCH_SIZE = 2
TEST_BATCH_SIZE = 2
CHECKPOINT_PATH = None #"hyp_cls/bert-base-uncased_main-label/" # datetime.datetime.now().strftime("tmp_models/rann_sffn/run_part_load_%Y-%m-%d_%H:%M:%S")
USE_CUSTOM_DATALOADER = True
SHUFFLE_CUSTOM_DATALOADER = True
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 10

In [3]:
CHECKPOINT_PATH

In [4]:
VOCAB_SIZE = 30_522
#VOCAB_SIZE = 32_000
#VOCAB_SIZE = 52_000
MAX_POSITION_EMBEDDINGS = 512#24**2
HIDDEN_SIZE = 1024
IS_HF_MODEL = False
GENERIC_OUTPUT_CLASS = True

In [5]:
# ALWAYS CHECK num_labels, RFFN doesn't throw an error on a wrong parameter
rffn_base_model_path = "tmp_models/COIN-i3C_mcca-translation-en-de_0029-500k_1x2_1dec-none_no-revert_chunkwise_group-exp_congen-head_B10_multi-query-2_switch-ii_mpe576_no-cross-att/"
#rffn_base_model_path = "tmp_models/COIN-i2B_oasst1_25k_1x3_000dec_none-decoder-revert-out_chunkwise_nt-case-3_2-decay-parts_allow-enc-tf/"
#rffn_base_model_path = "tmp_models/RRB_oasst1_25k_2-2-encoder_0-1-decoder_decay_maskedLM.15_.2share_docx1_wtf/"
#rffn_tokenizer_path = "pretrained_models/rffn_wikitext_516_tokenizer"
rffn_tokenizer_path = rffn_base_model_path

In [6]:
# default, sst2, swag, uni-main-hyp, uni-side-hyp, bucket-sort, duplicate-string, parity-check
TEST_METHOD = "uni-side-hyp"
ILOC_LIMIT = None
DEFAULT_TEACHER_FORCING = False
DOC_PAD_TOKENS = False

In [7]:
if TEST_METHOD == "default":
    NUM_LABELS = 7
elif TEST_METHOD in ("parity-check", "sst2"):
    NUM_LABELS = 2
elif TEST_METHOD == "swag":
    NUM_LABELS = 4
elif TEST_METHOD == "uni-main-hyp":
    NUM_LABELS = 10
elif TEST_METHOD == "uni-side-hyp":
    NUM_LABELS = 20
else:
    NUM_LABELS = 2

TEST_SST2 = TEST_METHOD == "sst2"
ONE_LABEL_ONLY = TEST_SST2

In [8]:
if TEST_METHOD == "parity-check":
    VOCAB_SIZE = 2

In [9]:
CHECK_RUN = False

In [10]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci_rnn.COINForSequenceClassification(
        config=ci_rnn.RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=2,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            layer_norm_eps=1e-12,
            rope_dim=16,
        )
    )

In [11]:
class GenConfig:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

In [12]:
if 0:
    ONE_LABEL_ONLY = True
    tokenizer = None
    class ParityLSTM(nn.Module):
        def __init__(self, hidden_size=HIDDEN_SIZE):
            super().__init__()
            self.config = None
            self.hidden_size = hidden_size
            self.lstm = nn.LSTM(1, hidden_size, batch_first=True)
            self.L1 = nn.Linear(hidden_size, 128)
            self.L2 = nn.Linear(128, 2)
        
        def forward(self, X):
            N = len()

            y = F.relu(self.L1(l_out))
            y = F.dropout(y, 0.5)
            y = self.L2(y)
            y = F.sigmoid(y)
            return y

In [13]:
if 0:
    ONE_LABEL_ONLY = True
    tokenizer = None
    class LSTMForParityCheck(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.n_layers = config.num_hidden_layers
            self.lstm = nn.LSTM(2, config.hidden_size, batch_first=True, num_layers=self.n_layers, dropout=config.hidden_dropout_prob)
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.classifier = nn.Sequential(
                #nn.Dropout(config.hidden_dropout_prob, 
                nn.Linear(config.hidden_size, config.num_labels)
            )
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            logits = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            #logits = self.embeddings(logits)
            B, T, C = logits.shape
            hidden = (torch.randn(self.n_layers, B, C, device=logits.device), torch.randn(self.n_layers, B, C, device=logits.device))
            #logits, hidden = self.lstm(logits, hidden)
            logits, hidden = self.lstm(logits)
            #for i in range(T):
            #    out, hidden = self.lstm(logits[:, i:i+1, :], hidden)
            #r_logits = self.pooler(logits[:, -1, :])
            #r_logits = F.tanh(logits)
            #r_logits = out[:, -1]
            #print(logits.shape, logits.view(T, -1).shape)
            r_logits = self.classifier(logits[:, -1])
            #r_logits = F.log_softmax(r_logits, 1)
            loss = self.loss_fn(r_logits, labels)
            return ci3C.COINOutputClass(
                logits=r_logits,
                loss=loss
            )

    model = LSTMForParityCheck(
        config=GenConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.5,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [14]:
if 0:
    IS_HF_MODEL = False
    ONE_LABEL_ONLY = True
    tokenizer = None
    class BertForParityCheck(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            #self.bert = transformers.BertModel(config)
            self.bert = transformers.BertForSequenceClassification(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = self.embeddings(F.one_hot(input_ids, self.config.vocab_size).float())
            B, T, C = emb.shape
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).logits
            loss = self.loss_fn(logits, labels)
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )


    model = BertForParityCheck(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=5,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
            num_labels=NUM_LABELS,
        )
    )

In [15]:
if 0:
    IS_HF_MODEL = True
    ONE_LABEL_ONLY = True
    tokenizer = None
    class BertForBucketSort(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.bert = transformers.BertModel(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = self.embeddings(F.one_hot(input_ids, self.config.vocab_size).float())
            B, T, C = emb.shape
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).last_hidden_state
            logits = self.lm_head(logits)
            loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )


    model = BertForBucketSort(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=5,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
        )
    )

In [16]:
ci3C_CONFIG = ci3C.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    hidden_size=HIDDEN_SIZE,
    num_labels=NUM_LABELS,
    forward_method="parallel",
    apply_decay=False,
    num_decay_parts=1,
    hidden_retention_act="relu",
    hidden_pos_offset=False,
    rope_dim=16,
    num_query_heads=1,

    decoder_output="none",
    revert_decoder=False,
    decoder_schema=[0] * 6,
    cross_encoder_schema=[0] * 6,
    experts_schema=None,#[2, 2],
    block_io_schema=None,#[[1024, 1024*4, 1024]],
)

In [17]:
if 0:
    tokenizer = None
    ONE_LABEL_ONLY = True
    model = ci3C.COINForParityCheck(
        config=ci3C_CONFIG
    )

In [18]:
if 0:
    ONE_LABEL_ONLY = True
    tokenizer = None
    model = ci3C.COINForBucketSort(
        config=ci3C_CONFIG
    )

In [19]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3C.COINForSequenceClassification(
        config=ci3C_CONFIG
    )

In [20]:
if 1:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3C.COINForHierachicalClassification(
        config=ci3C_CONFIG
    )

num layers: 6
gamma schema: [0.96875, 0.9820516109466553, 0.9896913170814514, 0.9940792322158813, 0.9965994358062744, 0.998046875]
layer 0 num experts: 1
layer 1 num experts: 1
layer 2 num experts: 1
layer 3 num experts: 1
layer 4 num experts: 1
layer 5 num experts: 1


In [21]:
if 0:
    N_ITER = 4
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci3C.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [22]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3B.COINForSequenceClassification(
        config=ci3B.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=False,
            num_decay_parts=1,
            hidden_retention_act="relu",

            decoder_output="strict",
            decoder_schema=[0, 1],
            cross_encoder_schema=[0, 0],
            experts_schema=None,
        )
    )

In [23]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = ci3A.COINForSequenceClassification(
        config=ci3A.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=False,
            num_decay_parts=1,
            hidden_retention_act="relu",
            apply_hidden_pos_offset=False,
            #fuzed_decay_attention_mask=False,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[1, 1],
            cross_encoder_schema=[0] * 2,
            experts_schema=None,
        )
    )

In [24]:
if 0:
    N_ITER = 0
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci3A.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [25]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    NUM_REGIONS = 1
    model = ci2D.COINForSequenceClassification(
        config=ci2D.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            hidden_retention_act="relu",
            #hidden_out_act=None,
            forward_method="parallel",
            apply_decay=False,
            reverse_decay=False,
            num_decay_parts=1,
            decoder_output="strict",
            #rope_dim=16,
            
            num_regions=NUM_REGIONS,
            decoder_schema=      [0, 1],
            cross_encoder_schema=[0, 0],
            
            share_S=False,
            
            #layer_norm_eps=1e-12,
            #retention_group_norm_eps=1e-8,
            #rms_norm_eps=1e-12,

            disable_teacher_forcing=False,
        )
    )

In [26]:
if 0:
    N_ITER = 0
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci2D.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [27]:
if 0:
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    NUM_REGIONS = 1
    model = ci2C.COINForSequenceClassification(
        config=ci2C.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_labels=NUM_LABELS,
            hidden_retention_act="relu",
            #hidden_out_act=None,
            forward_method="parallel",
            apply_decay=False,
            #fixed_decay_value=None,
            num_decay_parts=1,
            #reverse_decay=False,
            chunkwise_num_chunks=4,
            apply_chunking_globally=False,
            #apply_hidden_pos_offset=False,
            decoder_output="none",
            
            num_regions=NUM_REGIONS,
            decoder_schema=      [0],
            cross_encoder_schema=[0],
            multi_head_qkv=False,
            num_heads=16,
            share_S=False,
            #num_repetitions=1,
            add_residual_query_skip=False,

            #layer_norm_eps=1e-12,
            #retention_group_norm_eps=1e-8,
            #rms_norm_eps=1e-12,

            print_checks=CHECK_RUN,
            reset_S_n_state=False,
            disable_teacher_forcing=False,

            apply_selective_attention_params=False,
            #selective_param_Ns=(2, 2),
        )
    )

In [28]:
if 0:
    N_ITER = 0
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = RobertaTokenizer.from_pretrained(f"{rffn_base_model_path}/bpe_tokenizer/")
    model = ci2C.COINForSequenceClassification.from_pretrained(
        f"{rffn_base_model_path}/model/epoch_{N_ITER}/model",
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        num_labels=NUM_LABELS,
    )

In [29]:
if 0:
    tokenizer = transformers.AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
    class MambaOutput:
        def __init__(self, out):
            self.logits = out
            self.encoder_hidden_state = None
            self.S = None
            self.C = None
            self.loss = None
            self.aux_loss = None

    class MambaForSequenceClassification(nn.Module):
        def __init__(self, path, **kwargs):
            super().__init__()
            self.mamba = transformers.MambaModel.from_pretrained(path, num_hidden_layers=12, **kwargs)
            self.config = self.mamba.config
            self.dense = nn.Linear(768, 768)
            self.act = nn.Tanh()
            self.cls = nn.Linear(768, NUM_LABELS)

        def forward(self, **kwargs):
            logits = self.mamba(**kwargs).last_hidden_state
            out = self.act(self.dense(logits[:, 0, :]))
            out = self.cls(out)
            return MambaOutput(out)

        
    #model = transformers.MambaModel.from_pretrained("state-spaces/mamba-130m-hf", num_labels=NUM_LABELS)
    model = MambaForSequenceClassification("state-spaces/mamba-130m-hf")

In [30]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
    model = transformers.BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels=NUM_LABELS
    )

In [31]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
    #tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    model = transformers.BertForSequenceClassification(
        config=transformers.BertConfig(
            num_labels=NUM_LABELS,
            #num_hidden_layers=2,
        )
    )

In [32]:
print("{:,}\n{:,}".format(num_parameters(model), num_trainable_parameters(model)))

116,884,548
116,884,500


In [33]:
base_ds_path = "../datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/"

train_ds = [
    f"{base_ds_path}/train/train_00[0-{TO_FILE}].csv"

    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/train/train_00[0-9].csv",
    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/train/train_01[0-9].csv"
]
test_ds = [
    f"{base_ds_path}/validation/validation_00[0-{TO_FILE}].csv"

    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/validation/validation_00[0-9].csv",
    #f"datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/validation/validation_01[0-9].csv"
]

In [34]:
if TEST_METHOD == "sst2":
    dataset = DatasetDict({
        "train": load_dataset("glue", name="sst2", split="train[:10000]").rename_column("sentence", "text"),
        "test": load_dataset("glue", name="sst2", split="validation").rename_column("sentence", "text")
    })
elif TEST_METHOD == "default":
    dataset = DatasetDict({
        "train": load_set(train_ds, unused_fields=["author", "subreddit", "style"], iloc_limit=ILOC_LIMIT),
        "test":  load_set(test_ds, unused_fields=["author", "subreddit", "style"], iloc_limit=ILOC_LIMIT)

        #"train": load_dataset("glue", name="mnli", split="train[0:10000]"),
        #"test": load_dataset("glue", name="mnli", split="validation_matched[0:1500]")

        #"train": load_dataset("squad_v2", split="train[0:10000]"),
        #"test": load_dataset("squad_v2", split="test[0:1500]")
    })
elif TEST_METHOD == "swag":
    dataset = DatasetDict({
        "train": load_dataset("Rowan/hellaswag", split="train"),
        "test": load_dataset("Rowan/hellaswag", split="validation")
    })
elif TEST_METHOD in ("uni-main-hyp", "uni-side-hyp"):
    if TEST_METHOD == "uni-main-hyp":
        DS_TRAIN_PATH = "../uni-hyp-class/wordpiece_abstracts_train.csv"
        DS_TEST_PATH = "../uni-hyp-class/wordpiece_abstracts_test.csv"
    elif TEST_METHOD == "uni-side-hyp":
        DS_TRAIN_PATH = "../uni-hyp-class/wordpiece_abstracts_train_side_label_1.csv"
        DS_TEST_PATH = "../uni-hyp-class/wordpiece_abstracts_test_side_label_1.csv"
    dataset = DatasetDict({
        "train": load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"]),
        "test": load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"]),
    })
elif TEST_METHOD == "bucket-sort":
    #dataset = generate_bucket_sort_set(B=100, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE)
    dataset = DatasetDict({
        #"train": Dataset.from_dict(generate_bucket_sort_set(B=100_000, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE)),
        #"test": Dataset.from_dict(generate_bucket_sort_set(B=1000, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE))
        "train": generate_uniform_batches(generate_bucket_sort_set, B=TRAIN_BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=10_000),
        "test": generate_uniform_batches(generate_bucket_sort_set, B=TEST_BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=1000)
    })
elif TEST_METHOD == "duplicate-string":
    dataset = DatasetDict({
        "train": generate_duplicate_string_set(B=100_000, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE),
        "test": generate_duplicate_string_set(B=1000, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE)
    })
elif TEST_METHOD == "parity-check":
    dataset = DatasetDict({
        #"train": generate_parity_check_set(B=10_000, T=MAX_POSITION_EMBEDDINGS),
        #"test": generate_parity_check_set(B=1000, T=MAX_POSITION_EMBEDDINGS)
        "train": generate_uniform_batches(generate_parity_check_set, B=TRAIN_BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=10_000),
        "test": generate_uniform_batches(generate_parity_check_set, B=TEST_BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=1000)
    })
else:
    raise ValueError(TEST_METHOD)
dataset

loading files
    ../uni-hyp-class/wordpiece_abstracts_train_side_label_1.csv
loading files
    ../uni-hyp-class/wordpiece_abstracts_test_side_label_1.csv


DatasetDict({
    train: Dataset({
        features: ['text', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9'],
        num_rows: 862
    })
    test: Dataset({
        features: ['text', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9'],
        num_rows: 92
    })
})

In [35]:
# %%
if ONE_LABEL_ONLY:
    #labels = np.unique(train_df["label"]).tolist()
    labels = np.unique(dataset["train"]["label"]).tolist()
else:
    labels = [label for label in dataset['train'].features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)


{'a0': 0, 'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5, 'a6': 6, 'a7': 7, 'a8': 8, 'a9': 9, 'b0': 10, 'b1': 11, 'b2': 12, 'b3': 13, 'b4': 14, 'b5': 15, 'b6': 16, 'b7': 17, 'b8': 18, 'b9': 19}
{0: 'a0', 1: 'a1', 2: 'a2', 3: 'a3', 4: 'a4', 5: 'a5', 6: 'a6', 7: 'a7', 8: 'a8', 9: 'a9', 10: 'b0', 11: 'b1', 12: 'b2', 13: 'b3', 14: 'b4', 15: 'b5', 16: 'b6', 17: 'b7', 18: 'b8', 19: 'b9'}
['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9']


In [36]:
# %%
if TEST_METHOD in ("default", "sst2", "uni-main-hyp", "uni-side-hyp"):
    encoded_dataset = preprocess_with_given_labels(dataset, tokenizer, labels, label2id, MAX_POSITION_EMBEDDINGS, ONE_LABEL_ONLY, remove_columns=dataset["train"].column_names, 
                                                   default_teacher_forcing=DEFAULT_TEACHER_FORCING, doc_pad_tokens=DOC_PAD_TOKENS)
elif TEST_METHOD == "swag":
    #encoded_dataset = preprocess_for_multiple_choice(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names, num_proc=4)
    encoded_dataset = preprocess_for_seq2seq_swag(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names, num_proc=4)
elif TEST_METHOD in ("bucket-sort", "duplicate-string", "parity-check"):
    encoded_dataset = dataset
encoded_dataset


Map (num_proc=4): 100%|██████████| 862/862 [00:01<00:00, 659.44 examples/s]


Map (num_proc=4): 100%|██████████| 92/92 [00:00<00:00, 134.24 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'decoder_input_ids'],
        num_rows: 862
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'decoder_input_ids'],
        num_rows: 92
    })
})

In [37]:
# %%
batch_schema = list(encoded_dataset["train"].features.keys())
batch_schema


['input_ids',
 'token_type_ids',
 'attention_mask',
 'labels',
 'decoder_input_ids']

In [38]:
# %%
if USE_CUSTOM_DATALOADER:
    #train_dataloader = create_dataloader(encoded_dataset["train"])
    #test_dataloader = create_dataloader(encoded_dataset["test"])
    train_dataloader = BatchBuffer(encoded_dataset["train"], TRAIN_BATCH_SIZE)
    if SHUFFLE_CUSTOM_DATALOADER:
        train_dataloader.shuffle()
    test_dataloader = BatchBuffer(encoded_dataset["test"], TEST_BATCH_SIZE)
else:
    USE_TOKEN_TYPE_IDS = "token_type_ids" in encoded_dataset["train"].features
    USE_DEC_II = "decoder_input_ids" in encoded_dataset["train"].features
    # Load input data into tensors
    train_input_ids = torch.tensor(encoded_dataset["train"]["input_ids"])
    if USE_TOKEN_TYPE_IDS:
        train_token_type_ids = torch.tensor(encoded_dataset["train"]["token_type_ids"])
    train_masks = torch.tensor(encoded_dataset["train"]["attention_mask"])
    train_labels = torch.tensor(encoded_dataset["train"]["labels"])
    if USE_DEC_II:
        train_dec_ii = torch.tensor(encoded_dataset["train"]["decoder_input_ids"])

    test_input_ids = torch.tensor(encoded_dataset["test"]["input_ids"])
    if USE_TOKEN_TYPE_IDS:
        test_token_type_ids = torch.tensor(encoded_dataset["test"]["token_type_ids"])
    test_masks = torch.tensor(encoded_dataset["test"]["attention_mask"])
    test_labels = torch.tensor(encoded_dataset["test"]["labels"])
    if USE_DEC_II:
        test_dec_ii = torch.tensor(encoded_dataset["test"]["decoder_input_ids"])

    # Create the DataLoader and Sampler for both sets.
    if USE_TOKEN_TYPE_IDS:
        train_data = TensorDataset(train_input_ids, train_token_type_ids, train_masks, train_labels)
    else:
        train_data = TensorDataset(train_input_ids, train_masks, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, 
        sampler=train_sampler, 
        batch_size=BATCH_SIZE)

    if USE_TOKEN_TYPE_IDS:
        test_data = TensorDataset(test_input_ids, test_token_type_ids, test_masks, test_labels)
    else:
        test_data = TensorDataset(test_input_ids, test_masks, test_labels)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data, 
        sampler=test_sampler, 
        batch_size=BATCH_SIZE)


In [39]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

if len(labels) <= 2:
    loss_function = nn.CrossEntropyLoss()
else:
    loss_function = nn.BCEWithLogitsLoss()


In [40]:
# %%
print(loss_function)
print(total_steps)
print(warmup_steps)

BCEWithLogitsLoss()
2155.0
108


In [41]:
# %%
if CHECKPOINT_PATH is not None:
    try:
        os.mkdir(CHECKPOINT_PATH)
    except OSError as err:
        print(err)


In [42]:
# %%
#len(encoded_dataset["train"]["input_ids"]), len(encoded_dataset["train"]["input_ids"][0])


In [43]:
# %%
batch_schema


['input_ids',
 'token_type_ids',
 'attention_mask',
 'labels',
 'decoder_input_ids']

In [44]:
# %%
model


COINForHierachicalClassification(
  (coin): COINModel(
    (encoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (decoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (residual_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (stack): COINStack(
      (layers): ModuleList(
        (0-5): 6 x COINLayer(
          (blocks): ModuleList(
            (0): COINBlock(
              (qkv): SingleHeadQKV(
                (rope): RotaryEmbedding()
                (xpos): XPOS()
            

In [45]:
model.config

COINConfig {
  "allow_encoder_teacher_forcing": false,
  "apply_decay": true,
  "apply_decoder_heads": true,
  "apply_ffn": true,
  "apply_hidden_pos_offset": false,
  "apply_softmax_gate": true,
  "block_io_schema": null,
  "cross_encoder_schema": [
    0,
    0,
    0,
    0,
    0,
    0
  ],
  "decoder_output": "none",
  "decoder_schema": [
    0,
    0,
    0,
    0,
    0,
    0
  ],
  "decoder_start_token_id": 0,
  "disable_teacher_forcing": false,
  "experts_schema": null,
  "ffn_intermediate_factor": 4,
  "ffn_intermediate_size": 4096,
  "fixed_decay_value": null,
  "fixed_ffn_intermediate_size": false,
  "fixed_intermediate_size": false,
  "forward_method": "parallel",
  "global_recurrence_check": false,
  "group_norm_channels": 1024,
  "group_norm_num": 32,
  "hidden_act": "tanh",
  "hidden_dropout_prob": 0.1,
  "hidden_out_act": "relu",
  "hidden_pos_offset": false,
  "hidden_retention_act": "relu",
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_

In [46]:
torch.set_printoptions(threshold=100_000_000)

In [47]:
%%time
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    loss_function,
    id2label,
    batch_schema=batch_schema,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    only_save_core=False,
    one_label_only=ONE_LABEL_ONLY,
    mixed_lm_task=False,
    mixed_lm_loss_function=nn.CrossEntropyLoss(),
    
    generic_output_class=True,
    
    #forward_args=["input_ids", "token_type_ids", "attention_mask", "labels"],
    #forward_args=["input_ids"],

    add_layers_on_stagnation=False,
    num_layers_to_add=1,
    add_layers_threshold=0.01, #0.005,
    plot_k_topics=False,

    batch_hack_train=True,
    mlm_decode_n=0,#.0075,
    tokenizer=tokenizer,

    masked_lm_task=False,
    check_run=CHECK_RUN,
    retain_graph=False,

    calc_metrics=True
)



Training...


  Batch     8  of    431.    Elapsed:  0:00:01, Remaining:  0:00:53.
  Batch    16  of    431.    Elapsed:  0:00:01, Remaining:  0:00:00.
  Batch    24  of    431.    Elapsed:  0:00:01, Remaining:  0:00:00.
  Batch    32  of    431.    Elapsed:  0:00:02, Remaining:  0:00:00.
  Batch    40  of    431.    Elapsed:  0:00:02, Remaining:  0:00:00.
  Batch    48  of    431.    Elapsed:  0:00:02, Remaining:  0:00:00.
  Batch    56  of    431.    Elapsed:  0:00:03, Remaining:  0:00:00.
  Batch    64  of    431.    Elapsed:  0:00:03, Remaining:  0:00:00.
  Batch    72  of    431.    Elapsed:  0:00:03, Remaining:  0:00:00.
  Batch    80  of    431.    Elapsed:  0:00:04, Remaining:  0:00:00.
  Batch    88  of    431.    Elapsed:  0:00:04, Remaining:  0:00:00.
  Batch    96  of    431.    Elapsed:  0:00:04, Remaining:  0:00:00.
  Batch   104  of    431.    Elapsed:  0:00:05, Remaining:  0:00:00.
  Batch   112  of    431.    Elapsed:  0:00:05, Remaining:  0:00:00.
  Batch   120  of    431.    Elaps

In [None]:
# %%
## SST 2 test
## [loss] / [acc/ham]

# 35k
# .61 / .79 epoch 4 ; .59 / .78 epoch 3 ; .52 / .776 epoch 2 15k pretraining, num_labels=2

# .48 / .80 epoch 3


# 10k
# .51 / .758 epoch 2 empty
# .57 / .778 epoch 3 empty


In [None]:
# %%
# hf bert (empty): batch_size=6, time per epoch=6:30min, 5734MiB VRAM, 0.756 0.786 epoch 6
# hf t5 (empty): batch_size=6, time per epoch=5:17min, 5306MiB VRAM, 0.758 0.773 epoch 5
# hf t5 (empty) num_layers=12 num_heads=12: batch_size=3, time per epoch=11:45min, 7416MiB VRAM, 0.758 0.787 epoch 5