# Joint Modeling Notebook

Run the following cells to train a joint classifier

## 0. Imports

In [1]:
%load_ext autoreload
%autoreload 2
    
%load_ext tensorboard

import sys
sys.path.append('../jointclassifier/')
from joint_args import ModelArguments, DataTrainingArguments, TrainingArguments
from joint_dataloader import load_dataset
from joint_trainer import JointTrainer
from single_trainer import SingleTrainer
from joint_model_v1 import JointSeqClassifier

from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
import os

## 1. Initialize the Arguments

In [7]:
task = "formality+jokes"
data_dir = "../data/processed/"
model_name = "distilbert-base-cased"
model_nick = "distilbert"
output_dir = "../models/"
freeze_encoder = "False"
skip_preclassifier = "False"
train_jointly = "True"
epochs = "3"

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses([
    "--model_name_or_path",
    model_name,
    "--model_nick",
    model_nick,
    "--task",
    task,
    "--data_dir",
    data_dir,
    "--output_dir",
    os.path.join(output_dir, model_nick, task, 'joint'),
    "--cache_dir",
    os.path.join(output_dir, model_nick,"cache"),
    "--freeze_encoder",
    freeze_encoder,
    "--skip_preclassifier",
    skip_preclassifier,
    "--train_jointly",
    train_jointly,
    "--overwrite_cache",
    "--per_device_train_batch_size",
    "16",
    "--per_device_eval_batch_size",
    "16",
    "--max_seq_len",
    "64",
    "--gradient_accumulation_steps",
    "1",
    "--num_train_epochs",
    epochs,
    "--logging_steps",
    "2000",
    "--save_steps",
    "2000"
])


PyTorch: setting up devices


## 2. Load the Tokenizer

In [8]:
model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir,
                                         model_max_length = data_args.max_seq_len)
    

loading configuration file https://huggingface.co/distilbert-base-cased/resolve/main/config.json from cache at ../models/distilbert/cache/ebe1ea24d11aa664488b8de5b21e33989008ca78f207d4e30ec6350b693f073f.302bfd1b5e031cc1b17796e0b6e5b242ba2045d31d00f97589e12b458ebff27a
Model config DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.3.3",
  "vocab_size": 28996
}

loading configuration file https://huggingface.co/distilbert-base-cased/resolve/main/config.json from cache at ../models/distilbert/cache/ebe1ea24d11aa664488b8de5b21e33989008ca78f207d4e30ec6350b693f073f.302bfd1b5e031cc1b17796e0b6e5b242ba2045d31d00f97589e1

## 3. Load the datasets 
Note : Single for Joint Training, Dict for Separate Training

In [9]:
tasks = data_args.task.split('+')
train_dataset = load_dataset(data_args.data_dir, tokenizer, model_name=model_args.model_name_or_path, 
                            tasks=tasks, mode="train", n_proc=256)
dev_dataset = load_dataset(data_args.data_dir, tokenizer, model_name=model_args.model_name_or_path, 
                            tasks=tasks, mode="dev", n_proc=256)

100%|██████████| 663/663 [00:13<00:00, 50.96it/s] 


torch.Size([169735, 64]) torch.Size([169735, 64]) torch.Size([169735, 2]) torch.Size([169735])


100%|██████████| 1177/1177 [01:03<00:00, 18.52it/s]
 10%|▉         | 16/165 [00:00<00:00, 155.69it/s]

torch.Size([471222, 64]) torch.Size([471222, 64]) torch.Size([471222, 2]) torch.Size([471222])


100%|██████████| 165/165 [00:01<00:00, 104.57it/s]
  3%|▎         | 8/294 [00:00<00:03, 73.00it/s]

torch.Size([42434, 64]) torch.Size([42434, 64]) torch.Size([42434, 2]) torch.Size([42434])


100%|██████████| 294/294 [00:05<00:00, 51.17it/s]

torch.Size([117806, 64]) torch.Size([117806, 64]) torch.Size([117806, 2]) torch.Size([117806])





## 4. Initialize the Trainer and the Model

In [10]:
# Open TensorBoard
%tensorboard --logdir runs

In [11]:
print(f"Processing Joint Task : {tasks}")
model = JointSeqClassifier.from_pretrained(model_args.model_name_or_path,tasks=tasks, model_args=model_args,
                                                   task_if_single=None, joint = training_args.train_jointly)
trainer = JointTrainer([training_args,model_args, data_args], model, train_dataset, dev_dataset)
trainer.train()

Processing Joint Task : ['formality', 'jokes']


loading configuration file https://huggingface.co/distilbert-base-cased/resolve/main/config.json from cache at /home/nuwandavek/.cache/huggingface/transformers/ebe1ea24d11aa664488b8de5b21e33989008ca78f207d4e30ec6350b693f073f.302bfd1b5e031cc1b17796e0b6e5b242ba2045d31d00f97589e12b458ebff27a
Model config DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.3.3",
  "vocab_size": 28996
}

loading weights file https://huggingface.co/distilbert-base-cased/resolve/main/pytorch_model.bin from cache at /home/nuwandavek/.cache/huggingface/transformers/9c9f39769dba4c5fe379b4bc82973eb01297bd607954621434eb9f1bc85a23a0.06b428c8

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=29452.0, style=ProgressStyle(description_…

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f57dc916ac0>
Total eval batch size = 16


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=7363.0, style=ProgressStyle(description_w…






KeyError: 'formality'