In [1]:
import math
import numpy as np
import random
import os
import torch
from torch.optim import AdamW
from torch import nn
from datasets import Dataset, DatasetDict, concatenate_datasets
import transformers
from transformers import (
    get_linear_schedule_with_warmup,
    BertTokenizer,
    RobertaTokenizer,
    DebertaTokenizer,
    BertForSequenceClassification,
    RobertaForSequenceClassification,
    DebertaForSequenceClassification,
    DebertaForMaskedLM
)
from tokenizers import (
    BertWordPieceTokenizer
)
from hybrid_embedding_modeling import (
    preprocess_for_hybrid_embeddings,
    preprocess_for_inclusive_embeddings,
    HybridEmbeddingModel
)
from activated_attention_modeling import (
    ActivatedAttentionForSequenceClassification, 
    AAConfig,
    ActivatedAttentionForMaskedLM
)
from utils import (
    get_act_func, 
    Dataloader, 
    train_run, 
    preprocess_with_given_labels, 
    preprocess_with_given_labels_train_test_wrap,
    num_parameters,
    num_trainable_parameters,
    preprocess_for_key_masking,
    preprocess_for_maskedlm,
    #prn_fn
)
from load_set import load_set

In [2]:
TRAIN_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
CHECKPOINT_PATH = None
SHUFFLE_CUSTOM_DATALOADER = True
# use slightly higher lr on mlm
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 20
EMPTY_CACHE = False

In [3]:
#VOCAB_SIZE = 30_522
#VOCAB_SIZE = 50257
VOCAB_SIZE = 50265 
MAX_POSITION_EMBEDDINGS = 514
HIDDEN_SIZE = 1024
GENERIC_OUTPUT_CLASS = True
DOC_PAD_TOKENS = False

NUM_LABELS = 30
CALC_METRICS = True

In [4]:
forward_args = ["input_ids", "attention_mask", "labels"]

In [5]:
CHECKPOINT_PATH = "trained_models/deberta_inbio_def_mlm_lr1e-5/"
WORDPIECE_TOKENIZER_DIR = f"{CHECKPOINT_PATH}/tokenizer_wordpiece/"
try:
    os.mkdir(CHECKPOINT_PATH)
    os.mkdir(WORDPIECE_TOKENIZER_DIR)
except:
    ...

In [6]:
if 0:
    tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    #model_path = "../cc-phoebe/tmp_models/COIN-i3C_default_tokenizer/"
    #tokenizer = BertTokenizer.from_pretrained(f"{model_path}/wordpiece_tokenizer/")
    #VOCAB_SIZE = tokenizer.vocab_size
    model = ActivatedAttentionForMaskedLM(config=AAConfig(
        num_layers=2,
        hidden_act="tanh",
        group_norm_eps=1e-8,
        num_labels=NUM_LABELS,
        hidden_size=HIDDEN_SIZE,
        num_norm_groups=8,
        num_norm_channels=HIDDEN_SIZE,
    ))
    CALC_METRICS = False

In [7]:
if 0:
    #tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    print(tokenizer.vocab_size)
    model_path = "../cc-phoebe/tmp_models/COIN-i3C_default_tokenizer/"
    #tokenizer = BertTokenizer.from_pretrained(f"{model_path}/wordpiece_tokenizer/")
    #VOCAB_SIZE = tokenizer.vocab_size
    model = ActivatedAttentionForSequenceClassification(config=AAConfig(
        num_layers=2,
        hidden_act="tanh",
        group_norm_eps=1e-8,
        num_labels=NUM_LABELS,
        hidden_size=HIDDEN_SIZE,
        num_norm_groups=8,
        num_norm_channels=HIDDEN_SIZE,
    ))
    CHECKPOINT_PATH = None

In [8]:
if 0:
    #tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    #tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    #model_path = "trained_models/aa_mlm_lr1e-4_debertatok/"
    model_path = CHECKPOINT_PATH
    tokenizer = BertTokenizer.from_pretrained(f"{model_path}/tokenizer_wordpiece/")
    #VOCAB_SIZE = tokenizer.vocab_size
    model = ActivatedAttentionForSequenceClassification.from_pretrained(
        f"{model_path}/iter_04/",
        num_labels=NUM_LABELS
    )
    CHECKPOINT_PATH = None

In [9]:
if 0:
    tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = DebertaForMaskedLM.from_pretrained(
        "microsoft/deberta-base"
    )
    CALC_METRICS = False

In [10]:
if 0:
    tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = DebertaForSequenceClassification.from_pretrained(
        "microsoft/deberta-base",
        num_labels=NUM_LABELS,
        problem_type="multi_label_classification" # uses BCEWIthLogitsLoss
    )

In [11]:
if 1:
    tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = DebertaForSequenceClassification.from_pretrained(
        f"{CHECKPOINT_PATH}/iter_07/",
        num_labels=NUM_LABELS,
        problem_type="multi_label_classification"
    )
    CHECKPOINT_PATH = None

Some weights of DebertaForSequenceClassification were not initialized from the model checkpoint at trained_models/deberta_inbio_def_mlm_lr1e-5//iter_07/ and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
print("{:,}\n{:,}".format(num_parameters(model), num_trainable_parameters(model)))

139,215,390
139,215,390


In [13]:
inbio = load_set(["INBIO.csv"], unused_fields=("Synonyms,Obsolete,CUI,Semantic Types,Parents,achieves,adjacent to,affects,allocates,capable of,characteristic for,completed invasion phase,contained in,contains,contributes to,contributor,created by,decreases,decreases effort in,derives from,derives into,determines,don't use concept,editor note,enabled by,ends,ends during,ends with,enhance,facilitate,has alien range,has amount of closely related species,has amount of species,has area,has component,has decreased effort level by,has distribution,has growth,has habitat,has increased effort level by,has increased levels of,has index,has input,has invasion success likelihood,has level of,has measurement,has measurement unit label,has measurement value,has mortality,has natality,has native range,has number of individuals,has output,has part,has part structure that is capable of,has participant,has propagule pressure,has quality,has range,has recruitment,has role,has spatial occupant at some time,has specific name,has status,has value,http://data.bioontology.org/metadata/obo/part_of,http://data.bioontology.org/metadata/prefixIRI,http://data.bioontology.org/metadata/treeView,http://purl.obolibrary.org/obo/IAO_0000111,http://purl.obolibrary.org/obo/IAO_0000112,http://purl.obolibrary.org/obo/IAO_0000114,http://purl.obolibrary.org/obo/IAO_0000115,http://purl.obolibrary.org/obo/IAO_0000118,http://purl.obolibrary.org/obo/IAO_0000119,http://purl.obolibrary.org/obo/IAO_0000232,http://purl.obolibrary.org/obo/IAO_0000412,http://purl.obolibrary.org/obo/ncbitaxon#has_rank,http://purl.obolibrary.org/obo/NCIT_A8,http://purl.obolibrary.org/obo/NCIT_NHC0,http://purl.obolibrary.org/obo/NCIT_P106,http://purl.obolibrary.org/obo/NCIT_P107,http://purl.obolibrary.org/obo/NCIT_P108,http://purl.obolibrary.org/obo/NCIT_P207,http://purl.obolibrary.org/obo/NCIT_P322,http://purl.obolibrary.org/obo/NCIT_P325,http://purl.obolibrary.org/obo/NCIT_P366,http://purl.obolibrary.org/obo/OBI_0001886,http://purl.obolibrary.org/obo/RO_0001900,http://purl.org/dc/elements/1.1/source,http://purl.org/dc/terms/creator,http://www.geneontology.org/formats/oboInOwl#creation_date,http://www.geneontology.org/formats/oboInOwl#hasAlternativeId,http://www.geneontology.org/formats/oboInOwl#hasBroadSynonym,http://www.geneontology.org/formats/oboInOwl#hasDbXref,http://www.geneontology.org/formats/oboInOwl#hasExactSynonym,http://www.geneontology.org/formats/oboInOwl#hasNarrowSynonym,http://www.geneontology.org/formats/oboInOwl#hasOBONamespace,http://www.geneontology.org/formats/oboInOwl#hasRelatedSynonym,http://www.geneontology.org/formats/oboInOwl#hasSynonymType,http://www.geneontology.org/formats/oboInOwl#id,http://www.geneontology.org/formats/oboInOwl#inSubset,http://www.w3.org/2000/01/rdf-schema#comment,http://www.w3.org/2000/01/rdf-schema#label,http://www.w3.org/2002/07/owl#deprecated,http://www.w3.org/2004/02/skos/core#altLabel,http://www.w3.org/2004/02/skos/core#definition,http://www.w3.org/2004/02/skos/core#notation,https://w3id.org/inbio#_000130,https://w3id.org/inbio#_000132,increases,increases effort in,interacts with,is absent,is affected by,is against,is aggregate of,is alien range to,is characteristic of,is characterized by,is closely related to,is enemy of,is enhanced by,is growth of,is habitat of,is in invasion phase,is mortality of,is natality of,is native range to,is part of,is prey of,is range of,is recruitment of,is similar to,is status of,license,license,license,license,located in,location of,occupies spatial region at some time,occurs in,output of,overlaps,part of,participates in,produced by,produces,quality of,role of,shows changes in species trait,spatially coextensive with,surrounded by,surrounds,title,TODO,license.1,license.2,license.3".split(",")))
bio2def = dict(zip(inbio["Preferred Label"], inbio["Definitions"]))

loading files
    INBIO.csv


In [14]:
# all labels: datasets/wordpiece_abstracts_train_all_labels.csv  30 labels
# 2 labels: datasets/wordpiece_abstracts_train_side_label_1.csv  20 labels
# main label only: datasets/wordpiece_abstracts_train.csv        10 labels

#DS_TRAIN_PATH = "datasets/wordpiece_abstracts_train_all_labels.csv"
#DS_TEST_PATH = "datasets/wordpiece_abstracts_test_all_labels.csv"

DS_TRAIN_PATH = "datasets/abstracts_all_labels_train.csv"
DS_TEST_PATH = "datasets/abstracts_all_labels_test.csv"

In [15]:
if 1:
    dataset = DatasetDict({
        "train": load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"]),
        "test": load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"]),
    })

loading files
    datasets/abstracts_all_labels_train.csv
loading files
    datasets/abstracts_all_labels_test.csv


In [16]:
inbio = load_set(["INBIO.csv"], unused_fields=("Synonyms,Obsolete,CUI,Semantic Types,Parents,achieves,adjacent to,affects,allocates,capable of,characteristic for,completed invasion phase,contained in,contains,contributes to,contributor,created by,decreases,decreases effort in,derives from,derives into,determines,don't use concept,editor note,enabled by,ends,ends during,ends with,enhance,facilitate,has alien range,has amount of closely related species,has amount of species,has area,has component,has decreased effort level by,has distribution,has growth,has habitat,has increased effort level by,has increased levels of,has index,has input,has invasion success likelihood,has level of,has measurement,has measurement unit label,has measurement value,has mortality,has natality,has native range,has number of individuals,has output,has part,has part structure that is capable of,has participant,has propagule pressure,has quality,has range,has recruitment,has role,has spatial occupant at some time,has specific name,has status,has value,http://data.bioontology.org/metadata/obo/part_of,http://data.bioontology.org/metadata/prefixIRI,http://data.bioontology.org/metadata/treeView,http://purl.obolibrary.org/obo/IAO_0000111,http://purl.obolibrary.org/obo/IAO_0000112,http://purl.obolibrary.org/obo/IAO_0000114,http://purl.obolibrary.org/obo/IAO_0000115,http://purl.obolibrary.org/obo/IAO_0000118,http://purl.obolibrary.org/obo/IAO_0000119,http://purl.obolibrary.org/obo/IAO_0000232,http://purl.obolibrary.org/obo/IAO_0000412,http://purl.obolibrary.org/obo/ncbitaxon#has_rank,http://purl.obolibrary.org/obo/NCIT_A8,http://purl.obolibrary.org/obo/NCIT_NHC0,http://purl.obolibrary.org/obo/NCIT_P106,http://purl.obolibrary.org/obo/NCIT_P107,http://purl.obolibrary.org/obo/NCIT_P108,http://purl.obolibrary.org/obo/NCIT_P207,http://purl.obolibrary.org/obo/NCIT_P322,http://purl.obolibrary.org/obo/NCIT_P325,http://purl.obolibrary.org/obo/NCIT_P366,http://purl.obolibrary.org/obo/OBI_0001886,http://purl.obolibrary.org/obo/RO_0001900,http://purl.org/dc/elements/1.1/source,http://purl.org/dc/terms/creator,http://www.geneontology.org/formats/oboInOwl#creation_date,http://www.geneontology.org/formats/oboInOwl#hasAlternativeId,http://www.geneontology.org/formats/oboInOwl#hasBroadSynonym,http://www.geneontology.org/formats/oboInOwl#hasDbXref,http://www.geneontology.org/formats/oboInOwl#hasExactSynonym,http://www.geneontology.org/formats/oboInOwl#hasNarrowSynonym,http://www.geneontology.org/formats/oboInOwl#hasOBONamespace,http://www.geneontology.org/formats/oboInOwl#hasRelatedSynonym,http://www.geneontology.org/formats/oboInOwl#hasSynonymType,http://www.geneontology.org/formats/oboInOwl#id,http://www.geneontology.org/formats/oboInOwl#inSubset,http://www.w3.org/2000/01/rdf-schema#comment,http://www.w3.org/2000/01/rdf-schema#label,http://www.w3.org/2002/07/owl#deprecated,http://www.w3.org/2004/02/skos/core#altLabel,http://www.w3.org/2004/02/skos/core#definition,http://www.w3.org/2004/02/skos/core#notation,https://w3id.org/inbio#_000130,https://w3id.org/inbio#_000132,increases,increases effort in,interacts with,is absent,is affected by,is against,is aggregate of,is alien range to,is characteristic of,is characterized by,is closely related to,is enemy of,is enhanced by,is growth of,is habitat of,is in invasion phase,is mortality of,is natality of,is native range to,is part of,is prey of,is range of,is recruitment of,is similar to,is status of,license,license,license,license,located in,location of,occupies spatial region at some time,occurs in,output of,overlaps,part of,participates in,produced by,produces,quality of,role of,shows changes in species trait,spatially coextensive with,surrounded by,surrounds,title,TODO,license.1,license.2,license.3".split(",")))
bio2def = dict(zip(inbio["Preferred Label"], inbio["Definitions"]))
inbio_def_list = [{"text" : n} for n in inbio["Definitions"] if n is not None]
mask_keys = inbio["Preferred Label"]

loading files
    INBIO.csv


In [17]:
if 0:
    inbio_dataset = Dataset.from_list(inbio_def_list)
    dataset["train"] = concatenate_datasets([dataset["train"], inbio_dataset])

In [18]:
if 0:
    inbio = load_set(["INBIO.csv"], unused_fields=("Synonyms,Obsolete,CUI,Semantic Types,Parents,achieves,adjacent to,affects,allocates,capable of,characteristic for,completed invasion phase,contained in,contains,contributes to,contributor,created by,decreases,decreases effort in,derives from,derives into,determines,don't use concept,editor note,enabled by,ends,ends during,ends with,enhance,facilitate,has alien range,has amount of closely related species,has amount of species,has area,has component,has decreased effort level by,has distribution,has growth,has habitat,has increased effort level by,has increased levels of,has index,has input,has invasion success likelihood,has level of,has measurement,has measurement unit label,has measurement value,has mortality,has natality,has native range,has number of individuals,has output,has part,has part structure that is capable of,has participant,has propagule pressure,has quality,has range,has recruitment,has role,has spatial occupant at some time,has specific name,has status,has value,http://data.bioontology.org/metadata/obo/part_of,http://data.bioontology.org/metadata/prefixIRI,http://data.bioontology.org/metadata/treeView,http://purl.obolibrary.org/obo/IAO_0000111,http://purl.obolibrary.org/obo/IAO_0000112,http://purl.obolibrary.org/obo/IAO_0000114,http://purl.obolibrary.org/obo/IAO_0000115,http://purl.obolibrary.org/obo/IAO_0000118,http://purl.obolibrary.org/obo/IAO_0000119,http://purl.obolibrary.org/obo/IAO_0000232,http://purl.obolibrary.org/obo/IAO_0000412,http://purl.obolibrary.org/obo/ncbitaxon#has_rank,http://purl.obolibrary.org/obo/NCIT_A8,http://purl.obolibrary.org/obo/NCIT_NHC0,http://purl.obolibrary.org/obo/NCIT_P106,http://purl.obolibrary.org/obo/NCIT_P107,http://purl.obolibrary.org/obo/NCIT_P108,http://purl.obolibrary.org/obo/NCIT_P207,http://purl.obolibrary.org/obo/NCIT_P322,http://purl.obolibrary.org/obo/NCIT_P325,http://purl.obolibrary.org/obo/NCIT_P366,http://purl.obolibrary.org/obo/OBI_0001886,http://purl.obolibrary.org/obo/RO_0001900,http://purl.org/dc/elements/1.1/source,http://purl.org/dc/terms/creator,http://www.geneontology.org/formats/oboInOwl#creation_date,http://www.geneontology.org/formats/oboInOwl#hasAlternativeId,http://www.geneontology.org/formats/oboInOwl#hasBroadSynonym,http://www.geneontology.org/formats/oboInOwl#hasDbXref,http://www.geneontology.org/formats/oboInOwl#hasExactSynonym,http://www.geneontology.org/formats/oboInOwl#hasNarrowSynonym,http://www.geneontology.org/formats/oboInOwl#hasOBONamespace,http://www.geneontology.org/formats/oboInOwl#hasRelatedSynonym,http://www.geneontology.org/formats/oboInOwl#hasSynonymType,http://www.geneontology.org/formats/oboInOwl#id,http://www.geneontology.org/formats/oboInOwl#inSubset,http://www.w3.org/2000/01/rdf-schema#comment,http://www.w3.org/2000/01/rdf-schema#label,http://www.w3.org/2002/07/owl#deprecated,http://www.w3.org/2004/02/skos/core#altLabel,http://www.w3.org/2004/02/skos/core#definition,http://www.w3.org/2004/02/skos/core#notation,https://w3id.org/inbio#_000130,https://w3id.org/inbio#_000132,increases,increases effort in,interacts with,is absent,is affected by,is against,is aggregate of,is alien range to,is characteristic of,is characterized by,is closely related to,is enemy of,is enhanced by,is growth of,is habitat of,is in invasion phase,is mortality of,is natality of,is native range to,is part of,is prey of,is range of,is recruitment of,is similar to,is status of,license,license,license,license,located in,location of,occupies spatial region at some time,occurs in,output of,overlaps,part of,participates in,produced by,produces,quality of,role of,shows changes in species trait,spatially coextensive with,surrounded by,surrounds,title,TODO,license.1,license.2,license.3".split(",")))
    bio2def = dict(zip(inbio["Preferred Label"], inbio["Definitions"]))
    mask_keys = inbio["Preferred Label"]

    DS_TRAIN_PATH = "datasets/abstracts_all_labels_train.csv"
    DS_TEST_PATH = "datasets/abstracts_all_labels_test.csv"

    train_dataset = load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"])
    #test_dataset = load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"])

    #tok_dataset = concatenate_datasets([train_dataset, test_dataset])
    tok_dataset = train_dataset

In [19]:
labels = [label for label in dataset['train'].features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

{'a0': 0, 'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5, 'a6': 6, 'a7': 7, 'a8': 8, 'a9': 9, 'b0': 10, 'b1': 11, 'b2': 12, 'b3': 13, 'b4': 14, 'b5': 15, 'b6': 16, 'b7': 17, 'b8': 18, 'b9': 19, 'c0': 20, 'c1': 21, 'c2': 22, 'c3': 23, 'c4': 24, 'c5': 25, 'c6': 26, 'c7': 27, 'c8': 28, 'c9': 29}
{0: 'a0', 1: 'a1', 2: 'a2', 3: 'a3', 4: 'a4', 5: 'a5', 6: 'a6', 7: 'a7', 8: 'a8', 9: 'a9', 10: 'b0', 11: 'b1', 12: 'b2', 13: 'b3', 14: 'b4', 15: 'b5', 16: 'b6', 17: 'b7', 18: 'b8', 19: 'b9', 20: 'c0', 21: 'c1', 22: 'c2', 23: 'c3', 24: 'c4', 25: 'c5', 26: 'c6', 27: 'c7', 28: 'c8', 29: 'c9'}
['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9', 'c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']


In [20]:
if 0:
    tokenizer = BertWordPieceTokenizer(clean_text=True, handle_chinese_chars=True,
                                        strip_accents=True, lowercase=True)

    tokenizer.train_from_iterator(iterator=tok_dataset["text"], vocab_size=VOCAB_SIZE, min_frequency=2, special_tokens=[
        "[PAD]", 
        "[UNK]", 
        "[CLS]", 
        "[SEP]", 
        "[MASK]"
    ])
    tokenizer.save_model(WORDPIECE_TOKENIZER_DIR)
    tokenizer = BertTokenizer.from_pretrained(WORDPIECE_TOKENIZER_DIR)

In [21]:
if 0:
    encoded_dataset = preprocess_for_hybrid_embeddings(
        dataset, tokenizer, labels, MAX_POSITION_EMBEDDINGS, incontext_dict=bio2def, remove_columns=dataset["train"].column_names
    )
    encoded_dataset

In [22]:
if 0:
    encoded_dataset = preprocess_with_given_labels_train_test_wrap(
        dataset, tokenizer, labels, label2id, MAX_POSITION_EMBEDDINGS, one_label_only=False, remove_columns=dataset["train"].column_names, 
        default_teacher_forcing=False, teacher_forcing_prefix=None, doc_pad_tokens=DOC_PAD_TOKENS,
        incontext_dict=bio2def, move_incontext_to_decoder=False
    )
    encoded_dataset                                                

In [23]:
def prn_fn(dataset, tokenizer, labels, max_length, remove_columns, text_field="text"):
    def proc(examples):
        text = examples[text_field]
        encoding = tokenizer(text, padding="max_length", truncation=True, max_length=max_length)
        
        labels_batch = {k: v for k, v in examples.items() if k in labels}
        labels_matrix = np.zeros((len(text), len(labels)))
        for idx, label in enumerate(labels):
            labels_matrix[:, idx] = labels_batch[label]
        encoding["labels"] = labels_matrix.tolist()
        return encoding

    return dataset.map(proc, batched=True, num_proc=4, remove_columns=remove_columns)

In [24]:
PREPROCESS_FN = prn_fn
#PREPROCESS_FN = preprocess_for_masked_lm

In [25]:
if 1:
    encoded_dataset = DatasetDict({
        "train": prn_fn(dataset["train"], tokenizer, labels, MAX_POSITION_EMBEDDINGS, dataset["train"].column_names),
        "test": prn_fn(dataset["test"], tokenizer, labels, MAX_POSITION_EMBEDDINGS, dataset["test"].column_names)
    })

Map (num_proc=4):   0%|          | 0/722 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/140 [00:00<?, ? examples/s]

In [26]:
if 0:
    encoded_dataset = DatasetDict({
        "train": preprocess_for_maskedlm(dataset["train"], tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names),
        "test": preprocess_for_maskedlm(dataset["test"], tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["test"].column_names)
    })

In [27]:
if 0:
    encoded_dataset = DatasetDict({
        "train": preprocess_for_key_masking(mask_keys, dataset["train"], tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names),
        "test": preprocess_for_key_masking(mask_keys, dataset["test"], tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["test"].column_names)
    })

In [28]:
if 0:
    encoded_dataset = prn_fn(dataset["train"], tokenizer, labels, MAX_POSITION_EMBEDDINGS, 
        dataset["train"].column_names
    ).train_test_split(test_size=0.1)

In [29]:
if 0:
    encoded_dataset = preprocess_for_maskedlm(dataset["train"], tokenizer, MAX_POSITION_EMBEDDINGS, remove_columns=dataset["train"].column_names).train_test_split(test_size=0.1)

In [30]:
batch_schema = list(encoded_dataset["train"].features.keys())
print(batch_schema)
train_dataloader = Dataloader(encoded_dataset["train"], TRAIN_BATCH_SIZE)
if SHUFFLE_CUSTOM_DATALOADER:
    train_dataloader.shuffle()
test_dataloader = Dataloader(encoded_dataset["test"], TEST_BATCH_SIZE)

['input_ids', 'token_type_ids', 'attention_mask', 'labels']


In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

In [32]:
train_stats, test_stats = train_run(
    model, device, train_dataloader, test_dataloader, 
    id2label, forward_args, optimizer, scheduler, 
    EPOCHS, EMPTY_CACHE, 
    calc_metrics=CALC_METRICS, 
    use_tqdm=False,
    checkpoint_path=CHECKPOINT_PATH
)


  Batch    14  of    722.    Elapsed:  0:00:02, Remaining:  0:01:41.
  Batch    28  of    722.    Elapsed:  0:00:03, Remaining:  0:00:50.
  Batch    42  of    722.    Elapsed:  0:00:04, Remaining:  0:00:49.
  Batch    56  of    722.    Elapsed:  0:00:06, Remaining:  0:00:48.
  Batch    70  of    722.    Elapsed:  0:00:07, Remaining:  0:00:47.
  Batch    84  of    722.    Elapsed:  0:00:08, Remaining:  0:00:46.
  Batch    98  of    722.    Elapsed:  0:00:10, Remaining:  0:00:45.
  Batch   112  of    722.    Elapsed:  0:00:11, Remaining:  0:00:44.
  Batch   126  of    722.    Elapsed:  0:00:12, Remaining:  0:00:43.
  Batch   140  of    722.    Elapsed:  0:00:14, Remaining:  0:00:42.
  Batch   154  of    722.    Elapsed:  0:00:15, Remaining:  0:00:41.
  Batch   168  of    722.    Elapsed:  0:00:17, Remaining:  0:00:40.
  Batch   182  of    722.    Elapsed:  0:00:18, Remaining:  0:00:39.
  Batch   196  of    722.    Elapsed:  0:00:19, Remaining:  0:00:38.
  Batch   210  of    722.    Elap