In [1]:
import glob
import time
import os
import math
from datasets import DatasetDict, load_dataset, Dataset, concatenate_datasets, load_from_disk
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    BertTokenizer, 
    RobertaTokenizer,
    XLNetTokenizer,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np

from load_set import load_set, load_moses_set
from epoch_stats import EpochStats, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    mask_tokens, 
    train_bern_model, 
    preprocess_for_maskedlm, 
    preprocess_for_causallm, 
    preprocess_for_monologe, 
    preprocess_for_sparselm, 
    preprocess_for_binary_sparselm,
    num_parameters, 
    num_trainable_parameters,
    preprocess_for_translation,
)

import coin_i2C_modeling as ci2C
import coin_i2D_modeling as ci2D
import coin_i3A_modeling as ci3A
import coin_i3C_modeling as ci3C

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
REBATCH = False
TRAIN_BATCH_SIZE = 6
TEST_BATCH_SIZE = 6
BASE_PATH = "tmp_models/COIN-i3C_mcca-translation-en-de_0029-500k_1x2_1dec-none_no-revert_parallel_group-exp_congen-head_B6_multi-query-2_switch-ii/"
#BASE_PATH = "tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2/"

if REBATCH:
    if BASE_PATH[-1] == "/":
        BASE_PATH = BASE_PATH[:-1]
    BASE_PATH += "_rebatch/"
DATASET_JSON_PATH = f"{BASE_PATH}/dataset/"
CHECKPOINT_PATH = f"{BASE_PATH}/model/"
WORDPIECE_TOKENIZER_DIR = f"{BASE_PATH}/wordpiece_tokenizer/"
BPE_TOKENIZER_DIR = f"{BASE_PATH}/bpe_tokenizer/"
SENTENCE_PIECE_TOKENIZER_DIR = f"{BASE_PATH}/sentence_piece_tokenizer/"
USE_CUSTOM_DATALOADER = False
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 10

In [3]:
SHUFFLE_TRAIN_DATA = False
SHUFFLE_TEST_DATA = False

In [4]:
CAUSAL_LM = False
ENCODE_CAUSAL_LM = False
GROUP_TEXTS = True
SPARSIFY = True
MASK_TOKEN = None
PAD_TOKEN = None
PREFIX = None#"Translate the following text:"
#PREFIX = "Replace all of the mask-tokens: "
#PREFIX = "This sentence is completely obsolete "
SWITCH_II_DECODER_II = True

In [5]:
f"out dir: {BASE_PATH}"

'out dir: tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2/'

In [6]:
VOCAB_SIZE = 30_522
#VOCAB_SIZE = 32_000
#VOCAB_SIZE = 52_000
MAX_POSITION_EMBEDDINGS = 512
#MAX_POSITION_EMBEDDINGS = 516
#MAX_POSITION_EMBEDDINGS = 768
IS_HF_MODEL = False
IS_ENCODER_DECODER_MODEL = False
EPOCH_I = 0

In [7]:
CI3C_CONFIG = ci3C.COINConfig(
    vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            forward_method="chunkwise",
            apply_decay=False,
            num_decay_parts=1,
            hidden_retention_act="relu",
            hidden_pos_offset=True,
            rope_dim=16,
            num_query_heads=2,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[1, 1],
            cross_encoder_schema=[0, 0],
            block_io_schema=None,#[[1024, 1024*4, 1024], [1024, 1024*2, 1024]],
)

In [8]:
if 1:
    model = ci3C.COINForConditionalGeneration(
        CI3C_CONFIG
    )

In [9]:
if 0:
    model = ci3C.COINForMaskedLM(
        CI3C_CONFIG
    )

num layers: 2
gamma schema: [[0.96875, 0.9875984191894531], [0.995078444480896, 0.998046875]]
layer 0 num experts: 1
layer 1 num experts: 1


In [10]:
CI3A_CONFIG = ci3A.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    forward_method="chunkwise",
    apply_decay=False,
    num_decay_parts=1,
    hidden_retention_act="relu",
    apply_hidden_pos_offset=True,

    decoder_output="none",
    revert_decoder=False,
    decoder_schema=[1] * 2,
    cross_encoder_schema=[0] * 2,
    experts_schema=None,

    switch_ii_decoder_ii=SWITCH_II_DECODER_II,
)

In [11]:
if 0:
    model = ci3A.COINForMaskedLM(
        CI3A_CONFIG
    )

In [12]:
if 0:
    model = ci3A.COINForConditionalGeneration(
        CI3A_CONFIG
    )

In [13]:
NUM_REGIONS = 1
COIN_CONFIG = ci2D.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    hidden_retention_act="relu",
    #hidden_out_act=None,
    forward_method="parallel",
    apply_decay=False,
    reverse_decay=False,
    num_decay_parts=1,
    decoder_output="strict",
    #rope_dim=16,
    
    num_regions=NUM_REGIONS,
    decoder_schema=      [1],
    cross_encoder_schema=[0],
    
    share_S=False,
    
    #layer_norm_eps=1e-12,
    #retention_group_norm_eps=1e-8,
    #rms_norm_eps=1e-12,
    switch_ii_decoder_ii=SWITCH_II_DECODER_II,
    disable_teacher_forcing=False,
)

In [14]:
if 0:
    model = ci2D.COINForMaskedLM(COIN_CONFIG)

In [15]:
NUM_REGIONS = 1
COIN_CONFIG = ci2C.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    hidden_retention_act="relu",
    #hidden_out_act=None,
    forward_method="parallel",
    apply_decay=True,
    #fixed_decay_value=None,
    num_decay_parts=1,
    #reverse_decay=False,
    #chunkwise_num_chunks=1,
    #apply_chunking_globally=False,
    #apply_hidden_pos_offset=False,
    decoder_output="none",
    
    num_regions=NUM_REGIONS,
    decoder_schema=[0, 0],
    cross_encoder_schema=[0] * 2,
    multi_head_qkv=False,
    #num_heads=16,
    #share_S=False,
    num_repetitions=1,

    #rms_norm_eps=1e-8,

    disable_teacher_forcing=False,
    switch_ii_decoder_ii=SWITCH_II_DECODER_II,
)

In [16]:
if 0:
    model = ci2C.COINForMaskedLM(
        config=COIN_CONFIG
    )

In [17]:
if 0:
    model = ci2C.COINForConditionalGeneration(
        config=COIN_CONFIG
    )

In [18]:
if 0:
    IS_HF_MODEL = True
    model = transformers.XLMWithLMHeadModel(
        config=transformers.XLMConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        )
    )

In [19]:
print("{:,}\n{:,}".format(num_parameters(model), num_trainable_parameters(model)))

148,148,058
148,148,026


In [20]:
# %%
try:
    os.mkdir(BASE_PATH)
except OSError as err:
    print(err)
try:
    os.mkdir(CHECKPOINT_PATH)
except OSError as err:
    print(err)
try:
    os.mkdir(WORDPIECE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(BPE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(SENTENCE_PIECE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(DATASET_JSON_PATH)
except OSError as err:
    print(err)


[Errno 17] File exists: 'tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2/'
[Errno 17] File exists: 'tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2//model/'
[Errno 17] File exists: 'tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2//wordpiece_tokenizer/'
[Errno 17] File exists: 'tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2//bpe_tokenizer/'
[Errno 17] File exists: 'tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2//sentence_piece_tokenizer/'
[Errno 17] File exists: 'tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2//dataset/'


In [21]:
# %%
#dataset = DatasetDict({
#    "train": load_dataset("wikitext", name="wikitext-103-raw-v1", split="train[0:10000]"),
#    "test":  load_dataset("wikitext", name="wikitext-103-raw-v1", split="validation[:1500]")
#})

#dataset


In [22]:
# oasst1, aaabdon, mcca, translation_mcca
DATASET = "translation_mcca"
#DATASET = "oasst1"

#HF_TRAIN_ROWS = 25_000
HF_TRAIN_ROWS = 500_000
#HF_TRAIN_ROWS = -1
HF_TRAIN_FROM = 0#10_000

HF_TEST_ROWS = 1_500#5000
#HF_TEST_ROWS = -1
HF_TEST_FROM = 0

CUSTOM_BASE_DS_PATH = "../datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/"
CUSTOM_DS_TO_FILE = 2

In [23]:
if DATASET == "mcca":
    train_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-2][0-9]",
        ]
    })
    test_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
        ]
    })
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "translation_mcca":
    train_dataset = load_moses_set({
        "src": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-5][0-9]",
            "../datasets/multi_cc_aligned_en-de/de/x00[0-5][0-9]",
        ],
        "tgt": [
            "../datasets/multi_cc_aligned_en-de/de/x00[0-5][0-9]",
            "../datasets/multi_cc_aligned_en-de/en/x00[0-5][0-9]",
        ]
    })
    test_dataset = load_moses_set({
        "src": [
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
            "../datasets/multi_cc_aligned_en-de/de/x800[0-1]",
        ],
        "tgt": [
            "../datasets/multi_cc_aligned_en-de/de/x800[0-1]",
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
        ]
    })
    tok_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-9][0-9]",
            "../datasets/multi_cc_aligned_en-de/de/x00[0-9][0-9]",
        ]
    })
elif DATASET == "aaabdon":
    train_ds = [f"{CUSTOM_BASE_DS_PATH}/train/train_00[0-{CUSTOM_DS_TO_FILE}].csv"]
    test_ds = [f"{CUSTOM_BASE_DS_PATH}/validation/validation_00[0-{CUSTOM_DS_TO_FILE}].csv"]
    train_dataset = load_set(train_ds)#.select(list(range(HF_TRAIN_ROWS)))
    test_dataset = load_set(test_ds)#.select(list(range(HF_TEST_ROWS)))
    tok_dataset = load_set([f"{CUSTOM_BASE_DS_PATH}/train/train_*.csv"])
elif DATASET == "oasst1":
    #train_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split=f"train[0:{HF_TRAIN_ROWS}]")
    #test_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split=f"validation[:{HF_TEST_ROWS}]")
    #tok_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split="train")
    
    #train_dataset = load_dataset("QingyiSi/Alpaca-CoT", split=f"train[0:{HF_TRAIN_ROWS}]")  .rename_column("instruction", "text").rename_column("output", "target")
    #test_dataset = load_dataset("QingyiSi/Alpaca-CoT", split=f"test[0:{HF_TEST_ROWS}]")     .rename_column("instruction", "text").rename_column("output", "target")
    #tok_dataset = load_dataset("QingyiSi/Alpaca-CoT")                                       .rename_column("instruction", "text").rename_column("output", "target")

    train_dataset = load_dataset("OpenAssistant/oasst1", split="train").filter(lambda e: e["lang"] == "en")
    test_dataset = load_dataset("OpenAssistant/oasst1", split="validation").filter(lambda e: e["lang"] == "en")
    tok_dataset = load_dataset("OpenAssistant/oasst1", split="train").filter(lambda e: e["lang"] == "en")
    #train_dataset = load_from_disk("../datasets/oasst1/train").filter(lambda e: e["lang"] == "en").select(list(range(HF_TRAIN_ROWS)))
    #test_dataset = load_from_disk("../datasets/oasst1/validation").filter(lambda e: e["lang"] == "en").select(list(range(HF_TEST_ROWS)))
    #tok_dataset = load_from_disk("../datasets/oasst1/train").filter(lambda e: e["lang"] == "en")

    #train_dataset = load_dataset("OpenAssistant/oasst2", split="train").filter(lambda e: e["lang"] == "en").select(list(range(HF_TRAIN_ROWS)))
    #test_dataset = load_dataset("OpenAssistant/oasst2", split="validation").filter(lambda e: e["lang"] == "en").select(list(range(HF_TEST_ROWS)))

if SHUFFLE_TRAIN_DATA:
    print("shuffle")
    train_dataset = train_dataset.shuffle()
if HF_TRAIN_ROWS > 0:
    train_dataset =  train_dataset.select(list(range(HF_TRAIN_FROM, HF_TRAIN_ROWS)))
if SHUFFLE_TEST_DATA:
    print("shuffle")
    test_dataset = test_dataset.shuffle()
if HF_TEST_ROWS > 0:
    test_dataset = test_dataset.select(list(range(HF_TEST_FROM, HF_TEST_ROWS)))

print(train_dataset)
print(test_dataset)
print(tok_dataset)

Dataset({
    features: ['message_id', 'parent_id', 'user_id', 'created_date', 'text', 'role', 'lang', 'review_count', 'review_result', 'deleted', 'rank', 'synthetic', 'model_name', 'detoxify', 'message_tree_id', 'tree_state', 'emojis', 'labels'],
    num_rows: 25000
})
Dataset({
    features: ['message_id', 'parent_id', 'user_id', 'created_date', 'text', 'role', 'lang', 'review_count', 'review_result', 'deleted', 'rank', 'synthetic', 'model_name', 'detoxify', 'message_tree_id', 'tree_state', 'emojis', 'labels'],
    num_rows: 1500
})
Dataset({
    features: ['message_id', 'parent_id', 'user_id', 'created_date', 'text', 'role', 'lang', 'review_count', 'review_result', 'deleted', 'rank', 'synthetic', 'model_name', 'detoxify', 'message_tree_id', 'tree_state', 'emojis', 'labels'],
    num_rows: 39283
})


In [24]:
train_dataset.to_json(f"{DATASET_JSON_PATH}/train.json")
test_dataset.to_json(f"{DATASET_JSON_PATH}/test.json")

Creating json from Arrow format: 100%|██████████| 25/25 [00:01<00:00, 18.74ba/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 26.05ba/s]


2309275

In [25]:
labels = [label for label in train_dataset.features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

{'message_id': 0, 'parent_id': 1, 'user_id': 2, 'created_date': 3, 'role': 4, 'lang': 5, 'review_count': 6, 'review_result': 7, 'deleted': 8, 'rank': 9, 'synthetic': 10, 'model_name': 11, 'detoxify': 12, 'message_tree_id': 13, 'tree_state': 14, 'emojis': 15, 'labels': 16}
{0: 'message_id', 1: 'parent_id', 2: 'user_id', 3: 'created_date', 4: 'role', 5: 'lang', 6: 'review_count', 7: 'review_result', 8: 'deleted', 9: 'rank', 10: 'synthetic', 11: 'model_name', 12: 'detoxify', 13: 'message_tree_id', 14: 'tree_state', 15: 'emojis', 16: 'labels'}
['message_id', 'parent_id', 'user_id', 'created_date', 'role', 'lang', 'review_count', 'review_result', 'deleted', 'rank', 'synthetic', 'model_name', 'detoxify', 'message_tree_id', 'tree_state', 'emojis', 'labels']


In [26]:
from tokenizers import ByteLevelBPETokenizer, BertWordPieceTokenizer, SentencePieceBPETokenizer, SentencePieceUnigramTokenizer

In [27]:
# %%
if 0:
    #tok_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split="train")
    #tok_dataset = load_dataset("glue", name="sst2", split="train")
    #tok_dataset = tok_dataset.rename_column("sentence", "text")
    
    tokenizer = SentencePieceUnigramTokenizer()

    tokenizer.train_from_iterator(
        iterator=tok_dataset["text"], 
        vocab_size=VOCAB_SIZE,
        #min_frequency=2,
        show_progress=True,
        #limit_alphabet=500,
        special_tokens=[
            "<PAD>", 
            "<UNK>", 
            "<CLS>", 
            "<SEP>", 
            "<DOC>",
            "<MASK>"
        ])

    tokenizer = PreTrainedTokenizer(
        tokenizer_object=tokenizer
    )
    tokenizer.save_model(SENTENCE_PIECE_TOKENIZER_DIR)
    #assert False
    #tokenizer = AutoTokenizer.from_pretrained(SENTENCE_PIECE_TOKENIZER_DIR)


In [28]:
# %%
if 1:
    tokenizer = BertWordPieceTokenizer(clean_text=True, handle_chinese_chars=True,
                                        strip_accents=True, lowercase=True)

    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, special_tokens=[
        "[PAD]", 
        "[UNK]", 
        "[CLS]", 
        "[SEP]", 
        "[DOC]",
    #    "[UDOC]",
        "[MASK]"
    ])
    tokenizer.save_model(WORDPIECE_TOKENIZER_DIR)
    #assert False
    tokenizer = BertTokenizer.from_pretrained(WORDPIECE_TOKENIZER_DIR)







In [29]:
# %%
if 0:
    tokenizer = ByteLevelBPETokenizer(lowercase=True)
    
    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, length=MAX_POSITION_EMBEDDINGS, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<doc>",
        "<mask>",
        
        #"<pad>",
        #"<unk>",
        #"<cls>",
        #"<sep>",
        #"<doc>",
        #"<mask>",
    ])

    # Save files to disk
    tokenizer.save_model(BPE_TOKENIZER_DIR)
    #assert False
    tokenizer = RobertaTokenizer.from_pretrained(BPE_TOKENIZER_DIR)
    #tokenizer = AutoTokenizer.from_pretrained(BPE_TOKENIZER_DIR)


In [30]:
# %%
def encode_and_batch(dataset, tokenizer, max_position_embeddings, batch_size, shuffle=False):
    if DATASET in ("translation_mcca"):
        encoded = preprocess_for_translation(dataset, tokenizer, max_position_embeddings, source_lang="src", target_lang="tgt", prefix=PREFIX, num_proc=4, remove_columns=["src", "tgt"], switch_ii_decoder_ii=SWITCH_II_DECODER_II)
    elif ENCODE_CAUSAL_LM:
        encoded = preprocess_for_causallm(dataset, tokenizer, block_size=MAX_POSITION_EMBEDDINGS, remove_columns=dataset.column_names, shift_right=False)
    else:
        encoded = preprocess_for_maskedlm(dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names, to_mask=.15, chance_rand_token=.2, 
                                          group_texts=GROUP_TEXTS, mask_token=MASK_TOKEN, pad_token=PAD_TOKEN, sparsify=SPARSIFY, prefix=PREFIX, switch_ii_decoder_ii=SWITCH_II_DECODER_II)
        #encoded = preprocess_for_cot(dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names, group_texts=GROUP_TEXTS, pad_token=PAD_TOKEN, sparsify=SPARSIFY, prefix=PREFIX)
       
       
        #encoded = preprocess_for_monologe(dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names)
        #encoded = preprocess_for_sparselm(dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names)
        #encoded = preprocess_for_binary_sparselm(dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names)
        #encoded = preprocess_for_maskedlm(encoded, tokenizer, max_position_embeddings, to_mask=.15, chance_rand_token=.2, group_texts=GROUP_TEXTS, mask_token=MASK_TOKEN)
    print(encoded)
    batched = BatchBuffer(encoded, batch_size)
    if shuffle:
        batched.shuffle()
    print("  finished")
    return batched


In [31]:
# %%
train_loader_call = lambda: encode_and_batch(train_dataset, tokenizer, MAX_POSITION_EMBEDDINGS, TRAIN_BATCH_SIZE, True)
test_loader_call = lambda: encode_and_batch(test_dataset, tokenizer, MAX_POSITION_EMBEDDINGS, TEST_BATCH_SIZE)

if not REBATCH:
    train_dataloader = train_loader_call()
    test_dataloader = test_loader_call()


Map (num_proc=4): 100%|██████████| 25000/25000 [00:07<00:00, 3440.59 examples/s]


Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'decoder_input_ids'],
    num_rows: 5844
})
  finished


Map (num_proc=4): 100%|██████████| 1500/1500 [00:01<00:00, 1421.97 examples/s]


Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'decoder_input_ids'],
    num_rows: 371
})
  finished


In [32]:
# %%
#encoded_train_dataset = preprocess_for_maskedlm(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=train_dataset.column_names)


In [33]:
# %%
#train_dataloader = BatchBuffer(encoded_dataset["train"], BATCH_SIZE).shuffle()
#test_dataloader = BatchBuffer(encoded_dataset["test"], BATCH_SIZE)


In [34]:
# %%
#batch_schema = list(encoded_dataset["train"].features.keys())
#batch_schema


In [35]:
# %%
def count_item(inp, item):
    count = 0
    total = 0
    for n in inp:
        for r in n:
            i = r
            if not i < 4:
                total += 1
            if i == item:
                count += 1
            #if i != 0 and i != item:
            #    print(i)
    return f"{count} / {total} ; {count/total}"


In [36]:
# %%
print("masked tokens [input_ids]:", count_item(train_dataloader.ds["input_ids"], tokenizer.mask_token_id))
print("masked tokens [labels]:", count_item(test_dataloader.ds["labels"], tokenizer.mask_token_id))


masked tokens [input_ids]: 350322 / 2942304 ; 0.11906383568795066
masked tokens [labels]: 0 / 186934 ; 0.0


In [37]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataset) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)


In [38]:
# %%
loss_function = nn.CrossEntropyLoss()


In [39]:
# %%
train_dataloader.schema


['labels',
 'input_ids',
 'token_type_ids',
 'attention_mask',
 'decoder_input_ids']

In [40]:
# %%
print(train_dataloader.schema)
its = 0
for i, n in enumerate(train_dataloader):
    if i >= its:
        break
    for k in n:
        print(i, "############")
        for l in k:
            print(len(l))
        print("##############")


['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'decoder_input_ids']


In [41]:
# %%
model

COINForMaskedLM(
  (coin): COINModel(
    (encoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (decoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (residual_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (stack): COINStack(
      (layers): ModuleList(
        (0-1): 2 x COINLayer(
          (blocks): ModuleList(
            (0): COINBlock(
              (cross_qkv): MultiQueryQKV(
                (rope): RotaryEmbedding()
                (xpos): XPOS()
                (act): 

In [42]:
%%time
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    #train_dataloader,
    #test_dataloader,
    #batch_schema,
    device,
    loss_function,
    id2label,
    train_dataloader=train_dataloader if not REBATCH else None,
    test_dataloader=test_dataloader if not REBATCH else None,
    create_train_dataloader=train_loader_call if REBATCH else None,
    create_test_dataloader=test_loader_call if REBATCH else None,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    only_save_core=False,
    epoch_i=EPOCH_I,
    is_encoder_decoder_model=IS_ENCODER_DECODER_MODEL,
    causal_lm=CAUSAL_LM,

    #forward_args=["input_ids", "attention_mask", "token_type_ids", "labels"],

    masked_lm_task=True,
    electra_task=False,
    mlm_decode_n=0,
    #mlm_decode_n=.0075,
    #mlm_decode_n=.1,
    mlm_decode_max_chars=200,
    tokenizer=tokenizer,

    dump_coin_regions=False,
    generic_output_class=True,
    #coin_region_lambda=lambda model: model.coin.core.regions
)




Generated batch schema as ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'decoder_input_ids']

Training...
  Batch    24  of    974.    Elapsed:  0:00:11, Remaining:  0:07:15.
  Batch    48  of    974.    Elapsed:  0:00:23, Remaining:  0:07:04.
  Batch    72  of    974.    Elapsed:  0:00:34, Remaining:  0:06:53.
  Batch    96  of    974.    Elapsed:  0:00:45, Remaining:  0:06:42.
  Batch   120  of    974.    Elapsed:  0:00:57, Remaining:  0:06:31.
  Batch   144  of    974.    Elapsed:  0:01:08, Remaining:  0:06:55.
  Batch   168  of    974.    Elapsed:  0:01:20, Remaining:  0:06:43.
  Batch   192  of    974.    Elapsed:  0:01:31, Remaining:  0:05:58.
  Batch   216  of    974.    Elapsed:  0:01:43, Remaining:  0:05:47.
  Batch   240  of    974.    Elapsed:  0:01:54, Remaining:  0:05:36.
  Batch   264  of    974.    Elapsed:  0:02:05, Remaining:  0:05:25.
  Batch   288  of    974.    Elapsed:  0:02:17, Remaining:  0:05:14.
  Batch   312  of    974.    Elapsed:  0:02:28, Re

In [None]:
print_stat_tuples(stats)

NameError: name 'stats' is not defined