In [1]:
import glob
import time
import os
import datetime
import math
import datasets
from datasets import DatasetDict, load_dataset, Dataset
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch import nn
import transformers
from transformers import (
    AutoTokenizer,
    T5Tokenizer,
    T5ForSequenceClassification,
    T5Config,
    BertTokenizer, 
    RobertaTokenizer,
    get_linear_schedule_with_warmup
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass

from load_set import *
from epoch_stats import EpochStats#, print_stat_tuples
import model_training
from model_training import (
    BatchBuffer, 
    train_bern_model, 
    mask_tokens, 
    preprocess_with_given_labels, 
    num_parameters, 
    num_trainable_parameters, 
    preprocess_for_causallm, 
    preprocess_for_multiple_choice,
    preprocess_for_seq2seq_swag,
    preprocess_with_given_labels_train_test_wrap
)
import coin_i3C_modeling as ci3C

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
CHECKPOINT_PATH = None#"hyp_cls/BERT-base_all-labels/" 
SHUFFLE_CUSTOM_DATALOADER = True
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 10

In [3]:
VOCAB_SIZE = 30_522
MAX_POSITION_EMBEDDINGS = 512#24**2
HIDDEN_SIZE = 1024
IS_HF_MODEL = False
GENERIC_OUTPUT_CLASS = True
DOC_PAD_TOKENS = False

NUM_LABELS = 30

In [4]:
if 0:
    IS_HF_MODEL = True
    tokenizer = transformers.DebertaTokenizer.from_pretrained("microsoft/deberta-base")
    model = transformers.DebertaForSequenceClassification.from_pretrained(
        "microsoft/deberta-base",
        num_labels=NUM_LABELS
    )

In [5]:
if 1:
    rffn_base_model_path = "tmp_models/COIN-i3C_mcca-translation-en-de_0029-500k_1x2_1dec-none_no-revert_chunkwise_group-exp_congen-head_B10_multi-query-2_switch-ii_mpe576_no-cross-att/"
    tokenizer = BertTokenizer.from_pretrained(f"{rffn_base_model_path}/wordpiece_tokenizer/")
    #tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = ci3C.COINForSequenceClassification(
        config=ci3C.COINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            forward_method="parallel",
            apply_decay=True,
            num_decay_parts=1,
            hidden_retention_act="relu",
            hidden_pos_offset=False,
            rope_dim=32,
            num_query_heads=1,

            decoder_output="none",
            revert_decoder=False,
            decoder_schema=[0] * 4,
            cross_encoder_schema=[0] * 4,
            experts_schema=None,#[2, 2],
            block_io_schema=None,#[[1024, 1024*4, 1024]],
        )
    )



num layers: 4
gamma schema: [0.96875, 0.9875984191894531, 0.995078444480896, 0.998046875]
layer 0 num experts: 1
layer 1 num experts: 1
layer 2 num experts: 1
layer 3 num experts: 1


In [6]:
# all labels: datasets/wordpiece_abstracts_train_all_labels.csv  30 labels
# 2 labels: datasets/wordpiece_abstracts_train_side_label_1.csv  20 labels
# main label onyl: datasets/wordpiece_abstracts_train.csv        10 labels

DS_TRAIN_PATH = "datasets/wordpiece_abstracts_train_all_labels.csv"
DS_TEST_PATH = "datasets/wordpiece_abstracts_test_all_labels.csv"

In [7]:
dataset = DatasetDict({
    "train": load_set([DS_TRAIN_PATH], unused_fields=["head", "body", "strlabels"]),
    "test": load_set([DS_TEST_PATH], unused_fields=["head", "body", "strlabels"]),
})
labels = [label for label in dataset['train'].features.keys() if label not in ["text"]]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

loading files
    datasets/wordpiece_abstracts_train_all_labels.csv
loading files
    datasets/wordpiece_abstracts_test_all_labels.csv
{'a0': 0, 'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5, 'a6': 6, 'a7': 7, 'a8': 8, 'a9': 9, 'b0': 10, 'b1': 11, 'b2': 12, 'b3': 13, 'b4': 14, 'b5': 15, 'b6': 16, 'b7': 17, 'b8': 18, 'b9': 19, 'c0': 20, 'c1': 21, 'c2': 22, 'c3': 23, 'c4': 24, 'c5': 25, 'c6': 26, 'c7': 27, 'c8': 28, 'c9': 29}
{0: 'a0', 1: 'a1', 2: 'a2', 3: 'a3', 4: 'a4', 5: 'a5', 6: 'a6', 7: 'a7', 8: 'a8', 9: 'a9', 10: 'b0', 11: 'b1', 12: 'b2', 13: 'b3', 14: 'b4', 15: 'b5', 16: 'b6', 17: 'b7', 18: 'b8', 19: 'b9', 20: 'c0', 21: 'c1', 22: 'c2', 23: 'c3', 24: 'c4', 25: 'c5', 26: 'c6', 27: 'c7', 28: 'c8', 29: 'c9'}
['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9', 'c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']


In [8]:
encoded_dataset = preprocess_with_given_labels_train_test_wrap(dataset, tokenizer, labels, label2id, MAX_POSITION_EMBEDDINGS, False, remove_columns=dataset["train"].column_names, 
                                                               default_teacher_forcing=False, teacher_forcing_prefix=None, doc_pad_tokens=DOC_PAD_TOKENS)

Map (num_proc=4): 100%|██████████| 862/862 [00:01<00:00, 663.62 examples/s]
Map (num_proc=4): 100%|██████████| 92/92 [00:00<00:00, 131.19 examples/s]


In [9]:
batch_schema = list(encoded_dataset["train"].features.keys())
print(batch_schema)
train_dataloader = BatchBuffer(encoded_dataset["train"], TRAIN_BATCH_SIZE)
if SHUFFLE_CUSTOM_DATALOADER:
    train_dataloader.shuffle()
test_dataloader = BatchBuffer(encoded_dataset["test"], TEST_BATCH_SIZE)

['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'decoder_input_ids']


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_dataloader) / TRAIN_BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

In [11]:
if CHECKPOINT_PATH is not None:
    try:
        os.mkdir(CHECKPOINT_PATH)
    except OSError as err:
        print(err)

In [12]:
if IS_HF_MODEL:
    forward_args = ["input_ids", "attention_mask", "token_type_ids", "labels"]
else:
    forward_args = ["input_ids", "attention_mask", "decoder_input_ids", "labels"]

In [13]:
%%time
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    id2label=id2label,
    batch_schema=batch_schema,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    vocab_size=VOCAB_SIZE,
    print_status=True,
    is_hf_model=IS_HF_MODEL,
    checkpoint_path=CHECKPOINT_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    test_batch_size=TEST_BATCH_SIZE,
    one_label_only=False,
    generic_output_class=True,
    calc_metrics=True,
    per_class_f1=True,

    forward_args=forward_args
)



Training...
  Batch    17  of    862.    Elapsed:  0:00:01, Remaining:  0:00:50.
  Batch    34  of    862.    Elapsed:  0:00:01, Remaining:  0:00:49.
  Batch    51  of    862.    Elapsed:  0:00:02, Remaining:  0:00:48.
  Batch    68  of    862.    Elapsed:  0:00:02, Remaining:  0:00:47.
  Batch    85  of    862.    Elapsed:  0:00:03, Remaining:  0:00:46.
  Batch   102  of    862.    Elapsed:  0:00:03, Remaining:  0:00:45.
  Batch   119  of    862.    Elapsed:  0:00:04, Remaining:  0:00:44.
  Batch   136  of    862.    Elapsed:  0:00:04, Remaining:  0:00:43.
  Batch   153  of    862.    Elapsed:  0:00:05, Remaining:  0:00:42.
  Batch   170  of    862.    Elapsed:  0:00:06, Remaining:  0:00:41.
  Batch   187  of    862.    Elapsed:  0:00:06, Remaining:  0:00:40.
  Batch   204  of    862.    Elapsed:  0:00:07, Remaining:  0:00:39.
  Batch   221  of    862.    Elapsed:  0:00:07, Remaining:  0:00:38.
  Batch   238  of    862.    Elapsed:  0:00:08, Remaining:  0:00:37.
  Batch   255  of   