In [1]:
import glob
import time
import os
import math
from datasets import DatasetDict, load_dataset, Dataset, concatenate_datasets, load_from_disk
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    BertTokenizer, 
    RobertaTokenizer,
    XLNetTokenizer,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np

from load_set import load_set, load_moses_set
from epoch_stats import EpochStats, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    mask_tokens, 
    train_bern_model, 
    preprocess_for_maskedlm, 
    preprocess_for_causallm, 
    preprocess_for_monologe, 
    preprocess_for_sparselm, 
    preprocess_for_binary_sparselm,
    num_parameters, 
    num_trainable_parameters,
    preprocess_for_translation,
    preprocess_for_key_masking
)

import coin_i2C_modeling as ci2C
import coin_i2D_modeling as ci2D
import coin_i3A_modeling as ci3A
import coin_i3C_modeling as ci3C
import coin_i4_modeling as ci4
import transformer_modeling as tm
import coin_rnn_modeling as ciR
import perceiver_modeling as pm

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:
REBATCH = False
TRAIN_BATCH_SIZE = 6
TEST_BATCH_SIZE = 6
#BASE_PATH = "benchmark_models/Llama_config_B3/"
#BASE_PATH = "benchmark_models/ppl_play/10k_play/"
#BASE_PATH = "tmp_models/ci3c_mlm_wikitext/"
BASE_PATH = None

#BASE_PATH = "pretrained_models/COIN-i3C_inbio_mask_no-decoder"
#BASE_PATH = "tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2/"

if REBATCH:
    if BASE_PATH[-1] == "/":
        BASE_PATH = BASE_PATH[:-1]
    BASE_PATH += "_rebatch/"
DATASET_JSON_PATH = f"{BASE_PATH}/dataset/"
CHECKPOINT_PATH = f"{BASE_PATH}/model/" if BASE_PATH is not None else None
WORDPIECE_TOKENIZER_DIR = f"{BASE_PATH}/wordpiece_tokenizer/"
BPE_TOKENIZER_DIR = f"{BASE_PATH}/bpe_tokenizer/"
SENTENCE_PIECE_TOKENIZER_DIR = f"{BASE_PATH}/sentence_piece_tokenizer/"
USE_CUSTOM_DATALOADER = False
LEARNING_RATE = 4e-4
EPS = 1e-8
EPOCHS = 25

In [3]:
SHUFFLE_TRAIN_DATA = True
SHUFFLE_TEST_DATA = False

In [4]:
CAUSAL_LM = False
ENCODE_CAUSAL_LM = False
GROUP_TEXTS = True
SPARSIFY = True
MASK_TOKEN = None
PAD_TOKEN = None
PREFIX = None#"Translate the following text:"
#PREFIX = "Replace all of the mask-tokens: "
#PREFIX = "This sentence is completely obsolete "
SWITCH_II_DECODER_II = True

In [5]:
f"out dir: {BASE_PATH}"

'out dir: None'

In [6]:
#VOCAB_SIZE = 30_522
#VOCAB_SIZE = 32_000
VOCAB_SIZE = 50257 #+ 2
MAX_POSITION_EMBEDDINGS = 512
IS_HF_MODEL = False
IS_ENCODER_DECODER_MODEL = False
EPOCH_I = 0

In [7]:
MCCA_SAVE_CONFIG = ci3C.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    forward_method="parallel",
    apply_decay=True,
    num_decay_parts=1,
    hidden_retention_act="relu",
    hidden_pos_offset=False,
    rope_dim=16,
    num_query_heads=1,

    decoder_output="adaptive",
    revert_decoder=True,
    decoder_schema=[1] * 4,
    cross_encoder_schema=[1] * 4,
    block_io_schema=None,#[[1024, 1024*4, 1024], [1024, 1024*2, 1024]],
)

In [8]:
TM_CONFIG = tm.TransformerConfig(
    vocab_size=VOCAB_SIZE,
    hidden_size=1024,
    intermediate_size=1536,
    num_layers=1,
    rms_norm_eps=1e-6,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
)

In [9]:
if 0:
    model = tm.TransformerForCausalLM(
        config=TM_CONFIG
    )

In [10]:
if 0:
    model = tm.TransformerForMaskedLM(
        config=TM_CONFIG
    )

In [11]:
if 0:
    model = ciR.COINForCausalLM(
        config=ciR.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_size=1024,
            num_layers=4,
            rms_norm_eps=1e-6
        )
    )

In [12]:
PERCEIVER_CONFIG = pm.PerceiverConfig(
    vocab_size=VOCAB_SIZE,
    hidden_size=1024,
    intermediate_size=1536,
    num_layers=1,
    rms_norm_eps=1e-6,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    num_heads=1,
    num_kv_heads=1,
    #conv_kernel_size=4,
)

In [13]:
if 1:
    model = pm.PerceiverForCausalLM(
        config=PERCEIVER_CONFIG
    )

In [14]:
if 0:
    model = pm.PerceiverForMaskedLM(
        config=PERCEIVER_CONFIG
    )

In [15]:
if 0:
    model = ci4.COINForCausalLM(
        config=ci4.COINConfig(
            vocab_size=VOCAB_SIZE,
            hidden_size=1024,
            intermediate_size=1536,
            forward_method="llama",
            num_layers=4,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            rms_norm_eps=1e-6,
            hidden_dropout_prob=0.0,
            training_chunk_size=None,
            inference_chunk_size=None,
            reset_hidden_states=True,
            apply_decay_mask=True,
            apply_attention_mask=False,
            apply_group_mask=False,
            gamma=(1 - 1e-6),
            num_heads=1,
            num_key_value_heads=None,
            conv_kernel_size=8,
        )
    )

In [16]:
if 0:
    IS_HF_MODEL = True
    model = transformers.LlamaForCausalLM(
        config=transformers.LlamaConfig(
            vocab_size=VOCAB_SIZE,
            hidden_size=1024,
            intermediate_size=1536,
            num_hidden_layers=4,
            num_attention_heads=1,
            num_key_value_heads=1,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            rms_norm_eps=1e-6,
            #hidden_dropout_prob=0.0,
            #use_cache=False,
            #_attn_implementation="eager",
            #mlp_bias=True,
            #attention_bias=True,
        )
    )

In [17]:
CI3C_CONFIG = ci3C.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    hidden_size=1024,
    forward_method="parallel",
    apply_decay=True,
    num_decay_parts=1,
    hidden_retention_act="relu",
    hidden_pos_offset=False,
    rope_dim=16,
    num_query_heads=1,

    decoder_output="none",
    revert_decoder=False,
    decoder_schema=[0] * 2,
    cross_encoder_schema=[0] * 2,
    experts_schema=None,#[2, 2],
    block_io_schema=None,#[[1024, 1024*4, 1024]],
)

In [18]:
if 0:
    model = ci3C.COINForConditionalGeneration(
        CI3C_CONFIG
    )

In [19]:
if 0:
    model = ci3C.COINForMaskedLM(
        CI3C_CONFIG
    )

In [20]:
if 0:
    model = ci3C.COINForCausalLM(
        CI3C_CONFIG
    )

In [21]:
if 0:
    CHECKPOINT = "tmp_models/ci3c_mlm_wikitext/"
    I = 5
    tokenizer = BertTokenizer.from_pretrained(f"{CHECKPOINT}/wordpiece_tokenizer/")
    model = ci3C.COINForCausalLM.from_pretrained(f"{CHECKPOINT}/model/epoch_{I}/model")

In [22]:
print("{:,}\n{:,}".format(num_parameters(model), num_trainable_parameters(model)))

128,659,584
128,658,560


In [23]:
# %%
if BASE_PATH is not None:
    try:
        os.mkdir(BASE_PATH)
    except OSError as err:
        print(err)
if CHECKPOINT_PATH is not None:
    try:
        os.mkdir(CHECKPOINT_PATH)
    except OSError as err:
        print(err)
try:
    os.mkdir(WORDPIECE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(BPE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(SENTENCE_PIECE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(DATASET_JSON_PATH)
except OSError as err:
    print(err)


[Errno 17] File exists: 'None/wordpiece_tokenizer/'
[Errno 17] File exists: 'None/bpe_tokenizer/'
[Errno 17] File exists: 'None/sentence_piece_tokenizer/'
[Errno 17] File exists: 'None/dataset/'


In [24]:
# %%
#dataset = DatasetDict({
#    "train": load_dataset("wikitext", name="wikitext-103-raw-v1", split="train[0:10000]"),
#    "test":  load_dataset("wikitext", name="wikitext-103-raw-v1", split="validation[:1500]")
#})

#dataset


In [25]:
# oasst1, aaabdon, mcca, translation_mcca
DATASET = "wikitext_ppl"
#DATASET = "wikitext_mlm"

#HF_TRAIN_ROWS = 25_000
#HF_TRAIN_ROWS = 10_000
HF_TRAIN_ROWS = 25_000
HF_TRAIN_FROM = 0#10_000

#HF_TEST_ROWS = 1_500
#HF_TEST_ROWS = 1000
HF_TEST_ROWS = -1
HF_TEST_FROM = 0

CUSTOM_BASE_DS_PATH = "../datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/"
CUSTOM_DS_TO_FILE = 2

In [26]:
if DATASET in ("wikitext_ppl", "wikitext_mlm"):
    train_dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
    test_dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "slim_pajama":
    CACHE_DIR = "/mnt/pushshift/slim_lajama_627B/"
    train_dataset = load_dataset("cerebras/SlimPajama-627B", split=f"train[{HF_TRAIN_FROM}:{HF_TRAIN_ROWS}]", cache_dir=CACHE_DIR)
    test_dataset = load_dataset("cerebras/SlimPajama-627B", split=f"validation[{HF_TEST_FROM}:{HF_TEST_ROWS}]", cache_dir=CACHE_DIR)
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "inbio_mask":
    inbio = load_set(["INBIO.csv"], unused_fields=("Synonyms,Obsolete,CUI,Semantic Types,Parents,achieves,adjacent to,affects,allocates,capable of,characteristic for,completed invasion phase,contained in,contains,contributes to,contributor,created by,decreases,decreases effort in,derives from,derives into,determines,don't use concept,editor note,enabled by,ends,ends during,ends with,enhance,facilitate,has alien range,has amount of closely related species,has amount of species,has area,has component,has decreased effort level by,has distribution,has growth,has habitat,has increased effort level by,has increased levels of,has index,has input,has invasion success likelihood,has level of,has measurement,has measurement unit label,has measurement value,has mortality,has natality,has native range,has number of individuals,has output,has part,has part structure that is capable of,has participant,has propagule pressure,has quality,has range,has recruitment,has role,has spatial occupant at some time,has specific name,has status,has value,http://data.bioontology.org/metadata/obo/part_of,http://data.bioontology.org/metadata/prefixIRI,http://data.bioontology.org/metadata/treeView,http://purl.obolibrary.org/obo/IAO_0000111,http://purl.obolibrary.org/obo/IAO_0000112,http://purl.obolibrary.org/obo/IAO_0000114,http://purl.obolibrary.org/obo/IAO_0000115,http://purl.obolibrary.org/obo/IAO_0000118,http://purl.obolibrary.org/obo/IAO_0000119,http://purl.obolibrary.org/obo/IAO_0000232,http://purl.obolibrary.org/obo/IAO_0000412,http://purl.obolibrary.org/obo/ncbitaxon#has_rank,http://purl.obolibrary.org/obo/NCIT_A8,http://purl.obolibrary.org/obo/NCIT_NHC0,http://purl.obolibrary.org/obo/NCIT_P106,http://purl.obolibrary.org/obo/NCIT_P107,http://purl.obolibrary.org/obo/NCIT_P108,http://purl.obolibrary.org/obo/NCIT_P207,http://purl.obolibrary.org/obo/NCIT_P322,http://purl.obolibrary.org/obo/NCIT_P325,http://purl.obolibrary.org/obo/NCIT_P366,http://purl.obolibrary.org/obo/OBI_0001886,http://purl.obolibrary.org/obo/RO_0001900,http://purl.org/dc/elements/1.1/source,http://purl.org/dc/terms/creator,http://www.geneontology.org/formats/oboInOwl#creation_date,http://www.geneontology.org/formats/oboInOwl#hasAlternativeId,http://www.geneontology.org/formats/oboInOwl#hasBroadSynonym,http://www.geneontology.org/formats/oboInOwl#hasDbXref,http://www.geneontology.org/formats/oboInOwl#hasExactSynonym,http://www.geneontology.org/formats/oboInOwl#hasNarrowSynonym,http://www.geneontology.org/formats/oboInOwl#hasOBONamespace,http://www.geneontology.org/formats/oboInOwl#hasRelatedSynonym,http://www.geneontology.org/formats/oboInOwl#hasSynonymType,http://www.geneontology.org/formats/oboInOwl#id,http://www.geneontology.org/formats/oboInOwl#inSubset,http://www.w3.org/2000/01/rdf-schema#comment,http://www.w3.org/2000/01/rdf-schema#label,http://www.w3.org/2002/07/owl#deprecated,http://www.w3.org/2004/02/skos/core#altLabel,http://www.w3.org/2004/02/skos/core#definition,http://www.w3.org/2004/02/skos/core#notation,https://w3id.org/inbio#_000130,https://w3id.org/inbio#_000132,increases,increases effort in,interacts with,is absent,is affected by,is against,is aggregate of,is alien range to,is characteristic of,is characterized by,is closely related to,is enemy of,is enhanced by,is growth of,is habitat of,is in invasion phase,is mortality of,is natality of,is native range to,is part of,is prey of,is range of,is recruitment of,is similar to,is status of,license,license,license,license,located in,location of,occupies spatial region at some time,occurs in,output of,overlaps,part of,participates in,produced by,produces,quality of,role of,shows changes in species trait,spatially coextensive with,surrounded by,surrounds,title,TODO,license.1,license.2,license.3".split(",")))
    bio2def = dict(zip(inbio["Preferred Label"], inbio["Definitions"]))
    mask_keys = inbio["Preferred Label"]

    DS_TRAIN_PATH = "datasets/abstracts_all_labels_train.csv"
    DS_TEST_PATH = "datasets/abstracts_all_labels_test.csv"

    train_dataset = load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"])
    test_dataset = load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"])

    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "mcca":
    train_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-2][0-9]",
        ]
    })
    test_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
        ]
    })
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "translation_mcca":
    train_dataset = load_moses_set({
        "src": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-5][0-9]",
            "../datasets/multi_cc_aligned_en-de/de/x00[0-5][0-9]",
        ],
        "tgt": [
            "../datasets/multi_cc_aligned_en-de/de/x00[0-5][0-9]",
            "../datasets/multi_cc_aligned_en-de/en/x00[0-5][0-9]",
        ]
    })
    test_dataset = load_moses_set({
        "src": [
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
            "../datasets/multi_cc_aligned_en-de/de/x800[0-1]",
        ],
        "tgt": [
            "../datasets/multi_cc_aligned_en-de/de/x800[0-1]",
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
        ]
    })
    tok_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-9][0-9]",
            "../datasets/multi_cc_aligned_en-de/de/x00[0-9][0-9]",
        ]
    })
elif DATASET == "aaabdon":
    train_ds = [f"{CUSTOM_BASE_DS_PATH}/train/train_00[0-{CUSTOM_DS_TO_FILE}].csv"]
    test_ds = [f"{CUSTOM_BASE_DS_PATH}/validation/validation_00[0-{CUSTOM_DS_TO_FILE}].csv"]
    train_dataset = load_set(train_ds)#.select(list(range(HF_TRAIN_ROWS)))
    test_dataset = load_set(test_ds)#.select(list(range(HF_TEST_ROWS)))
    tok_dataset = load_set([f"{CUSTOM_BASE_DS_PATH}/train/train_*.csv"])
elif DATASET == "oasst1":
    #train_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split=f"train[0:{HF_TRAIN_ROWS}]")
    #test_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split=f"validation[:{HF_TEST_ROWS}]")
    #tok_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split="train")
    
    #train_dataset = load_dataset("QingyiSi/Alpaca-CoT", split=f"train[0:{HF_TRAIN_ROWS}]")  .rename_column("instruction", "text").rename_column("output", "target")
    #test_dataset = load_dataset("QingyiSi/Alpaca-CoT", split=f"test[0:{HF_TEST_ROWS}]")     .rename_column("instruction", "text").rename_column("output", "target")
    #tok_dataset = load_dataset("QingyiSi/Alpaca-CoT")                                       .rename_column("instruction", "text").rename_column("output", "target")

    train_dataset = load_dataset("OpenAssistant/oasst1", split="train").filter(lambda e: e["lang"] == "en")
    test_dataset = load_dataset("OpenAssistant/oasst1", split="validation").filter(lambda e: e["lang"] == "en")
    tok_dataset = load_dataset("OpenAssistant/oasst1", split="train").filter(lambda e: e["lang"] == "en")
    #train_dataset = load_from_disk("../datasets/oasst1/train").filter(lambda e: e["lang"] == "en").select(list(range(HF_TRAIN_ROWS)))
    #test_dataset = load_from_disk("../datasets/oasst1/validation").filter(lambda e: e["lang"] == "en").select(list(range(HF_TEST_ROWS)))
    #tok_dataset = load_from_disk("../datasets/oasst1/train").filter(lambda e: e["lang"] == "en")

    #train_dataset = load_dataset("OpenAssistant/oasst2", split="train").filter(lambda e: e["lang"] == "en").select(list(range(HF_TRAIN_ROWS)))
    #test_dataset = load_dataset("OpenAssistant/oasst2", split="validation").filter(lambda e: e["lang"] == "en").select(list(range(HF_TEST_ROWS)))

if SHUFFLE_TRAIN_DATA:
    print("shuffle")
    train_dataset = train_dataset.shuffle()
if HF_TRAIN_ROWS > 0:
    train_dataset =  train_dataset.select(list(range(HF_TRAIN_FROM, HF_TRAIN_ROWS)))
if SHUFFLE_TEST_DATA:
    print("shuffle")
    test_dataset = test_dataset.shuffle()
if HF_TEST_ROWS > 0:
    test_dataset = test_dataset.select(list(range(HF_TEST_FROM, HF_TEST_ROWS)))

print(train_dataset)
print(test_dataset)
print(tok_dataset)

shuffle
Dataset({
    features: ['text'],
    num_rows: 25000
})
Dataset({
    features: ['text'],
    num_rows: 4358
})
Dataset({
    features: ['text'],
    num_rows: 1805708
})


In [27]:
#train_dataset.to_json(f"{DATASET_JSON_PATH}/train.json")
#test_dataset.to_json(f"{DATASET_JSON_PATH}/test.json")

In [28]:
labels = [label for label in train_dataset.features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

{}
{}
[]


In [29]:
from tokenizers import ByteLevelBPETokenizer, BertWordPieceTokenizer, SentencePieceBPETokenizer, SentencePieceUnigramTokenizer

In [30]:
# %%
if 0:
    #tok_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split="train")
    #tok_dataset = load_dataset("glue", name="sst2", split="train")
    #tok_dataset = tok_dataset.rename_column("sentence", "text")
    
    tokenizer = SentencePieceUnigramTokenizer()

    tokenizer.train_from_iterator(
        iterator=tok_dataset["text"], 
        vocab_size=VOCAB_SIZE,
        #min_frequency=2,
        show_progress=True,
        #limit_alphabet=500,
        special_tokens=[
            "<PAD>", 
            "<UNK>", 
            "<CLS>", 
            "<SEP>", 
            "<DOC>",
            "<MASK>"
        ])

    tokenizer = PreTrainedTokenizer(
        tokenizer_object=tokenizer
    )
    tokenizer.save_model(SENTENCE_PIECE_TOKENIZER_DIR)
    #assert False
    #tokenizer = AutoTokenizer.from_pretrained(SENTENCE_PIECE_TOKENIZER_DIR)


In [32]:
def batch_iterator():
    for i in range(0, len(tok_dataset), TRAIN_BATCH_SIZE):
        yield tok_dataset[i : i + TRAIN_BATCH_SIZE]["text"]

In [33]:
if 1:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "gpt2", 
    #    vocab_size=VOCAB_SIZE - 2,
    #    pad_token="<|pad|>",
    #    unk_token="<|unk|>",
    #    bos_token="<|doc|>",
    #    eos_token="<|udoc|>",
    )
    #tokenizer.eos_token_id = 4



In [34]:
if 0:
    old_tokenizer = transformers.AutoTokenizer.from_pretrained("AdithyaSK/LLama3Tokenizer", token=ACCESS_TOKEN)
    tokenizer = old_tokenizer.train_new_from_iterator(batch_iterator(), vocab_size=VOCAB_SIZE)

In [35]:
# %%
if 0:
    tokenizer = BertWordPieceTokenizer(clean_text=True, handle_chinese_chars=True,
                                        strip_accents=True, lowercase=True)
    #tokenizer = transformers.LlamaTokenizer()

    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, special_tokens=[
        "[PAD]", 
        "[UNK]", 
        "[CLS]", 
        "[SEP]", 
    #    "[DOC]",
    #    "[UDOC]",
        "[MASK]"
    ])
    tokenizer.save_model(WORDPIECE_TOKENIZER_DIR)
    #assert False
    tokenizer = BertTokenizer.from_pretrained(WORDPIECE_TOKENIZER_DIR)


In [36]:
# %%
if 0:
    tokenizer = ByteLevelBPETokenizer(lowercase=True)
    
    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, length=MAX_POSITION_EMBEDDINGS, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<doc>",
        "<mask>",
        
        #"<pad>",
        #"<unk>",
        #"<cls>",
        #"<sep>",
        #"<doc>",
        #"<mask>",
    ])

    # Save files to disk
    tokenizer.save_model(BPE_TOKENIZER_DIR)
    #assert False
    tokenizer = RobertaTokenizer.from_pretrained(BPE_TOKENIZER_DIR)
    #tokenizer = AutoTokenizer.from_pretrained(BPE_TOKENIZER_DIR)


In [37]:
# %%
def encode_and_batch(dataset, tokenizer, max_position_embeddings, batch_size, shuffle=False):
    if DATASET == "inbio_mask":
        encoded = preprocess_for_key_masking(mask_keys, dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names, to_mask=.15, chance_rand_token=.2, 
                                          group_texts=GROUP_TEXTS, pad_token=PAD_TOKEN, sparsify=SPARSIFY, prefix=PREFIX, switch_ii_decoder_ii=SWITCH_II_DECODER_II)
    elif DATASET in ("translation_mcca"):
        encoded = preprocess_for_translation(dataset, tokenizer, max_position_embeddings, source_lang="src", target_lang="tgt", prefix=PREFIX, num_proc=4, remove_columns=["src", "tgt"], switch_ii_decoder_ii=SWITCH_II_DECODER_II)
    elif DATASET in ("wikitext_ppl"):
        encoded = preprocess_for_causallm(
            dataset, 
            tokenizer, 
            block_size=MAX_POSITION_EMBEDDINGS, 
            remove_columns=dataset.column_names, 
            shift_right=False,
            pad_token_id=tokenizer.pad_token_id,
            doc_token_id=tokenizer.bos_token_id,
            udoc_token_id=tokenizer.eos_token_id
        )
    elif DATASET in ("wikitext_mlm"):
        encoded = preprocess_for_maskedlm(
            dataset, 
            tokenizer, 
            max_position_embeddings, 
            remove_columns=dataset.column_names, 
            to_mask=.15,
            chance_rand_token=.2, 
            group_texts=GROUP_TEXTS, 
            mask_token=MASK_TOKEN, 
            pad_token=PAD_TOKEN, 
            sparsify=SPARSIFY, 
            prefix=PREFIX, 
            switch_ii_decoder_ii=False
        )
        
    print(encoded)
    batched = BatchBuffer(encoded, batch_size)
    if shuffle:
        batched.shuffle()
    print("  finished")
    return batched


In [38]:
train_dataset[0]

{'text': ''}

In [39]:
test_dataset[1]

{'text': ' = Robert Boulter = \n'}

In [40]:
# %%
train_loader_call = lambda: encode_and_batch(train_dataset, tokenizer, MAX_POSITION_EMBEDDINGS, TRAIN_BATCH_SIZE, True)
test_loader_call = lambda: encode_and_batch(test_dataset, tokenizer, MAX_POSITION_EMBEDDINGS, TEST_BATCH_SIZE)

if not REBATCH:
    train_dataloader = train_loader_call()
    test_dataloader = test_loader_call()


Map (num_proc=4):   0%|          | 0/25000 [00:00<?, ? examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):   4%|▍         | 1000/25000 [00:00<00:05, 4040.98 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  20%|██        | 5000/25000 [00:00<00:01, 14709.10 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  36%|███▌      | 9000/25000 [00:00<00:00, 20076.84 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  52%|█████▏    | 13000/25000 [00:00<00:00, 22452.30 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  68%|██████▊   | 17000/25000 [00:00<00:00, 24214.20 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  84%|████████▍ | 21000/25000 [00:00<00:00, 26187.46 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  98%|█████████▊| 24500/25000 [00:01<00:00, 28379.30 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4): 100%|██████████| 25000/25000 [00:01<00:00, 21366.92 examples/s]

Dataset({
    features: ['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels'],
    num_rows: 3254
})
  finished
Dataset({
    features: ['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels'],
    num_rows: 557
})
  finished





In [41]:
test_dataloader.ds[0]["input_ids"]

[4,
 796,
 5199,
 347,
 2852,
 353,
 796,
 220,
 198,
 4,
 4,
 5199,
 347,
 2852,
 353,
 318,
 281,
 3594,
 2646,
 837,
 5581,
 290,
 21421,
 8674,
 764,
 679,
 550,
 257,
 8319,
 2488,
 12,
 31,
 20495,
 2597,
 319,
 262,
 5581,
 2168,
 383,
 3941,
 287,
 4751,
 764,
 770,
 373,
 3940,
 416,
 257,
 20495,
 2597,
 287,
 262,
 711,
 2332,
 684,
 3194,
 416,
 11288,
 37072,
 837,
 543,
 373,
 6157,
 287,
 5878,
 379,
 262,
 8111,
 3078,
 15752,
 764,
 679,
 550,
 257,
 8319,
 2597,
 287,
 262,
 5581,
 2168,
 8974,
 1757,
 1024,
 276,
 287,
 6244,
 764,
 554,
 5472,
 347,
 2852,
 353,
 11406,
 257,
 2597,
 355,
 366,
 13854,
 366,
 287,
 262,
 4471,
 366,
 29345,
 705,
 82,
 8362,
 366,
 286,
 262,
 5581,
 2168,
 383,
 5882,
 31623,
 2162,
 339,
 31636,
 7848,
 10544,
 2940,
 13535,
 290,
 20893,
 12806,
 72,
 764,
 679,
 373,
 3350,
 287,
 262,
 5075,
 21421,
 32260,
 286,
 262,
 14576,
 39616,
 711,
 21673,
 22384,
 837,
 543,
 373,
 6157,
 379,
 262,
 25331,
 15752,
 287,
 42125,
 290,

In [42]:
# %%
#encoded_train_dataset = preprocess_for_maskedlm(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=train_dataset.column_names)


In [43]:
# %%
#train_dataloader = BatchBuffer(encoded_dataset["train"], BATCH_SIZE).shuffle()
#test_dataloader = BatchBuffer(encoded_dataset["test"], BATCH_SIZE)


In [44]:
# %%
#batch_schema = list(encoded_dataset["train"].features.keys())
#batch_schema


In [45]:
# %%
def count_item(inp, item):
    count = 0
    total = 0
    for n in inp:
        for r in n:
            i = r
            if not i < 4:
                total += 1
            if i == item:
                count += 1
            #if i != 0 and i != item:
            #    print(i)
    return f"{count} / {total} ; {count/total}"


In [46]:
tokenizer.mask_token_id

In [47]:
# %%
print("masked tokens [input_ids]:", count_item(train_dataloader.ds["input_ids"], tokenizer.mask_token_id))
print("masked tokens [labels]:", count_item(test_dataloader.ds["labels"], tokenizer.mask_token_id))


masked tokens [input_ids]: 0 / 1666048 ; 0.0
masked tokens [labels]: 0 / 285184 ; 0.0


In [48]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataset) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)


In [49]:
# %%
loss_function = nn.CrossEntropyLoss()


In [50]:
# %%
train_dataloader.schema


['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels']

In [51]:
# %%
model

PerceiverForCausalLM(
  (model): PerceiverModel(
    (embeddings): PerceiverEmbeddings(
      (word_embeddings): Embedding(50257, 1024, padding_idx=0)
    )
    (model): PerceiverEncoder(
      (layers): ModuleList(
        (0): PerceiverLayer(
          (cross_attn): PCI5Attention(
            (softmax): PartialSoftmax()
            (Q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (K_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (V_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (rope): RotaryEmbedding()
            (fc): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (cross_rms): RMSNorm()
          (latent_rms): RMSNorm()
          (glu_rms): RMSNorm()
          (glu): GLU(
            (fi): Linear(in_features=1024, out_features=3072, bias=True)
            (fc): Linear(in_features=1536, out_fea

In [52]:
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [53]:
torch.set_printoptions(threshold=5_000_000)

In [54]:
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    #train_dataloader,
    #test_dataloader,
    #batch_schema,
    device,
    loss_function,
    id2label,
    train_dataloader=train_dataloader if not REBATCH else None,
    test_dataloader=test_dataloader if not REBATCH else None,
    create_train_dataloader=train_loader_call if REBATCH else None,
    create_test_dataloader=test_loader_call if REBATCH else None,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    only_save_core=False,
    epoch_i=EPOCH_I,
    is_encoder_decoder_model=IS_ENCODER_DECODER_MODEL,
    causal_lm=CAUSAL_LM,

    forward_args=["input_ids", "attention_mask", "labels"],
    #forward_args=["input_ids", "decoder_input_ids", "group_mask", "labels"],

    masked_lm_task=True,
    electra_task=False,
    mlm_decode_n=0,
    #mlm_decode_n=.0075,
    #mlm_decode_n=.1,
    mlm_decode_max_chars=200,
    tokenizer=tokenizer,

    dump_coin_regions=False,
    generic_output_class=True,
    #coin_region_lambda=lambda model: model.coin.core.regions

    evaluate_autoregressively=False,
)



Generated batch schema as ['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels']

Training...
  Batch    10  of    542.    Elapsed:  0:00:06, Remaining:  0:05:19.
  Batch    20  of    542.    Elapsed:  0:00:11, Remaining:  0:05:13.
  Batch    30  of    542.    Elapsed:  0:00:17, Remaining:  0:05:07.
  Batch    40  of    542.    Elapsed:  0:00:23, Remaining:  0:05:01.
  Batch    50  of    542.    Elapsed:  0:00:28, Remaining:  0:04:55.
  Batch    60  of    542.    Elapsed:  0:00:34, Remaining:  0:04:49.
  Batch    70  of    542.    Elapsed:  0:00:39, Remaining:  0:04:43.
  Batch    80  of    542.    Elapsed:  0:00:45, Remaining:  0:04:37.
  Batch    90  of    542.    Elapsed:  0:00:51, Remaining:  0:04:31.
  Batch   100  of    542.    Elapsed:  0:00:56, Remaining:  0:04:25.
  Batch   110  of    542.    Elapsed:  0:01:02, Remaining:  0:04:19.
  Batch   120  of    542.    Elapsed:  0:01:07, Remaining:  0:04:13.
  Batch   130  of    542.    Elapsed:  0:01:13, Remain

In [None]:
print_stat_tuples(stats)