In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive/')
    !pip install tweet_preprocessor
    !pip install attrdict
    !pip install transformers
except: pass

In [None]:
import torch
import yaml
import torch
import random 
import numpy as np 
import pandas as pd
import os
import shutil
from attrdict import AttrDict
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModel , 
    AdamW, get_linear_schedule_with_warmup)
from torch import Tensor
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
from posixpath import join as pathjoin

config_path = '../config/%s.yaml'

def check_device_match(device):
    if device == 'cuda':
        assert torch.cuda.is_available()
    return 

def load_args(module):
    with open(config_path % module, 'r') as f:
        args = AttrDict(yaml.safe_load(f))

    device = args.device
    check_device_match(device)
    print('device: {}'.format(device))

    return args

In [None]:
def init_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    return 

In [None]:
class ZipDataset(Dataset):
  def __init__(self, datasets):
    super(ZipDataset, self).__init__()
    self.keys = list(datasets.keys())
    self.values = list(datasets.values()) 
    self.datasets = datasets
    assert all([len(self.values[0]) == len(v) for v in self.values])

  def __len__(self):
    return len(self.values[0])
  
  def __getitem__(self, idx):
    item = {}
    for k, v in self.datasets.items():
      item[k] = v[idx]
    return item

In [None]:
# transformer specific
def tokenize_bert_inputs(texts, tokenizer, maxlen=100):
    return tokenizer.batch_encode_plus(
        texts, 
        return_attention_mask=True, 
        return_token_type_ids=False,
        padding='max_length',
        truncation='longest_first', 
        max_length=maxlen
    )
    
def _generate_bert_dataset(
    X, 
    y, 
    tokenizer,
    sentence_max_len = 82,
    split=None
):
    if split is None:
        X_tk = tokenize_bert_inputs(
            X.tolist(), 
            tokenizer=tokenizer, 
            maxlen=sentence_max_len
        )
        dataset = ZipDataset({
            'input_ids': Tensor(X_tk['input_ids']).type(torch.int32), 
            'attention_mask': Tensor(X_tk['attention_mask']).type(torch.int32), 
            'y_true': Tensor(y).type(torch.float32)
        })
        return dataset
    else:
        X1, X2, y1, y2 = train_test_split(
            X, y, 
            test_size=split, 
            shuffle=False,
        )
        d1 = _generate_bert_dataset(X1, y1, tokenizer, sentence_max_len, split=None)
        d2 = _generate_bert_dataset(X2, y2, tokenizer, sentence_max_len, split=None)
        return (d1, d2)

def generate_bert_dataset(
    dataset_path, 
    tokenizer,
    emotions, 
    sentence_max_len=82,
    split= None
):
    D = pd.read_csv(dataset_path)
    X = D['text'].to_numpy()
    y = D[list(emotions)].to_numpy()

    return _generate_bert_dataset(
        X, y, 
        tokenizer, 
        sentence_max_len, 
        split
    )

In [None]:
# models
class TransformerEncoderBase(torch.nn.Module):
    def __init__(self, 
        encoder, 
        criterion, 
        config=None
    ) -> None:
        super().__init__()
        self.set_encoder(encoder)
        self.criterion = criterion
        self.config = config if config else {} # put all model argument except encoder, encoder_dim in the 

    def unset_encoder(self):
        tmp = self.encoder
        self.set_encoder(None)
        return tmp

    def set_encoder(self, encoder):
        self.encoder = encoder 
        self.encoder_dim = encoder.config.hidden_size if encoder else None
        return

    def forward(self, input_ids, attention_mask):
        return self.encoder(input_ids, attention_mask,)

    def save_pretrained(self, path):
        encoder_path = pathjoin(path, 'encoder')
        pt_path = pathjoin(path, 'model.pt')
        encoder = self.unset_encoder()
        encoder.save_pretrained(encoder_path)
        torch.save({
            'model': self.state_dict(), 
            'config': {
                'encoder': encoder.config.to_dict(), 
                'criterion': self.criterion.__class__.__name__,
                'architecture': str(self), 
                'module':self.config, 
            }
        }, pt_path)
        self.set_encoder(encoder)
        return 
    
    @classmethod
    def from_pretrained(cls, path):
        encoder_path = pathjoin(path, 'encoder')
        pt_path = pathjoin(path, 'model.pt')

        checkpoint = torch.load(pt_path, map_location=torch.device('cpu'))
        config = checkpoint['config']
        encoder = AutoModel.from_pretrained(encoder_path)
        
        model = cls(
            encoder=encoder,
            criterion = getattr(torch.nn, config['criterion'])(), 
            **config['module']
        )
        model.load_state_dict(checkpoint['model'])
        return model

    def load_state_dict(self, *args, **kwargs):
        encoder = self.unset_encoder()
        super().load_state_dict(*args, **kwargs)
        self.set_encoder(encoder) 
        return 

def check_model_same(m1, m2):
    m1_state = m1.state_dict()
    m2_state = m2.state_dict()

    assert set(m1_state.keys()) == set(m2_state.keys())
    
    for k in m1_state.keys():
        assert torch.equal(m1_state[k], m2_state[k]), '{} mismatch'.format(k)

    return True


In [None]:
# train, test, evaluation, metrics
def compute_classification_metrics(y_true, proba, threshold):
    assert len(y_true) == len(proba), 'y_true and y_pred length mismatch {} {}'.format(len(y_true), len(threshold))
    labels=np.array(range(y_true.shape[-1]))

    results = {}
    y_true = y_true.astype(int)
    y_pred = (proba >= threshold).astype(int)

    results["accuracy"] = (y_true == y_pred).mean()
    try:
      results["auc_roc_macro"] = roc_auc_score(y_true, proba, average='macro')
      results["auc_roc_micro"] = roc_auc_score(y_true, proba, average='micro')
    except: pass
    results["macro_precision"], results["macro_recall"], results["macro_f1"], _ = precision_recall_fscore_support(y_true, y_pred, average="macro", labels=labels, zero_division=0)
    results["micro_precision"], results["micro_recall"], results["micro_f1"], _ = precision_recall_fscore_support(y_true, y_pred, average="micro", labels=labels, zero_division=0)
    results["weighted_precision"], results["weighted_recall"], results["weighted_f1"], _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", labels=labels, zero_division=0)

    return results

In [None]:
def predict_proba(
    model, 
    dataset, 
    batch_size=16, 
    device='cpu', 
    back_to_cpu=True):

  eval_dataloader = DataLoader(
      dataset, 
      batch_size=batch_size, 
  )

  n_batch = 0
  proba = []

  model.to(device)

  for batch in tqdm(eval_dataloader, desc='evaluation', leave=False):
    model.eval()
    batch = { k:v.to(device) for k, v in batch.items() if k != 'y_true'}

    with torch.no_grad():
      logits = model(**batch)
      logits = logits.cpu().detach().numpy()

    p = 1 / (1 + np.exp(-logits))
    proba.append(p)

    n_batch += 1

  if back_to_cpu:
    model.cpu()

  proba = np.vstack(proba)
  return proba

def predict_proba_examples(X_tk, model):    
    with torch.no_grad():
        X_tk = { k:v for k, v in X_tk.items() if k != 'y_true'}
        logits = model(**X_tk).numpy()
        proba = 1 / (1 + np.exp(-logits))
    
    return proba

def proba_to_emotion(proba, threshold, emotions):
  assert proba.shape[-1] == len(emotions), 'emotions and proba mismatch {} vs {}'.format(len(emotions), proba.shape[-1])
  emotions = np.array(emotions)

  return [list(zip(list(emotions[p >= threshold]),p[p >= threshold])) for p in proba]

In [None]:
def evaluate(model, 
             dataset, 
             batch_size=16, 
             threshold=0.5,
             device='cpu', 
             sample_ratio=None,
             shuffle=False, 
             back_to_cpu=True):

  eval_dataloader = DataLoader(
      dataset, 
      batch_size=batch_size,
      shuffle=shuffle,  
  )

  max_batch = int(len(eval_dataloader) * sample_ratio) \
    if sample_ratio else len(eval_dataloader)
  n_batch = 0
  total_loss = 0.0
  y_true = []
  proba = []

  model.to(device)

  for batch in tqdm(eval_dataloader, desc='evaluation', total=max_batch, leave=False):
    model.eval()
    batch = { k:v.to(device) for k, v in batch.items() }

    with torch.no_grad():
      loss_per_batch, logits = model(**batch)
      total_loss += loss_per_batch.item()

      logits = logits.cpu().detach().numpy()

    p = 1 / (1 + np.exp(-logits))
    proba.append(p)
    y_true.append(batch['y_true'].cpu().detach().numpy())

    n_batch += 1
    if n_batch >= max_batch:
      break

  if back_to_cpu:
    model.cpu()

  proba = np.vstack(proba)
  y_true = np.vstack(y_true)
  results = {
      'loss': total_loss / n_batch, 
      'trigger_rate': (proba >= threshold).mean(), 
      **compute_classification_metrics(y_true, proba, threshold)
  }

  return results

In [None]:
def save_checkpoint(
    model, 
    archive_dir, 
    model_name, 
    checkpoint_id="null-model", 
    metadata=None,
    tokenizer=None, 
    optimizer=None, 
    scheduler=None,
):
  # create archive folder
  archive_path = pathjoin(archive_dir, model_name)
  if not os.path.exists(archive_path):
    os.makedirs(archive_path, exist_ok=True)

  # create checkpoint folder
  checkpoint_dir = pathjoin(archive_path, 'checkpoint-%s' % str(checkpoint_id))
  os.makedirs(checkpoint_dir, exist_ok=True)

  # save model in checkpoint
  model_to_save = (model.module if hasattr(model, "module") else model)
  model_to_save.save_pretrained(checkpoint_dir)
  if tokenizer is not None:
    tokenizer.save_pretrained(checkpoint_dir)
  if metadata:
    torch.save(metadata, pathjoin(checkpoint_dir, "meta.bin"))
  if scheduler is not None:
    torch.save(scheduler.state_dict(), pathjoin(checkpoint_dir, 'scheduler.pt'))
  if optimizer is not None:
    torch.save(optimizer.state_dict(), pathjoin(checkpoint_dir, 'optimizer.pt'))

  return archive_path

def load_from_checkpoint(
    archive_dir, 
    model_name, 
    checkpoint_id="null-model", 
    load_tokenizer=False, 
    load_metadata=True, 
    load_optimizer=False, 
    load_model=True,
    model_cls=TransformerEncoderBase,
    tok_cls=AutoTokenizer
):
  archive_path = pathjoin(archive_dir, model_name)
  checkpoint_dir = pathjoin(archive_path, 'checkpoint-%s' % str(checkpoint_id))

  assert os.path.exists(archive_path), archive_path
  assert os.path.exists(checkpoint_dir), checkpoint_dir
  assert load_model or load_metadata or load_optimizer or load_tokenizer

  output = ()

  if load_model:
    model = getattr(model_cls, 'from_pretrained')(checkpoint_dir)
    output += (model, )

  if load_tokenizer:
    tokenizer = getattr(tok_cls, 'from_pretrained')(checkpoint_dir)
    output += (tokenizer, )

  if load_metadata or load_optimizer:
    metadata = torch.load(pathjoin(checkpoint_dir, 'meta.bin'))
    if load_metadata:
      output += (metadata, )

  if load_optimizer:
      grouped_parameters = [{'params': [param for name, param in model.named_parameters() \
                                          if not any(nd in name for nd in ('bias', 'LayerNorm.weight'))]}, 
                            {'params': [param for name, param in model.named_parameters() \
                                        if any(nd in name for nd in ('bias', 'LayerNorm.weight'))]}]

      optimizer = AdamW(grouped_parameters, 
                    lr=metadata['learning_rate'], 
                    weight_decay=metadata['weight_decay']) 
      
      scheduler = get_linear_schedule_with_warmup(
          optimizer,
          num_warmup_steps=int(metadata['train_max_step'] * metadata['warmup_ratio']),
          num_training_steps=metadata['train_max_step']
      )
      optimizer.load_state_dict(torch.load(pathjoin(checkpoint_dir, 'optimizer.pt')))
      scheduler.load_state_dict(torch.load(pathjoin(checkpoint_dir, 'scheduler.pt')))

      output += (optimizer, scheduler)
  return output

def get_attrs_from_checkpoints_meta(archive_dir, model_name, attrs, 
    ignore_null_model=True, 
    return_df=True,
    ):
    archive_path = pathjoin(archive_dir, model_name)

    results = { k: [] for k in attrs }
    results['checkpoint_id'] = []

    for dirname in os.listdir(archive_path):
        if 'checkpoint' in dirname and \
            (not ignore_null_model or (dirname != 'checkpoint-null-model')):
            checkpoint_id = dirname.replace('checkpoint-', "")
            (metadata,)  = load_from_checkpoint(
                archive_dir, 
                model_name, checkpoint_id, 
                load_tokenizer=False, 
                load_metadata=True, 
                load_optimizer=False, 
                load_model=False, 
            )
            results['checkpoint_id'].append(checkpoint_id)

            for attr in attrs:
                attr = attr if attr else None
                if attr and '.' in attr:
                    outer, inner = attr.split('.')[:2]
                    val = metadata[outer][inner]
                else:
                    val = metadata.get(attr)

                results[attr].append(val)

    if return_df and ignore_null_model:
        results['checkpoint_id'] = [int(k) for k in results['checkpoint_id']]
        results = pd.DataFrame(results)
        results = results.sort_values(by='checkpoint_id', ascending=True).reset_index(drop=True)

    return results

def clear_archive(archive_dir, model_name):
  archive_path = pathjoin(archive_dir, model_name)
  path_exists = os.path.exists(archive_path)
  if path_exists:
    shutil.rmtree(archive_path)

  os.makedirs(archive_path)
  return path_exists

In [None]:
def train_step(
    model, 
    scheduler,
    batch, 
    optimizer, 
    grad_clip_max=1
):
  model.train()

  loss, logits = model(**batch)
  loss.backward()

  if grad_clip_max is not None:
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max)

  optimizer.step()
  scheduler.step()
  model.zero_grad()

  return loss.detach().item(), logits

def train(
    model, 
    train_dataset, 
    val_dataset, 
    metadata,
    tokenizer = None, 
    epochs = 5,
    train_batch_size = 16, 
    val_batch_size = 16, 
    save_steps = 1e3, 
    validation_steps = 1e3, 
    archive_dir = None,
    model_name = 'model', # model & archive saved in archive_dir/model_name/..
    classification_threshold=0.5, 
    learning_rate = 1e-3, 
    grad_clip_max = 1, 
    weight_decay = 1e-5, 
    warmup_ratio=0.1,
    logging_metrics=None,
    optimizer=None, 
    scheduler=None, 
    continue_training=False,
    device = 'cpu'
):
  if torch.cuda.is_available():
    torch.cuda.empty_cache()

  # no weight decay on LayerNorm and bias
  grouped_parameters = [
    {'params': [param for name, param in model.named_parameters() \
               if not any(nd in name for nd in ('bias', 'LayerNorm.weight'))]}, 
    {'params': [param for name, param in model.named_parameters() \
               if any(nd in name for nd in ('bias', 'LayerNorm.weight'))]}
  ]

  train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size)  
  max_step = len(train_dataset) * epochs
  metadata['train_max_step'] = max_step

  if not continue_training:
    # use default b1, b2, eps
    optimizer = AdamW(grouped_parameters, 
                      lr=learning_rate, 
                      weight_decay=weight_decay) 
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(max_step * warmup_ratio),
        num_training_steps=max_step
    )
  else:
    assert optimizer is not None
    assert scheduler is not None

  total_steps = metadata['total_steps'] if continue_training else 0
  val_ratio = len(val_dataset) / len(train_dataset)

  model.zero_grad() 
  model.to(device)

  save_checkpoint(model, archive_dir, model_name, 
                  checkpoint_id='null-model', 
                  metadata=metadata,
                  tokenizer=tokenizer, 
                  optimizer=optimizer, 
                  scheduler=scheduler)

  for epoch in range(epochs):
    print('training epoch %d' % epoch)
    for batch in tqdm(train_dataloader, desc="Training", leave=None):
      batch = { k:v.to(device) for k, v in batch.items() }
      train_step(model, scheduler, batch, optimizer, grad_clip_max)

      total_steps += 1 
      
      if total_steps > 0 and (total_steps % validation_steps == 0 or \
                              total_steps % save_steps == 0):
        metadata['val_metrics'] = evaluate(
          model, val_dataset, val_batch_size, 
          classification_threshold, 
          device=device, back_to_cpu=False)
        metadata['tr_metrics'] = evaluate(
          model, train_dataset, val_batch_size, 
          classification_threshold, 
          device=device, sample_ratio=val_ratio, 
          back_to_cpu=False)

        print('evaluating at step %d' % total_steps)
        if logging_metrics is not None:
          print('val', { k: v for k, v in metadata['val_metrics'].items() if k in logging_metrics})
          print('tr', { k: v for k, v in metadata['tr_metrics'].items() if k in logging_metrics})

      if total_steps > 0 and total_steps % save_steps == 0:
        print('saving at step %d' % total_steps)
        metadata['total_steps'] = total_steps
        save_checkpoint(model, archive_dir, model_name, 
                        optimizer=optimizer, 
                        scheduler=scheduler, 
                        checkpoint_id=total_steps, 
                        metadata=metadata)
  model.cpu()
  return model, metadata