In [None]:
import glob
import time
import os
import datetime
import math
import datasets
from datasets import DatasetDict, load_dataset, Dataset
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    T5Tokenizer,
    T5ForSequenceClassification,
    T5Config,
    BertTokenizer, 
    RobertaTokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass

from load_set import *
from epoch_stats import EpochStats#, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    train_bern_model, 
    mask_tokens, 
    preprocess_with_given_labels, 
    num_parameters, 
    num_trainable_parameters, 
    preprocess_for_causallm, 
    preprocess_for_multiple_choice,
    preprocess_for_seq2seq_swag,
    preprocess_with_given_labels_train_test_wrap
)
import coin_i3C_modeling as ci3C

In [None]:
TRAIN_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
CHECKPOINT_PATH = None#"hyp_cls/BiomedBERT-base_all-labels/" 
SHUFFLE_CUSTOM_DATALOADER = True
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 10

In [None]:
VOCAB_SIZE = 30_522
MAX_POSITION_EMBEDDINGS = 512
HIDDEN_SIZE = 1024
IS_HF_MODEL = False
GENERIC_OUTPUT_CLASS = True
DOC_PAD_TOKENS = False

NUM_LABELS = 30

Choose the model (turn if condition to True) and run all cells

In [None]:
if 1:
    IS_HF_MODEL = True
    tokenizer = transformers.DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = transformers.DebertaForSequenceClassification.from_pretrained(
        "microsoft/deberta-base",
        num_labels=NUM_LABELS
    )

In [None]:
if 0:
    tokenizer = BertTokenizer.from_pretrained("tmp_models/COIN-i3C_default_tokenizer/wordpiece_tokenizer/")
    model = ci3C.COINForSequenceClassification(
        config=ci3C.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=True,
            num_decay_parts=1,
            hidden_retention_act="relu",
            hidden_pos_offset=False,
            rope_dim=32,
            num_query_heads=1,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[1] * 8,
            cross_encoder_schema=[0] * 8,
            experts_schema=None,#[2, 2],
            block_io_schema=None,#[[1024, 1024*4, 1024]],
        )
    )

In [None]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base")
    model = transformers.RobertaForSequenceClassification.from_pretrained(
        "roberta-base",
        num_labels=NUM_LABELS
    )

In [None]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.BertTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
    model = transformers.BertForSequenceClassification.from_pretrained(
        "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
        num_labels=NUM_LABELS
    )

In [None]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
    model = transformers.BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels=NUM_LABELS
    )

In [None]:
# all labels: datasets/wordpiece_abstracts_train_all_labels.csv  30 labels
# 2 labels: datasets/wordpiece_abstracts_train_side_label_1.csv  20 labels
# main label only: datasets/wordpiece_abstracts_train.csv        10 labels

DS_TRAIN_PATH = "datasets/wordpiece_abstracts_train_all_labels.csv"
DS_TEST_PATH = "datasets/wordpiece_abstracts_test_all_labels.csv"

In [None]:
dataset = DatasetDict({
    "train": load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"]),
    "test": load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"]),
})
labels = [label for label in dataset['train'].features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

In [None]:
encoded_dataset = preprocess_with_given_labels_train_test_wrap(dataset, tokenizer, labels, label2id, MAX_POSITION_EMBEDDINGS, False, remove_columns=dataset["train"].column_names, 
                                                               default_teacher_forcing=False, teacher_forcing_prefix=None, doc_pad_tokens=DOC_PAD_TOKENS)

In [None]:
batch_schema = list(encoded_dataset["train"].features.keys())
print(batch_schema)
train_dataloader = BatchBuffer(encoded_dataset["train"], TRAIN_BATCH_SIZE)
if SHUFFLE_CUSTOM_DATALOADER:
    train_dataloader.shuffle()
test_dataloader = BatchBuffer(encoded_dataset["test"], TEST_BATCH_SIZE)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

In [None]:
if CHECKPOINT_PATH is not None:
    try:
        os.mkdir(CHECKPOINT_PATH)
    except OSError as err:
        print(err)

In [None]:
if IS_HF_MODEL:
    forward_args = ["input_ids", "attention_mask", "token_type_ids", "labels"]
else:
    forward_args = ["input_ids", "attention_mask", "decoder_input_ids", "labels"]

In [None]:
%%time
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    id2label=id2label,
    batch_schema=batch_schema,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    one_label_only=False,
    generic_output_class=True,
    calc_metrics=True,
    per_class_f1=True,

    forward_args=forward_args
)