In [1]:
import glob
import time
import os
import math
from datasets import DatasetDict, load_dataset, Dataset, concatenate_datasets, load_from_disk
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    BertTokenizer, 
    RobertaTokenizer,
    XLNetTokenizer,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np

from load_set import load_set, load_moses_set
from epoch_stats import EpochStats, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    mask_tokens, 
    train_bern_model, 
    preprocess_for_maskedlm, 
    preprocess_for_causallm, 
    preprocess_for_monologe, 
    preprocess_for_sparselm, 
    preprocess_for_binary_sparselm,
    num_parameters, 
    num_trainable_parameters,
    preprocess_for_translation,
    preprocess_for_key_masking
)

import coin_i2C_modeling as ci2C
import coin_i2D_modeling as ci2D
import coin_i3A_modeling as ci3A
import coin_i3C_modeling as ci3C
import coin_i4_modeling as ci4
import transformer_modeling as tm

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:
REBATCH = False
TRAIN_BATCH_SIZE = 3
TEST_BATCH_SIZE = 3
#BASE_PATH = "benchmark_models/Llama_config_B3/"
BASE_PATH = "benchmark_models/ppl_play/10k_play/"

#BASE_PATH = "pretrained_models/COIN-i3C_inbio_mask_no-decoder"
#BASE_PATH = "tmp_models/COIN-i3C_oasst1_25k_mlm_1x2_1dec-none_no-revert_parallel_mlm-head_B6_no-shuffle_A-T_multi-query-2/"

if REBATCH:
    if BASE_PATH[-1] == "/":
        BASE_PATH = BASE_PATH[:-1]
    BASE_PATH += "_rebatch/"
DATASET_JSON_PATH = f"{BASE_PATH}/dataset/"
CHECKPOINT_PATH = f"{BASE_PATH}/model/"
WORDPIECE_TOKENIZER_DIR = f"{BASE_PATH}/wordpiece_tokenizer/"
BPE_TOKENIZER_DIR = f"{BASE_PATH}/bpe_tokenizer/"
SENTENCE_PIECE_TOKENIZER_DIR = f"{BASE_PATH}/sentence_piece_tokenizer/"
USE_CUSTOM_DATALOADER = False
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 25

In [3]:
SHUFFLE_TRAIN_DATA = True
SHUFFLE_TEST_DATA = False

In [4]:
CAUSAL_LM = False
ENCODE_CAUSAL_LM = False
GROUP_TEXTS = True
SPARSIFY = True
MASK_TOKEN = None
PAD_TOKEN = None
PREFIX = None#"Translate the following text:"
#PREFIX = "Replace all of the mask-tokens: "
#PREFIX = "This sentence is completely obsolete "
SWITCH_II_DECODER_II = True

In [5]:
f"out dir: {BASE_PATH}"

'out dir: benchmark_models/ppl_play/10k_play/'

In [6]:
#VOCAB_SIZE = 30_522
#VOCAB_SIZE = 32_000
VOCAB_SIZE = 50257 #+ 2
MAX_POSITION_EMBEDDINGS = 512
#MAX_POSITION_EMBEDDINGS = 516
#MAX_POSITION_EMBEDDINGS = 768
IS_HF_MODEL = False
IS_ENCODER_DECODER_MODEL = False
EPOCH_I = 0

In [7]:
MCCA_SAVE_CONFIG = ci3C.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    forward_method="parallel",
    apply_decay=True,
    num_decay_parts=1,
    hidden_retention_act="relu",
    hidden_pos_offset=False,
    rope_dim=16,
    num_query_heads=1,

    decoder_output="adaptive",
    revert_decoder=True,
    decoder_schema=[1] * 4,
    cross_encoder_schema=[1] * 4,
    block_io_schema=None,#[[1024, 1024*4, 1024], [1024, 1024*2, 1024]],
)

In [8]:
if 0:
    model = tm.TransformerForCausalLM(
        config=tm.TransformerConfig(
            vocab_size=VOCAB_SIZE,
            hidden_size=1024,
            intermediate_size=1536,
            num_layers=4,
            rms_norm_eps=1e-6,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        )
    )

In [9]:
if 1:
    model = ci4.COINForCausalLM(
        config=ci4.COINConfig(
            vocab_size=VOCAB_SIZE,
            hidden_size=1024,
            intermediate_size=1536,
            forward_method="parallel",
            num_layers=4,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            rms_norm_eps=1e-6,
            hidden_dropout_prob=0.0,
            training_chunk_size=None,
            inference_chunk_size=1,
            reset_hidden_states=True,
            apply_decay_mask=True,
            apply_attention_mask=False,
            apply_group_mask=False,
            gamma=(1 - 1e-6),
            num_heads=1,
            num_key_value_heads=None,
        )
    )

gamma: 0.999999
gamma: 0.999999
gamma: 0.999999
gamma: 0.999999


In [10]:
if 0:
    IS_HF_MODEL = True
    model = transformers.LlamaForCausalLM(
        config=transformers.LlamaConfig(
            vocab_size=VOCAB_SIZE,
            hidden_size=1024,
            intermediate_size=1536,
            num_hidden_layers=4,
            num_attention_heads=1,
            num_key_value_heads=1,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            rms_norm_eps=1e-6,
            hidden_dropout_prob=0.0,
            use_cache=False,
            _attn_implementation="eager",
            mlp_bias=True,
            attention_bias=True,
        )
    )

In [11]:
CI3C_CONFIG = ci3C.COINConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    hidden_size=1024,
    forward_method="parallel",
    apply_decay=True,
    num_decay_parts=1,
    hidden_retention_act="relu",
    hidden_pos_offset=False,
    rope_dim=16,
    num_query_heads=1,

    decoder_output="none",
    revert_decoder=False,
    decoder_schema=[0] * 2,
    cross_encoder_schema=[0] * 2,
    experts_schema=None,#[2, 2],
    block_io_schema=None,#[[1024, 1024*4, 1024]],
)

In [12]:
if 0:
    model = ci3C.COINForConditionalGeneration(
        CI3C_CONFIG
    )

In [13]:
if 0:
    model = ci3C.COINForMaskedLM(
        CI3C_CONFIG
    )

In [14]:
print("{:,}\n{:,}".format(num_parameters(model), num_trainable_parameters(model)))

297,126,176
297,126,144


In [15]:
# %%
try:
    os.mkdir(BASE_PATH)
except OSError as err:
    print(err)
try:
    os.mkdir(CHECKPOINT_PATH)
except OSError as err:
    print(err)
try:
    os.mkdir(WORDPIECE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(BPE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(SENTENCE_PIECE_TOKENIZER_DIR)
except OSError as err:
    print(err)
try:
    os.mkdir(DATASET_JSON_PATH)
except OSError as err:
    print(err)


[Errno 17] File exists: 'benchmark_models/ppl_play/10k_play/'
[Errno 17] File exists: 'benchmark_models/ppl_play/10k_play//model/'
[Errno 17] File exists: 'benchmark_models/ppl_play/10k_play//wordpiece_tokenizer/'
[Errno 17] File exists: 'benchmark_models/ppl_play/10k_play//bpe_tokenizer/'
[Errno 17] File exists: 'benchmark_models/ppl_play/10k_play//sentence_piece_tokenizer/'
[Errno 17] File exists: 'benchmark_models/ppl_play/10k_play//dataset/'


In [16]:
# %%
#dataset = DatasetDict({
#    "train": load_dataset("wikitext", name="wikitext-103-raw-v1", split="train[0:10000]"),
#    "test":  load_dataset("wikitext", name="wikitext-103-raw-v1", split="validation[:1500]")
#})

#dataset


In [17]:
# oasst1, aaabdon, mcca, translation_mcca
#DATASET = "translation_mcca"
#DATASET = "slim_pajama"
#DATASET = "oasst1"
DATASET = "wikitext_ppl"

#HF_TRAIN_ROWS = 25_000
#HF_TRAIN_ROWS = 10_000
HF_TRAIN_ROWS = 10_000
HF_TRAIN_FROM = 0#10_000

#HF_TEST_ROWS = 1_500
#HF_TEST_ROWS = 1000
HF_TEST_ROWS = -1
HF_TEST_FROM = 0

CUSTOM_BASE_DS_PATH = "../datasets/big_AAABDON_Nmax_st200_s0_a10_tvsplit.1_no_norm/"
CUSTOM_DS_TO_FILE = 2

In [18]:
if DATASET in ("wikitext_ppl", "wikitext_mlm"):
    train_dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
    test_dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "slim_pajama":
    CACHE_DIR = "/mnt/pushshift/slim_lajama_627B/"
    train_dataset = load_dataset("cerebras/SlimPajama-627B", split=f"train[{HF_TRAIN_FROM}:{HF_TRAIN_ROWS}]", cache_dir=CACHE_DIR)
    test_dataset = load_dataset("cerebras/SlimPajama-627B", split=f"validation[{HF_TEST_FROM}:{HF_TEST_ROWS}]", cache_dir=CACHE_DIR)
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "inbio_mask":
    inbio = load_set(["INBIO.csv"], unused_fields=("Synonyms,Obsolete,CUI,Semantic Types,Parents,achieves,adjacent to,affects,allocates,capable of,characteristic for,completed invasion phase,contained in,contains,contributes to,contributor,created by,decreases,decreases effort in,derives from,derives into,determines,don't use concept,editor note,enabled by,ends,ends during,ends with,enhance,facilitate,has alien range,has amount of closely related species,has amount of species,has area,has component,has decreased effort level by,has distribution,has growth,has habitat,has increased effort level by,has increased levels of,has index,has input,has invasion success likelihood,has level of,has measurement,has measurement unit label,has measurement value,has mortality,has natality,has native range,has number of individuals,has output,has part,has part structure that is capable of,has participant,has propagule pressure,has quality,has range,has recruitment,has role,has spatial occupant at some time,has specific name,has status,has value,http://data.bioontology.org/metadata/obo/part_of,http://data.bioontology.org/metadata/prefixIRI,http://data.bioontology.org/metadata/treeView,http://purl.obolibrary.org/obo/IAO_0000111,http://purl.obolibrary.org/obo/IAO_0000112,http://purl.obolibrary.org/obo/IAO_0000114,http://purl.obolibrary.org/obo/IAO_0000115,http://purl.obolibrary.org/obo/IAO_0000118,http://purl.obolibrary.org/obo/IAO_0000119,http://purl.obolibrary.org/obo/IAO_0000232,http://purl.obolibrary.org/obo/IAO_0000412,http://purl.obolibrary.org/obo/ncbitaxon#has_rank,http://purl.obolibrary.org/obo/NCIT_A8,http://purl.obolibrary.org/obo/NCIT_NHC0,http://purl.obolibrary.org/obo/NCIT_P106,http://purl.obolibrary.org/obo/NCIT_P107,http://purl.obolibrary.org/obo/NCIT_P108,http://purl.obolibrary.org/obo/NCIT_P207,http://purl.obolibrary.org/obo/NCIT_P322,http://purl.obolibrary.org/obo/NCIT_P325,http://purl.obolibrary.org/obo/NCIT_P366,http://purl.obolibrary.org/obo/OBI_0001886,http://purl.obolibrary.org/obo/RO_0001900,http://purl.org/dc/elements/1.1/source,http://purl.org/dc/terms/creator,http://www.geneontology.org/formats/oboInOwl#creation_date,http://www.geneontology.org/formats/oboInOwl#hasAlternativeId,http://www.geneontology.org/formats/oboInOwl#hasBroadSynonym,http://www.geneontology.org/formats/oboInOwl#hasDbXref,http://www.geneontology.org/formats/oboInOwl#hasExactSynonym,http://www.geneontology.org/formats/oboInOwl#hasNarrowSynonym,http://www.geneontology.org/formats/oboInOwl#hasOBONamespace,http://www.geneontology.org/formats/oboInOwl#hasRelatedSynonym,http://www.geneontology.org/formats/oboInOwl#hasSynonymType,http://www.geneontology.org/formats/oboInOwl#id,http://www.geneontology.org/formats/oboInOwl#inSubset,http://www.w3.org/2000/01/rdf-schema#comment,http://www.w3.org/2000/01/rdf-schema#label,http://www.w3.org/2002/07/owl#deprecated,http://www.w3.org/2004/02/skos/core#altLabel,http://www.w3.org/2004/02/skos/core#definition,http://www.w3.org/2004/02/skos/core#notation,https://w3id.org/inbio#_000130,https://w3id.org/inbio#_000132,increases,increases effort in,interacts with,is absent,is affected by,is against,is aggregate of,is alien range to,is characteristic of,is characterized by,is closely related to,is enemy of,is enhanced by,is growth of,is habitat of,is in invasion phase,is mortality of,is natality of,is native range to,is part of,is prey of,is range of,is recruitment of,is similar to,is status of,license,license,license,license,located in,location of,occupies spatial region at some time,occurs in,output of,overlaps,part of,participates in,produced by,produces,quality of,role of,shows changes in species trait,spatially coextensive with,surrounded by,surrounds,title,TODO,license.1,license.2,license.3".split(",")))
    bio2def = dict(zip(inbio["Preferred Label"], inbio["Definitions"]))
    mask_keys = inbio["Preferred Label"]

    DS_TRAIN_PATH = "datasets/abstracts_all_labels_train.csv"
    DS_TEST_PATH = "datasets/abstracts_all_labels_test.csv"

    train_dataset = load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"])
    test_dataset = load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"])

    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "mcca":
    train_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-2][0-9]",
        ]
    })
    test_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
        ]
    })
    tok_dataset = concatenate_datasets([train_dataset, test_dataset])
elif DATASET == "translation_mcca":
    train_dataset = load_moses_set({
        "src": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-5][0-9]",
            "../datasets/multi_cc_aligned_en-de/de/x00[0-5][0-9]",
        ],
        "tgt": [
            "../datasets/multi_cc_aligned_en-de/de/x00[0-5][0-9]",
            "../datasets/multi_cc_aligned_en-de/en/x00[0-5][0-9]",
        ]
    })
    test_dataset = load_moses_set({
        "src": [
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
            "../datasets/multi_cc_aligned_en-de/de/x800[0-1]",
        ],
        "tgt": [
            "../datasets/multi_cc_aligned_en-de/de/x800[0-1]",
            "../datasets/multi_cc_aligned_en-de/en/x800[0-1]",
        ]
    })
    tok_dataset = load_moses_set({
        "text": [
            "../datasets/multi_cc_aligned_en-de/en/x00[0-9][0-9]",
            "../datasets/multi_cc_aligned_en-de/de/x00[0-9][0-9]",
        ]
    })
elif DATASET == "aaabdon":
    train_ds = [f"{CUSTOM_BASE_DS_PATH}/train/train_00[0-{CUSTOM_DS_TO_FILE}].csv"]
    test_ds = [f"{CUSTOM_BASE_DS_PATH}/validation/validation_00[0-{CUSTOM_DS_TO_FILE}].csv"]
    train_dataset = load_set(train_ds)#.select(list(range(HF_TRAIN_ROWS)))
    test_dataset = load_set(test_ds)#.select(list(range(HF_TEST_ROWS)))
    tok_dataset = load_set([f"{CUSTOM_BASE_DS_PATH}/train/train_*.csv"])
elif DATASET == "oasst1":
    #train_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split=f"train[0:{HF_TRAIN_ROWS}]")
    #test_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split=f"validation[:{HF_TEST_ROWS}]")
    #tok_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split="train")
    
    #train_dataset = load_dataset("QingyiSi/Alpaca-CoT", split=f"train[0:{HF_TRAIN_ROWS}]")  .rename_column("instruction", "text").rename_column("output", "target")
    #test_dataset = load_dataset("QingyiSi/Alpaca-CoT", split=f"test[0:{HF_TEST_ROWS}]")     .rename_column("instruction", "text").rename_column("output", "target")
    #tok_dataset = load_dataset("QingyiSi/Alpaca-CoT")                                       .rename_column("instruction", "text").rename_column("output", "target")

    train_dataset = load_dataset("OpenAssistant/oasst1", split="train").filter(lambda e: e["lang"] == "en")
    test_dataset = load_dataset("OpenAssistant/oasst1", split="validation").filter(lambda e: e["lang"] == "en")
    tok_dataset = load_dataset("OpenAssistant/oasst1", split="train").filter(lambda e: e["lang"] == "en")
    #train_dataset = load_from_disk("../datasets/oasst1/train").filter(lambda e: e["lang"] == "en").select(list(range(HF_TRAIN_ROWS)))
    #test_dataset = load_from_disk("../datasets/oasst1/validation").filter(lambda e: e["lang"] == "en").select(list(range(HF_TEST_ROWS)))
    #tok_dataset = load_from_disk("../datasets/oasst1/train").filter(lambda e: e["lang"] == "en")

    #train_dataset = load_dataset("OpenAssistant/oasst2", split="train").filter(lambda e: e["lang"] == "en").select(list(range(HF_TRAIN_ROWS)))
    #test_dataset = load_dataset("OpenAssistant/oasst2", split="validation").filter(lambda e: e["lang"] == "en").select(list(range(HF_TEST_ROWS)))

if SHUFFLE_TRAIN_DATA:
    print("shuffle")
    train_dataset = train_dataset.shuffle()
if HF_TRAIN_ROWS > 0:
    train_dataset =  train_dataset.select(list(range(HF_TRAIN_FROM, HF_TRAIN_ROWS)))
if SHUFFLE_TEST_DATA:
    print("shuffle")
    test_dataset = test_dataset.shuffle()
if HF_TEST_ROWS > 0:
    test_dataset = test_dataset.select(list(range(HF_TEST_FROM, HF_TEST_ROWS)))

print(train_dataset)
print(test_dataset)
print(tok_dataset)

shuffle
Dataset({
    features: ['text'],
    num_rows: 10000
})
Dataset({
    features: ['text'],
    num_rows: 4358
})
Dataset({
    features: ['text'],
    num_rows: 1805708
})


In [19]:
#train_dataset.to_json(f"{DATASET_JSON_PATH}/train.json")
#test_dataset.to_json(f"{DATASET_JSON_PATH}/test.json")

In [20]:
labels = [label for label in train_dataset.features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

{}
{}
[]


In [21]:
from tokenizers import ByteLevelBPETokenizer, BertWordPieceTokenizer, SentencePieceBPETokenizer, SentencePieceUnigramTokenizer

In [22]:
# %%
if 0:
    #tok_dataset = load_dataset("wikitext", name="wikitext-103-raw-v1", split="train")
    #tok_dataset = load_dataset("glue", name="sst2", split="train")
    #tok_dataset = tok_dataset.rename_column("sentence", "text")
    
    tokenizer = SentencePieceUnigramTokenizer()

    tokenizer.train_from_iterator(
        iterator=tok_dataset["text"], 
        vocab_size=VOCAB_SIZE,
        #min_frequency=2,
        show_progress=True,
        #limit_alphabet=500,
        special_tokens=[
            "<PAD>", 
            "<UNK>", 
            "<CLS>", 
            "<SEP>", 
            "<DOC>",
            "<MASK>"
        ])

    tokenizer = PreTrainedTokenizer(
        tokenizer_object=tokenizer
    )
    tokenizer.save_model(SENTENCE_PIECE_TOKENIZER_DIR)
    #assert False
    #tokenizer = AutoTokenizer.from_pretrained(SENTENCE_PIECE_TOKENIZER_DIR)


In [23]:
ACCESS_TOKEN = "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"

In [24]:
def batch_iterator():
    for i in range(0, len(tok_dataset), TRAIN_BATCH_SIZE):
        yield tok_dataset[i : i + TRAIN_BATCH_SIZE]["text"]

In [25]:
if 1:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "gpt2", 
    #    vocab_size=VOCAB_SIZE - 2,
    #    pad_token="<|pad|>",
    #    unk_token="<|unk|>",
    #    bos_token="<|doc|>",
    #    eos_token="<|udoc|>",
    )
    #tokenizer.eos_token_id = 4



In [26]:
VOCAB_SIZE, tokenizer.vocab_size

(50257, 50257)

In [27]:
if 0:
    old_tokenizer = transformers.AutoTokenizer.from_pretrained("AdithyaSK/LLama3Tokenizer", token=ACCESS_TOKEN)
    tokenizer = old_tokenizer.train_new_from_iterator(batch_iterator(), vocab_size=VOCAB_SIZE)

In [28]:
# %%
if 0:
    tokenizer = BertWordPieceTokenizer(clean_text=True, handle_chinese_chars=True,
                                        strip_accents=True, lowercase=True)
    #tokenizer = transformers.LlamaTokenizer()

    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, special_tokens=[
        "[PAD]", 
        "[UNK]", 
        "[CLS]", 
        "[SEP]", 
        "[DOC]",
    #    "[UDOC]",
        "[MASK]"
    ])
    tokenizer.save_model(WORDPIECE_TOKENIZER_DIR)
    #assert False
    tokenizer = BertTokenizer.from_pretrained(WORDPIECE_TOKENIZER_DIR)


In [29]:
# %%
if 0:
    tokenizer = ByteLevelBPETokenizer(lowercase=True)
    
    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, length=MAX_POSITION_EMBEDDINGS, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<doc>",
        "<mask>",
        
        #"<pad>",
        #"<unk>",
        #"<cls>",
        #"<sep>",
        #"<doc>",
        #"<mask>",
    ])

    # Save files to disk
    tokenizer.save_model(BPE_TOKENIZER_DIR)
    #assert False
    tokenizer = RobertaTokenizer.from_pretrained(BPE_TOKENIZER_DIR)
    #tokenizer = AutoTokenizer.from_pretrained(BPE_TOKENIZER_DIR)


In [30]:
# %%
def encode_and_batch(dataset, tokenizer, max_position_embeddings, batch_size, shuffle=False):
    if DATASET == "inbio_mask":
        encoded = preprocess_for_key_masking(mask_keys, dataset, tokenizer, max_position_embeddings, remove_columns=dataset.column_names, to_mask=.15, chance_rand_token=.2, 
                                          group_texts=GROUP_TEXTS, pad_token=PAD_TOKEN, sparsify=SPARSIFY, prefix=PREFIX, switch_ii_decoder_ii=SWITCH_II_DECODER_II)
    elif DATASET in ("translation_mcca"):
        encoded = preprocess_for_translation(dataset, tokenizer, max_position_embeddings, source_lang="src", target_lang="tgt", prefix=PREFIX, num_proc=4, remove_columns=["src", "tgt"], switch_ii_decoder_ii=SWITCH_II_DECODER_II)
    elif DATASET in ("wikitext_ppl"):
        encoded = preprocess_for_causallm(
            dataset, 
            tokenizer, 
            block_size=MAX_POSITION_EMBEDDINGS, 
            remove_columns=dataset.column_names, 
            shift_right=False,
            pad_token_id=tokenizer.pad_token_id,
            doc_token_id=tokenizer.bos_token_id,
            udoc_token_id=tokenizer.eos_token_id,    
        )
    elif DATASET in ("wikitext_mlm"):
        encoded = preprocess_for_maskedlm(
            dataset, 
            tokenizer, 
            max_position_embeddings, 
            remove_columns=dataset.column_names, 
            to_mask=.15, 
            chance_rand_token=.2, 
            group_texts=GROUP_TEXTS, 
            mask_token=MASK_TOKEN, 
            pad_token=PAD_TOKEN, 
            sparsify=SPARSIFY, 
            prefix=PREFIX, 
            switch_ii_decoder_ii=SWITCH_II_DECODER_II
        )
        
    print(encoded)
    batched = BatchBuffer(encoded, batch_size)
    if shuffle:
        batched.shuffle()
    print("  finished")
    return batched


In [31]:
train_dataset[0]

{'text': ''}

In [32]:
test_dataset[1]

{'text': ' = Robert Boulter = \n'}

In [33]:
# %%
train_loader_call = lambda: encode_and_batch(train_dataset, tokenizer, MAX_POSITION_EMBEDDINGS, TRAIN_BATCH_SIZE, True)
test_loader_call = lambda: encode_and_batch(test_dataset, tokenizer, MAX_POSITION_EMBEDDINGS, TEST_BATCH_SIZE)

if not REBATCH:
    train_dataloader = train_loader_call()
    test_dataloader = test_loader_call()


Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  10%|█         | 1000/10000 [00:00<00:02, 3665.50 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  50%|█████     | 5000/10000 [00:00<00:00, 14943.38 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4):  90%|█████████ | 9000/10000 [00:00<00:00, 22132.77 examples/s]

num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])
num keys: dict_keys(['input_ids', 'group_mask', 'attention_mask'])


Map (num_proc=4): 100%|██████████| 10000/10000 [00:00<00:00, 16060.46 examples/s]

Dataset({
    features: ['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels'],
    num_rows: 1320
})
  finished
Dataset({
    features: ['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels'],
    num_rows: 557
})
  finished





In [34]:
test_dataloader.ds[0]["attention_mask"]

[4,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 4,
 4,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 4,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [35]:
test_dataloader.ds[0]["group_mask"]

[0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [36]:
# %%
#encoded_train_dataset = preprocess_for_maskedlm(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=train_dataset.column_names)


In [37]:
# %%
#train_dataloader = BatchBuffer(encoded_dataset["train"], BATCH_SIZE).shuffle()
#test_dataloader = BatchBuffer(encoded_dataset["test"], BATCH_SIZE)


In [38]:
# %%
#batch_schema = list(encoded_dataset["train"].features.keys())
#batch_schema


In [39]:
# %%
def count_item(inp, item):
    count = 0
    total = 0
    for n in inp:
        for r in n:
            i = r
            if not i < 4:
                total += 1
            if i == item:
                count += 1
            #if i != 0 and i != item:
            #    print(i)
    return f"{count} / {total} ; {count/total}"


In [40]:
# %%
print("masked tokens [input_ids]:", count_item(train_dataloader.ds["input_ids"], tokenizer.mask_token_id))
print("masked tokens [labels]:", count_item(test_dataloader.ds["labels"], tokenizer.mask_token_id))


masked tokens [input_ids]: 0 / 675840 ; 0.0
masked tokens [labels]: 0 / 285184 ; 0.0


In [41]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataset) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)


In [42]:
# %%
loss_function = nn.CrossEntropyLoss()


In [43]:
# %%
train_dataloader.schema


['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels']

In [44]:
# %%
print(train_dataloader.schema)
its = 0
for i, n in enumerate(train_dataloader):
    if i >= its:
        break
    for k in n:
        print(i, "############")
        for l in k:
            print(len(l))
        print("##############")


['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels']


In [45]:
# %%
model

COINForCausalLM(
  (coin): COINModel(
    (encoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(50257, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (decoder_embeddings): COINEmbeddings(
      (word_embeddings): Embedding(50257, 1024, padding_idx=0)
      (norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (sampler): COINSampler(
      (layers): COINMultiLayerBlock(
        (layers): ModuleList(
          (0-3): 4 x COINLayer(
            (block): i4Block(
              (Ub_inner): Linear(in_features=1024, out_features=1024, bias=True)
              (Wb_r): Linear(in_features=1024, out_features=1024, bias=True)
              (Ub_r): Linear(in_features=1024, out_features=1024, bias=True)
              (Wb_z): Linear(in_features=1024, out_features=1024, bias=True)
              (Ub_z): Linear(in_feature

In [46]:
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [47]:
torch.set_printoptions(threshold=100_000_000)

In [48]:
%%time
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    #train_dataloader,
    #test_dataloader,
    #batch_schema,
    device,
    loss_function,
    id2label,
    train_dataloader=train_dataloader if not REBATCH else None,
    test_dataloader=test_dataloader if not REBATCH else None,
    create_train_dataloader=train_loader_call if REBATCH else None,
    create_test_dataloader=test_loader_call if REBATCH else None,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    only_save_core=False,
    epoch_i=EPOCH_I,
    is_encoder_decoder_model=IS_ENCODER_DECODER_MODEL,
    causal_lm=CAUSAL_LM,

    forward_args=["input_ids", "attention_mask", "labels"],
    #forward_args=["input_ids", "decoder_input_ids", "group_mask", "labels"],

    masked_lm_task=True,
    electra_task=False,
    mlm_decode_n=0,
    #mlm_decode_n=.0075,
    #mlm_decode_n=.1,
    mlm_decode_max_chars=200,
    tokenizer=tokenizer,

    dump_coin_regions=False,
    generic_output_class=True,
    #coin_region_lambda=lambda model: model.coin.core.regions

    evaluate_autoregressively=False,
)



Generated batch schema as ['input_ids', 'group_mask', 'attention_mask', 'decoder_input_ids', 'labels']

Training...
  Batch     8  of    440.    Elapsed:  0:00:03, Remaining:  0:02:42.
  Batch    16  of    440.    Elapsed:  0:00:05, Remaining:  0:01:46.
  Batch    24  of    440.    Elapsed:  0:00:07, Remaining:  0:01:44.
  Batch    32  of    440.    Elapsed:  0:00:10, Remaining:  0:01:42.
  Batch    40  of    440.    Elapsed:  0:00:12, Remaining:  0:01:40.
  Batch    48  of    440.    Elapsed:  0:00:14, Remaining:  0:01:38.
  Batch    56  of    440.    Elapsed:  0:00:17, Remaining:  0:01:36.
  Batch    64  of    440.    Elapsed:  0:00:19, Remaining:  0:01:34.
  Batch    72  of    440.    Elapsed:  0:00:22, Remaining:  0:01:32.
  Batch    80  of    440.    Elapsed:  0:00:24, Remaining:  0:01:30.
  Batch    88  of    440.    Elapsed:  0:00:26, Remaining:  0:01:28.
  Batch    96  of    440.    Elapsed:  0:00:29, Remaining:  0:01:26.
  Batch   104  of    440.    Elapsed:  0:00:31, Remain

In [None]:
print_stat_tuples(stats)

Epoch 1:
  Train:
    loss: 10.622668109169268
  Test:
    loss: 9.689842956130569
    perplexity: 16375.289287004804

Epoch 2:
  Train:
    loss: 8.951763831778031
  Test:
    loss: 8.411100485518173
    perplexity: 4574.336782624653

Epoch 3:
  Train:
    loss: 7.877769417839137
  Test:
    loss: 7.55449881424775
    perplexity: 1969.7468266255503

Epoch 4:
  Train:
    loss: 7.267687891386194
  Test:
    loss: 7.14455542693267
    perplexity: 1314.0777654034669

Epoch 5:
  Train:
    loss: 6.962168479675014
  Test:
    loss: 6.94379976375683
    perplexity: 1078.4973929683547

Epoch 6:
  Train:
    loss: 6.7870453075085955
  Test:
    loss: 6.82009051297162
    perplexity: 955.021693094451

Epoch 7:
  Train:
    loss: 6.655083815620475
  Test:
    loss: 6.7224352604634054
    perplexity: 867.798382941681

Epoch 8:
  Train:
    loss: 6.534788072791198
  Test:
    loss: 6.637134670566868
    perplexity: 797.8378715540773

Epoch 9:
  Train:
    loss: 6.415604601190074
  Test:
    loss: