In [1]:
import glob
import time
import os
import datetime
import math
import datasets
from datasets import DatasetDict, load_dataset, Dataset
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    T5Tokenizer,
    T5ForSequenceClassification,
    T5Config,
    BertTokenizer, 
    RobertaTokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass

from load_set import *
from epoch_stats import EpochStats#, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    train_bern_model, 
    mask_tokens, 
    preprocess_with_given_labels, 
    num_parameters, 
    num_trainable_parameters, 
    preprocess_for_causallm, 
    preprocess_for_multiple_choice,
    preprocess_for_seq2seq_swag,
    preprocess_with_given_labels_train_test_wrap
)
import coin_i3C_modeling as ci3C
import coin_i4_modeling as ci4
from embedding_hybrid_modeling import HybridEmbeddingModel, preprocess_for_hybrid_embeddings

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:
TRAIN_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
CHECKPOINT_PATH = None#"hyp_cls/BiomedBERT-base_all-labels/" 
SHUFFLE_CUSTOM_DATALOADER = True
LEARNING_RATE = 1e-4
EPS = 1e-8
EPOCHS = 10

In [3]:
VOCAB_SIZE = 30_522
#VOCAB_SIZE = 50257 + 1
MAX_POSITION_EMBEDDINGS = 512 #- 10
HIDDEN_SIZE = 1024
IS_HF_MODEL = False
GENERIC_OUTPUT_CLASS = True
DOC_PAD_TOKENS = False

NUM_LABELS = 30

Choose the model (turn if condition to True) and run all cells

In [4]:
if 1:
    tokenizer = transformers.DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = HybridEmbeddingModel(classifier_config=transformers.DebertaConfig())

Some weights of DebertaForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
if 0:
    model_path = "pretrained_models/COIN-i3C_inbio_mask"
    tokenizer = BertTokenizer.from_pretrained(f"{model_path}/wordpiece_tokenizer/")
    model = ci4.COINForSequenceClassification(
        config=ci4.COINConfig(
            hidden_size=HIDDEN_SIZE,
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_layers=4,
            forward_method="llama",
            training_chunk_size=1,
            inference_chunk_size=1,
            rms_norm_eps=1e-06,
            num_heads=1,
            num_labels=NUM_LABELS
        )
    )

In [6]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = transformers.DebertaForSequenceClassification.from_pretrained(
        "microsoft/deberta-base",
        num_labels=NUM_LABELS
    )

In [7]:
if 0:
    model_path = "pretrained_models/COIN-i3C_inbio_mask"
    epoch_i = 13
    tokenizer = BertTokenizer.from_pretrained(f"{model_path}/wordpiece_tokenizer")
    model = ci3C.COINForSequenceClassification.from_pretrained(
        f"{model_path}/model/epoch_{epoch_i}/model/", 
        num_labels=NUM_LABELS, 
        max_position_embeddings=MAX_POSITION_EMBEDDINGS
    )

In [8]:
if 0:
    tokenizer = BertTokenizer.from_pretrained("tmp_models/COIN-i3C_default_tokenizer/wordpiece_tokenizer/")
    #tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|pad|>")
    model = ci3C.COINForSequenceClassification(
        config=ci3C.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=True,
            num_decay_parts=1,
            hidden_retention_act="relu",
            hidden_pos_offset=False,
            rope_dim=16,
            num_query_heads=1,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[0, 0, 0, 0],
            cross_encoder_schema=[0] * 4,
            experts_schema=None,#[2, 2],
            block_io_schema=None,#[[1024, 1024*4, 1024]],
        )
    )

In [9]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base")
    model = transformers.RobertaForSequenceClassification.from_pretrained(
        "roberta-base",
        num_labels=NUM_LABELS
    )

In [10]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.BertTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
    model = transformers.BertForSequenceClassification.from_pretrained(
        "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
        num_labels=NUM_LABELS
    )

In [11]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
    model = transformers.BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels=NUM_LABELS
    )

In [12]:
inbio = load_set(["INBIO.csv"], unused_fields=("Synonyms,Obsolete,CUI,Semantic Types,Parents,achieves,adjacent to,affects,allocates,capable of,characteristic for,completed invasion phase,contained in,contains,contributes to,contributor,created by,decreases,decreases effort in,derives from,derives into,determines,don't use concept,editor note,enabled by,ends,ends during,ends with,enhance,facilitate,has alien range,has amount of closely related species,has amount of species,has area,has component,has decreased effort level by,has distribution,has growth,has habitat,has increased effort level by,has increased levels of,has index,has input,has invasion success likelihood,has level of,has measurement,has measurement unit label,has measurement value,has mortality,has natality,has native range,has number of individuals,has output,has part,has part structure that is capable of,has participant,has propagule pressure,has quality,has range,has recruitment,has role,has spatial occupant at some time,has specific name,has status,has value,http://data.bioontology.org/metadata/obo/part_of,http://data.bioontology.org/metadata/prefixIRI,http://data.bioontology.org/metadata/treeView,http://purl.obolibrary.org/obo/IAO_0000111,http://purl.obolibrary.org/obo/IAO_0000112,http://purl.obolibrary.org/obo/IAO_0000114,http://purl.obolibrary.org/obo/IAO_0000115,http://purl.obolibrary.org/obo/IAO_0000118,http://purl.obolibrary.org/obo/IAO_0000119,http://purl.obolibrary.org/obo/IAO_0000232,http://purl.obolibrary.org/obo/IAO_0000412,http://purl.obolibrary.org/obo/ncbitaxon#has_rank,http://purl.obolibrary.org/obo/NCIT_A8,http://purl.obolibrary.org/obo/NCIT_NHC0,http://purl.obolibrary.org/obo/NCIT_P106,http://purl.obolibrary.org/obo/NCIT_P107,http://purl.obolibrary.org/obo/NCIT_P108,http://purl.obolibrary.org/obo/NCIT_P207,http://purl.obolibrary.org/obo/NCIT_P322,http://purl.obolibrary.org/obo/NCIT_P325,http://purl.obolibrary.org/obo/NCIT_P366,http://purl.obolibrary.org/obo/OBI_0001886,http://purl.obolibrary.org/obo/RO_0001900,http://purl.org/dc/elements/1.1/source,http://purl.org/dc/terms/creator,http://www.geneontology.org/formats/oboInOwl#creation_date,http://www.geneontology.org/formats/oboInOwl#hasAlternativeId,http://www.geneontology.org/formats/oboInOwl#hasBroadSynonym,http://www.geneontology.org/formats/oboInOwl#hasDbXref,http://www.geneontology.org/formats/oboInOwl#hasExactSynonym,http://www.geneontology.org/formats/oboInOwl#hasNarrowSynonym,http://www.geneontology.org/formats/oboInOwl#hasOBONamespace,http://www.geneontology.org/formats/oboInOwl#hasRelatedSynonym,http://www.geneontology.org/formats/oboInOwl#hasSynonymType,http://www.geneontology.org/formats/oboInOwl#id,http://www.geneontology.org/formats/oboInOwl#inSubset,http://www.w3.org/2000/01/rdf-schema#comment,http://www.w3.org/2000/01/rdf-schema#label,http://www.w3.org/2002/07/owl#deprecated,http://www.w3.org/2004/02/skos/core#altLabel,http://www.w3.org/2004/02/skos/core#definition,http://www.w3.org/2004/02/skos/core#notation,https://w3id.org/inbio#_000130,https://w3id.org/inbio#_000132,increases,increases effort in,interacts with,is absent,is affected by,is against,is aggregate of,is alien range to,is characteristic of,is characterized by,is closely related to,is enemy of,is enhanced by,is growth of,is habitat of,is in invasion phase,is mortality of,is natality of,is native range to,is part of,is prey of,is range of,is recruitment of,is similar to,is status of,license,license,license,license,located in,location of,occupies spatial region at some time,occurs in,output of,overlaps,part of,participates in,produced by,produces,quality of,role of,shows changes in species trait,spatially coextensive with,surrounded by,surrounds,title,TODO,license.1,license.2,license.3".split(",")))
bio2def = dict(zip(inbio["Preferred Label"], inbio["Definitions"]))

loading files
    INBIO.csv


In [13]:
inbio

Dataset({
    features: ['Class ID', 'Preferred Label', 'Definitions'],
    num_rows: 457
})

In [14]:
#bio2def = None

In [15]:
# all labels: datasets/wordpiece_abstracts_train_all_labels.csv  30 labels
# 2 labels: datasets/wordpiece_abstracts_train_side_label_1.csv  20 labels
# main label only: datasets/wordpiece_abstracts_train.csv        10 labels

#DS_TRAIN_PATH = "datasets/wordpiece_abstracts_train_all_labels.csv"
#DS_TEST_PATH = "datasets/wordpiece_abstracts_test_all_labels.csv"

DS_TRAIN_PATH = "datasets/abstracts_all_labels_train.csv"
DS_TEST_PATH = "datasets/abstracts_all_labels_test.csv"

In [16]:
dataset = DatasetDict({
    "train": load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"]),
    "test": load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"]),
})
labels = [label for label in dataset['train'].features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

loading files
    datasets/abstracts_all_labels_train.csv
loading files
    datasets/abstracts_all_labels_test.csv
{'a0': 0, 'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5, 'a6': 6, 'a7': 7, 'a8': 8, 'a9': 9, 'b0': 10, 'b1': 11, 'b2': 12, 'b3': 13, 'b4': 14, 'b5': 15, 'b6': 16, 'b7': 17, 'b8': 18, 'b9': 19, 'c0': 20, 'c1': 21, 'c2': 22, 'c3': 23, 'c4': 24, 'c5': 25, 'c6': 26, 'c7': 27, 'c8': 28, 'c9': 29}
{0: 'a0', 1: 'a1', 2: 'a2', 3: 'a3', 4: 'a4', 5: 'a5', 6: 'a6', 7: 'a7', 8: 'a8', 9: 'a9', 10: 'b0', 11: 'b1', 12: 'b2', 13: 'b3', 14: 'b4', 15: 'b5', 16: 'b6', 17: 'b7', 18: 'b8', 19: 'b9', 20: 'c0', 21: 'c1', 22: 'c2', 23: 'c3', 24: 'c4', 25: 'c5', 26: 'c6', 27: 'c7', 28: 'c8', 29: 'c9'}
['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9', 'c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']


In [17]:
if 1:
    encoded_dataset = preprocess_for_hybrid_embeddings(dataset, tokenizer, MAX_POSITION_EMBEDDINGS, incontext_dict=bio2def, remove_columns=dataset["train"].column_names)
if 0:
    encoded_dataset = preprocess_with_given_labels_train_test_wrap(dataset, tokenizer, labels, label2id, MAX_POSITION_EMBEDDINGS, False, remove_columns=dataset["train"].column_names, 
                                                                default_teacher_forcing=False, teacher_forcing_prefix=None, doc_pad_tokens=DOC_PAD_TOKENS,
                                                                incontext_dict=bio2def, move_incontext_to_decoder=True)
encoded_dataset

Map (num_proc=4): 100%|██████████| 722/722 [00:02<00:00, 288.35 examples/s]
Map (num_proc=4): 100%|██████████| 140/140 [00:01<00:00, 70.79 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'pos_input_ids', 'neg_input_ids', 'pos_idcs', 'neg_idcs'],
        num_rows: 722
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'pos_input_ids', 'neg_input_ids', 'pos_idcs', 'neg_idcs'],
        num_rows: 140
    })
})

In [18]:
batch_schema = list(encoded_dataset["train"].features.keys())
print(batch_schema)
train_dataloader = BatchBuffer(encoded_dataset["train"], TRAIN_BATCH_SIZE)
if SHUFFLE_CUSTOM_DATALOADER:
    train_dataloader.shuffle()
test_dataloader = BatchBuffer(encoded_dataset["test"], TEST_BATCH_SIZE)

['input_ids', 'token_type_ids', 'attention_mask', 'pos_input_ids', 'neg_input_ids', 'pos_idcs', 'neg_idcs']


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

In [20]:
if CHECKPOINT_PATH is not None:
    try:
        os.mkdir(CHECKPOINT_PATH)
    except OSError as err:
        print(err)

In [21]:
if IS_HF_MODEL:
    forward_args = ["input_ids", "attention_mask", "token_type_ids", "labels"]
else:
    forward_args = ["input_ids", "attention_mask", "decoder_input_ids", "labels"]

In [22]:
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    id2label=id2label,
    batch_schema=batch_schema,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    one_label_only=False,
    generic_output_class=True,
    calc_metrics=True,
    per_class_f1=True,

    forward_args=None
)



Training...
torch.Size([1, 514, 768])


../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [2,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [3,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [4,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [5,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [6,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
