In [95]:
import os
import json
from dataclasses import field, dataclass
from typing import Optional, List, Dict

from transformers import (
    TrainingArguments,
    HfArgumentParser,
    EarlyStoppingCallback,
    IntervalStrategy,
    AlbertTokenizer,
    BertTokenizer,
)
from torch.utils.data import Dataset
from transformers.utils import ModelOutput


In [96]:
# global constants for this script
SPAN_PAD = [3, 5]
LABEL_PAD = -100
LABEL_PAD_LIGHT = 0
UNK = 1
MAX_SENT_LENGTH = 500
THRESHOLD = 2
KEYS = ('content', 'label')
EXCEPT_KEYS = ('spans', 'input_lengths')

In [3]:
# arguments:
# ModelArguments for model initialize; ControlArguments for writing and reading related operation
# CustomTrainArguments inherited from Huggingface TrainArguments.
@dataclass
class ModelArguments:

  num_layers: int = field(default=2, metadata={'help': 'number of layers for RNNs'})
  input_dim: int = field(default=100, metadata={'help': 'number of dimension for word tokens'})
  hidden_dim: int = field(default=768)
  vector_dim: int = field(default=100)
  char_hidden_dim: int = field(default=50)
  model_path: str = field(default='cache/albert-English')
  pretrained_tokenizer: bool = field(default=False)
  word_vector: str = field(default='word2vector/glove.6B.100d.txt')

@dataclass
class CustomTrainArguments(TrainingArguments):
  evaluation_strategy: str = field(default='steps')
  output_dir: str = field(default='saved_model')
  per_device_train_batch_size: int = field(default=16)
  learning_rate: float = field(default=5e-5)
  num_train_epochs: int = field(default=5)
  eval_steps: int = field(default=5)
  save_total_limit: int = field(default=5)
  load_best_model_at_end: bool = field(default=True)
  metric_for_best_model: str = field(default='accuracy')

@dataclass
class ControlArguments:
  log_dir: str = field(default='log/', metadata={'help': 'directory to save log files'})
  dataset_path: str = field(default='spam-test/')
  reader: str = field(default='csv', metadata={'help': 'methods used to read source file'})
  task_type: str = field(default='sent')
  save_to: str = field(default='saved_model', metadata={'help': 'directory used to save model'})
  label_mapping: str = field(default='label_mapping.json', metadata={'help': 'saved label mapping for some specific tasks with string label'})
  best_model_path: str = field(default='saved_model/roberta-chinese_5.pth')
  given_best_model: bool = field(default=False, metadata={'help': 'if give best model path'})
  is_light: bool = field(default=True, metadata={'help': 'whether train in light mode '})
  cached_tokenizer: str = field(default='cache/cache_tokenizer.bin')
  seed: int = field(default=42)



control_arguments, train_arguments, model_arguments = ControlArguments(), CustomTrainArguments(), ModelArguments()

# init_args()



In [97]:
from collections.abc import Iterable

import pandas as pd

# initialize Dataset class and related methods
def csv_reader(file_path, keys=KEYS):
    """
    generate data as described in keys
    """
    data = pd.read_csv(file_path)
    if len(keys) > 2:
        data[keys[1]] = data[keys[1]].apply(eval)
        data[keys[2]] = data[keys[2]].apply(eval)
    data_ = []
    for key in keys:
        if key in data.columns:
            data_.append(list(data[key]))
        else:
            raise KeyError(f"{key} doesn't exist in source csv")
    return data_


class SentDataset(Dataset):
    def __init__(self, dataset_path, reader, label2idx=None):
        self.all_samples, self.labels = reader(dataset_path)
        if label2idx is not None:
            self.label2idx = label2idx
        else:
            self.label2idx = {}
            self._gen_label2idx()

    def _gen_label2idx(self):
        for label in self.labels:
            if isinstance(label, Iterable) and not isinstance(label, str):
                for inner_label in label:
                    if inner_label not in self.label2idx:
                        self.label2idx[inner_label] = len(self.label2idx)
            else:
                if label not in self.label2idx:
                    self.label2idx[label] = len(self.label2idx)

    def __getitem__(self, item):
        label_cur = self.labels[item]
        if isinstance(label_cur, Iterable) and not isinstance(label_cur, str):
            label_idx_cur = [self.label2idx.get(i, None) for i in label_cur]
            for i in label_idx_cur:
                if i is None:
                    raise KeyError(f"found unexisted key in {label_cur}")
        else:
            label_idx_cur = self.label2idx.get(label_cur, None)
            if label_idx_cur is None:
                raise KeyError(f"{label_cur} doesn't exist in label list")
        return self.all_samples[item], label_idx_cur

    def __len__(self):
        return len(self.all_samples)


class DocDataset(Dataset):
    def __init__(self, dataset_path, reader):
        self.all_samples, self.labels, self.entity_spans = reader(dataset_path)
        self.label2idx = {}
        for label in self.labels:
            if isinstance(label, Iterable) and not isinstance(label, str):
                for inner_label in label:
                    if inner_label not in self.label2idx:
                        self.label2idx[inner_label] = len(self.label2idx)
            else:
                if label not in self.label2idx:
                    self.label2idx[label] = len(self.label2idx)

    def __getitem__(self, item):
        label_cur = self.labels[item]
        if isinstance(label_cur, Iterable) and isinstance(label_cur, str):
            label_idx_cur = [self.label2idx.get(i, None) for i in label_cur]
            for i in label_idx_cur:
                if i is None:
                    raise KeyError(f"found unexisted key in {label_cur}")
        else:
            label_idx_cur = self.label2idx.get(label_cur, None)
            if label_idx_cur is None:
                raise KeyError(f"{label_cur} doesn't exist in label list")
        return self.all_samples[item], label_idx_cur, self.entity_spans[item]

    def __len__(self):
        return len(self.all_samples)



In [76]:
# model architecture.
import torch
import torch.nn as nn
from transformers import (
    AutoModel,
    AutoModelForSequenceClassification,
    BigBirdModel
)


class PretrainSentModel(nn.Module):
    def __init__(self, model_path, hidden_dim, num_labels):
        super(PretrainSentModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.encoder = AutoModel.from_pretrained(model_path)
        # self.encoder = BigBirdModel.from_pretrained(model_path)
        self.linear_layer = nn.Linear(hidden_dim, num_labels)
        self.num_labels = num_labels
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                labels=None,
                spans=None,
                output_attentions=None,
                return_dict=None
                ):
        # (loss[optional], logit, hidden_states[optional], output_attentions[optional]
        output = self.encoder(input_ids=input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids,
                              position_ids=position_ids,
                              return_dict=return_dict)
        sequence_output = output[0]  # (batch_size, seq_len, hidden_dim)
        last_hidden_state = sequence_output[:, -1, :]  # (batch_size, hidden_dim)
        # batch_size, num_span = sequence_output.size(0), len(spans[0])

            

        pred_ = self.linear_layer(last_hidden_state)  # (batch_size, num_categories)
        if labels is not None:
            loss = self.loss_fn(pred_, labels)
        else:
            loss = None
        return ModelOutput(loss=loss, logits=pred_)

    def dynamic_quantization(self):
        quantized_model = torch.quantization.quantize_dynamic(self.encoder, {torch.nn.Linear}, dtype=torch.qint8)
        setattr(self, 'encoder', quantized_model)


class LongFormer(nn.Module):
    def __init__(self, model_path, hidden_dim, num_labels):
        super(LongFormer, self).__init__()
        self.hidden_dim = hidden_dim
        self.encoder = BigBirdModel.from_pretrained(model_path)
        self.linear_layer = nn.Linear(hidden_dim, num_labels)
        self.num_labels = num_labels

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                labels=None,
                spans=None,
                output_attentions=None,
                return_dict=None
                ):
        # (loss[optional], logit, hidden_states[optional], output_attentions[optional]
        output = self.encoder(input_ids=input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids,
                              position_ids=position_ids,
                              return_dict=return_dict)
        sequence_output = output[0]
        batch_size, num_span = sequence_output.size(0), len(spans[0])
        entity_embedding = torch.rand(batch_size, num_span, self.hidden_dim)
        for idx, span_items in enumerate(spans):
            for idx_span, span_item in enumerate(span_items):
                entity_rep = sequence_output[idx, span_item[0]:span_item[1]]
                entity_embedding[idx, idx_span, :] = torch.mean(entity_rep, dim=0)
        # (batch_size, num_spans, hidden_dim) -> (batch_size, num_spans, num_categories)
        pred_ = self.linear_layer(entity_embedding)
        pred_ = pred_.view(batch_size * num_span, -1)  # (batch_size, number of categories)
        return {'logits': pred_}


In [98]:
DATACLASS = {
    'sent': SentDataset,
    'doc-span': DocDataset
}

READER = {
    'csv': csv_reader
}

MODELS = {
    'sent': PretrainSentModel,
    'doc-span': LongFormer
}


In [99]:


import pickle
from collections import Counter

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

class PreTraining:
    def __init__(self, train_arguments, model_arguments, control_arguments):
        # model config
        self.hidden_dim = model_arguments.hidden_dim
        self.batch_size = train_arguments.per_device_train_batch_size
        self.model_path = model_arguments.model_path
        # self.tokenizer = AlbertTokenizer.from_pretrained(self.model_path)
        self.save_to = control_arguments.save_to

        # config for setting pipeline
        self.dataset_path = control_arguments.dataset_path
        self.reader = READER[control_arguments.reader]
        self.Dataset = DATACLASS[control_arguments.task_type]
        self.task_type = control_arguments.task_type

        # training arguments
        self.train_arguments = train_arguments

        # post-processing after initialization
        self.train, self.val, self.test = self.init_dataset()
        if os.path.exists(control_arguments.label_mapping):
            self.label2idx = json.load(open(control_arguments.label_mapping, 'r', encoding='utf-8'))
            self.train.label2idx = self.label2idx
        else:
            self.label2idx = self.train.label2idx
        #     json.dump(self.label2idx, open('label_mapping.json', 'w', encoding='utf-8'))
        self.val.label2idx = self.label2idx
        self.test.label2idx = self.label2idx

        self.idx2label = {}
        for key, value in self.label2idx.items():
            self.idx2label[value] = key

    def create_loader(self, collate_fn=None, data_sampler=None):
        train_loader = DataLoader(self.train, batch_size=self.batch_size, collate_fn=collate_fn,
                                  sampler=data_sampler, shuffle=True)
        val_loader = DataLoader(self.val, batch_size=self.batch_size, collate_fn=collate_fn, shuffle=True)
        test_loader = DataLoader(self.test, batch_size=self.batch_size, collate_fn=collate_fn, shuffle=True)
        return train_loader, val_loader, test_loader

    def init_dataset(self):
        train_dataset = self.Dataset(os.path.join(self.dataset_path, 'val.csv'), self.reader)
        val_dataset = self.Dataset(os.path.join(self.dataset_path, 'val.csv'), self.reader)
        test_dataset = self.Dataset(os.path.join(self.dataset_path, 'val.csv'), self.reader)
        return train_dataset, val_dataset, test_dataset

    def prepare_model(self):
        if 'albert' in self.model_path or 'Albert' in self.model_path:
            tokenizer = AlbertTokenizer.from_pretrained(self.model_path)
        else:
            tokenizer = BertTokenizer.from_pretrained(self.model_path)
        try:
            model_class = MODELS[self.task_type]
            model = model_class(self.model_path, self.hidden_dim, len(self.label2idx))
        except KeyError:
            raise f"{self.task_type} doesn't have a corresponding model "
        return tokenizer, model

    def prepare_optimizer(self, model, data_loader):
        all_steps = self.epoch * len(data_loader)
        optimizer = optim.AdamW(model.parameters(), lr=self.lr, eps=self.adam_eps)
        # optimizer = optim.SGD(model.parameters(), lr=self.lr)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_step, num_training_steps=all_steps)
        loss_fn = nn.CrossEntropyLoss(ignore_index=LABEL_PAD)
        return optimizer, scheduler, loss_fn




In [122]:
# Custom hugging face trainner class
from typing import Optional, List, Dict
from transformers import Trainer
from sklearn.metrics import accuracy_score, classification_report

class CustomTrainer(Trainer):
  def compute_loss(self, model, inputs, return_outputs=False):
    labels = inputs.get('labels')
    logit = model(**inputs).get('logits')
    loss_fct = nn.CrossEntropyLoss()
    loss = loss_fct(logit.view(-1, self.model.num_labels), labels.view(-1))
    return (loss, logit) if return_outputs else loss
  
  def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
      self._memory_tracker.start()
      all_pred, all_target = [], []
      report = {}
      correct = 0
      local_rank = self.args.local_rank
      dataloader = self.get_eval_dataloader(eval_dataset=eval_dataset)
      for batch_data in dataloader:
          target = batch_data['labels']
          output = self.model(**batch_data)
          pred_score = output.get('logits')
          pred_class = torch.argmax(pred_score, dim=-1)
          correct += torch.sum(pred_class==target).item()
          assert target.size(0) == pred_class.size(0)
          all_pred.append(pred_class)
          all_target.append(target)
      all_pred, all_target = torch.cat(all_pred, dim=0), torch.cat(all_target, dim=0)
      if all_pred.is_cuda:
          all_pred, all_target = all_pred.cpu().numpy(), all_target.cpu().numpy()
      else:
          all_pred, all_target = all_pred.numpy(), all_target.numpy()
      all_num = all_pred.shape[0]
    # pdb.set_trace()
      indies = all_pred == all_target
      correct_array = all_pred[indies]
      correct_num = np.sum(indies)
      acc = correct_num / all_num
      report['accuracy'] = acc
      print(f'number of correct: {correct}')
      print(f'number of correct: {correct_num}; number of all samples: {all_num} ')
          
      return report
  

      
def compute_metric(eval):
  logit, label = eval
  correct = 0
  pred_class = torch.argmax(logit, dim=-1)
  assert logit.size(0) == label.size(0)
  correct += torch.sum(pred_class == label).item()
  if logit.is_cuda:
    pred_class, label = pred_class.cpu().numpy(), label.cpu().numpy()
  else:
    pred_class, label = pred_class.numpy(), label.numpy()
  acc = accuracy_score(label, pred_class)
  report = classification_report(label, pred_class)
  report['accuracy'] = acc
  return report

In [123]:

import torch
import pandas as pd

from utilis.constants import SPAN_PAD, LABEL_PAD, MAX_SENT_LENGTH


class CollateFn:
    def __init__(self,
                 tokenizer,
                 label2idx,
                 idx2label=None,
                 is_split=False,
                 task_type='doc-span'
                 ):
        self.tokenizer = tokenizer
        self.label2idx = label2idx
        if idx2label:
            self.idx2label = idx2label
        else:
            self.idx2label = {}
            for label, idx in label2idx.items():
                self.idx2label[idx] = label
        self.is_split = is_split
        self.task_type = task_type

    def __call__(self, batch_data):
        batchfy_input, batch_data_sep = self.processing(batch_data)
        batchfy_input = self.post_process(batchfy_input, batch_data_sep)
        return batchfy_input

    def processing(self, batch_data):
        # (all_text, all_label, entity_spans[optional])
        batch_data_sep = _pre_processing(batch_data, task_type=self.task_type)
        batchfy_input = self.tokenizer(batch_data_sep[0],
                                       is_split_into_words=self.is_split,
                                       truncation=True,
                                       padding=True,
                                       return_tensors='pt',
                                       max_length=MAX_SENT_LENGTH,
                                       )
        return batchfy_input, batch_data_sep

    def post_process(self, batchfy_input, batch_data_sep):
        if self.task_type == 'doc-span':

            # pad_labels, pad_spans = _padding_token(all_labels, all_spans)
            pad_labels, pad_spans = _padding_entity(batchfy_input[1], batchfy_input[2])
            # pad_spans_df = _span_to_csv(pad_spans)
            batchfy_input['labels'] = torch.tensor(pad_labels, dtype=torch.long)
            batchfy_input['spans'] = pad_spans
            return batchfy_input
        else:
            batchfy_input['labels'] = torch.tensor(batch_data_sep[1], dtype=torch.long)

        return batchfy_input

def _pre_processing(batch_data, task_type):
    """
    processing batch data from dataloader before feeding into model depends on task type
    return:
    all_text: list[str]
    all_label: list[int] (see this task as a token classification tasks)
    entity_spans: pd.Dataframe (in this way, model can select list of indexes effciently
    """
    all_text, all_label = [], []
    if task_type == 'doc-span':
        entity_spans = []
    else:
        entity_spans = None
    for zip_sample in batch_data:
        text, label = zip_sample[0], zip_sample[1]
        all_text.append(text)
        all_label.append(label)
        if entity_spans is not None:
            entity_spans.append(zip_sample[2])
    # spans_csv = pd.DataFrame(entity_spans, columns=[for ])
    return (all_text, all_label, entity_spans) if entity_spans else (all_text, all_label)


def _padding_token(labels, spans):
    """
    padding labels into same length of spans and covert label format to token classification format
    """
    max_len = max(map(len, labels))
    pad_spans = []
    label_in_token = []
    for label_item, span_item in zip(labels, spans):
        cur_label_token = []
        for label_inner, span_inner in zip(label_item, span_item):
            cur_label_token += [label_inner] * (span_inner[1] - span_inner[0])
        label_in_token.append(cur_label_token)
        pad_num = (max_len - len(label_item))
        pad_spans.append(span_item + [SPAN_PAD] * pad_num)
    max_len_label = max(map(len, label_in_token))
    pad_label_token = []
    for label_token_item in label_in_token:
        pad_num = max_len_label - len(label_token_item)
        pad_label_token.append(label_token_item + [LABEL_PAD] * pad_num)
    return pad_label_token, pad_spans


def _padding_entity(labels, spans):
    """
    padding labels and span's entity in a classification format, using average
    entity representation to do prediction
    """
    max_len = max(map(len, labels))
    pad_labels, pad_spans = [], []
    for label_item, span_item in zip(labels, spans):
        pad_num = max_len - len(label_item)
        pad_labels.append(label_item + [LABEL_PAD] * pad_num)
        pad_spans.append(span_item + [SPAN_PAD] * pad_num)
    return pad_labels, pad_spans



In [124]:
import random

import numpy as np

def fix_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False

def train(train_arguments, model_arguments, control_arguments):
    fix_seed(control_arguments.seed)

 
    prepare = PreTraining(train_arguments, model_arguments, control_arguments)

    setattr(train_arguments, 'label_names', list(prepare.idx2label.keys()))
    # device = torch.device(control_arguments.device)
    data_sampler = None
    tokenizer, model = prepare.prepare_model()
    
    collate_fn = CollateFn(tokenizer, prepare.label2idx, task_type=prepare.task_type)
    # train_loader, val_loader, test_loader = prepare.create_loader(collate_fn=collate_fn, data_sampler=data_sampler)
    # model.to(device=device)
    # optimizer, scheduler, loss_fn = prepare.prepare_optimizer(model, train_loader)
    def preprocess_logits_for_metrics(logit, labels):
        if logit is tuple:
            return logit[0]
        else:
            return logit   

    trainer = CustomTrainer(
            model=model,
            args=train_arguments,
            train_dataset=prepare.train,
            eval_dataset=prepare.train,
            tokenizer=tokenizer,
            data_collator=collate_fn,
            compute_metrics=compute_metric,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            
        )
    return trainer


In [125]:
trainer = train(train_arguments, model_arguments, control_arguments)

In [126]:
trainer.train()

Step,Training Loss,Validation Loss


number of correct: 20
number of correct: 20; number of all samples: 50 
number of correct: 20
number of correct: 20; number of all samples: 50 
number of correct: 20
number of correct: 20; number of all samples: 50 
number of correct: 20
number of correct: 20; number of all samples: 50 
number of correct: 20
number of correct: 20; number of all samples: 50 


TrainOutput(global_step=20, training_loss=2.132588005065918, metrics={'train_runtime': 217.297, 'train_samples_per_second': 1.15, 'train_steps_per_second': 0.092, 'total_flos': 0.0, 'train_loss': 2.132588005065918, 'epoch': 5.0})