# Joint Modeling Notebook

Run the following cells to train a joint classifier

## 0. Imports

In [4]:
%load_ext autoreload
%autoreload 2
    
%load_ext tensorboard

import sys
sys.path.append('../jointclassifier/')
from joint_args import ModelArguments, DataTrainingArguments, TrainingArguments
from joint_dataloader import load_dataset
from joint_trainer import JointTrainer
from single_trainer import SingleTrainer
from joint_model_v1 import JointSeqClassifier

from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## 1. Initialize the Arguments

In [5]:
task = "formality+jokes"
data_dir = "../data/processed/"
model_name = "distilbert-base-cased"
model_nick = "distilbert"
output_dir = "../models/"
freeze_encoder = "False"
skip_preclassifier = "False"
train_jointly = "True"
epochs = "3"
train_batch_size = "256"
eval_batch_size = "256"

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses([
    "--model_name_or_path",
    model_name,
    "--model_nick",
    model_nick,
    "--task",
    task,
    "--data_dir",
    data_dir,
    "--output_dir",
    os.path.join(output_dir, model_nick, task, 'joint'),
    "--cache_dir",
    os.path.join(output_dir,"cache"),
    "--freeze_encoder",
    freeze_encoder,
    "--skip_preclassifier",
    skip_preclassifier,
    "--train_jointly",
    train_jointly,
    "--overwrite_cache",
    "--per_device_train_batch_size",
    train_batch_size,
    "--per_device_eval_batch_size",
    eval_batch_size,
    "--max_seq_len",
    "64",
    "--gradient_accumulation_steps",
    "1",
    "--num_train_epochs",
    epochs,
    "--logging_steps",
    "100",
    "--save_steps",
    "100"
])


PyTorch: setting up devices
  return torch._C._cuda_getDeviceCount() > 0


## 2. Load the Tokenizer

In [6]:
model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir,
                                         model_max_length = data_args.max_seq_len)
    

loading configuration file https://huggingface.co/distilbert-base-cased/resolve/main/config.json from cache at ../models/cache/ebe1ea24d11aa664488b8de5b21e33989008ca78f207d4e30ec6350b693f073f.302bfd1b5e031cc1b17796e0b6e5b242ba2045d31d00f97589e12b458ebff27a
Model config DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.3.3",
  "vocab_size": 28996
}

loading configuration file https://huggingface.co/distilbert-base-cased/resolve/main/config.json from cache at ../models/cache/ebe1ea24d11aa664488b8de5b21e33989008ca78f207d4e30ec6350b693f073f.302bfd1b5e031cc1b17796e0b6e5b242ba2045d31d00f97589e12b458ebff27a
Model con

## 3. Load the datasets 
Note : Single for Joint Training, Dict for Separate Training

In [7]:
tasks = data_args.task.split('+')
train_dataset, idx_to_classes = load_dataset(data_args.data_dir, tokenizer, model_name=model_args.model_name_or_path, 
                            tasks=tasks, mode="train", n_proc=1024)
dev_dataset, _ = load_dataset(data_args.data_dir, tokenizer, model_name=model_args.model_name_or_path, 
                            tasks=tasks, mode="dev", n_proc=1024)

100%|██████████| 165/165 [00:05<00:00, 27.64it/s]


torch.Size([169735, 64]) torch.Size([169735, 64]) torch.Size([169735, 2]) torch.Size([169735])


100%|██████████| 294/294 [00:20<00:00, 14.41it/s]
 12%|█▏        | 5/41 [00:00<00:00, 43.90it/s]

torch.Size([471222, 64]) torch.Size([471222, 64]) torch.Size([471222, 2]) torch.Size([471222])


100%|██████████| 41/41 [00:01<00:00, 36.25it/s]
  0%|          | 0/73 [00:00<?, ?it/s]

torch.Size([42434, 64]) torch.Size([42434, 64]) torch.Size([42434, 2]) torch.Size([42434])


100%|██████████| 73/73 [00:02<00:00, 25.47it/s]

torch.Size([117806, 64]) torch.Size([117806, 64]) torch.Size([117806, 2]) torch.Size([117806])





## 4. Initialize the Trainer and the Model & Train!

In [7]:
# Open TensorBoard
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 9917), started 5:51:32 ago. (Use '!kill 9917' to kill it.)

In [18]:
print(f"Processing Joint Task : {tasks}")
model = JointSeqClassifier.from_pretrained(model_args.model_name_or_path,tasks=tasks, model_args=model_args,
                                                   task_if_single=None, joint = training_args.train_jointly)
trainer = JointTrainer([training_args,model_args, data_args], model, train_dataset, dev_dataset, idx_to_classes)
trainer.train()

Processing Joint Task : ['formality', 'jokes']
loading configuration file https://huggingface.co/distilbert-base-cased/resolve/main/config.json from cache at /home/vivek/.cache/huggingface/transformers/ebe1ea24d11aa664488b8de5b21e33989008ca78f207d4e30ec6350b693f073f.302bfd1b5e031cc1b17796e0b6e5b242ba2045d31d00f97589e12b458ebff27a
Model config DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.4.0.dev0",
  "vocab_size": 28996
}

loading weights file https://huggingface.co/distilbert-base-cased/resolve/main/pytorch_model.bin from cache at /home/vivek/.cache/huggingface/transformers/9c9f39769dba4c5fe379b4bc82973eb

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=921.0, style=ProgressStyle(description_wi…

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 100, epoch 0: f1 = 0.9202204838491258
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 200, epoch 0: f1 = 0.9305576779070397
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 300, epoch 0: f1 = 0.9356108974443103
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 400, epoch 0: f1 = 0.9388886133874716
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 500, epoch 0: f1 = 0.9409457077372437
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 500, epoch 0

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 700, epoch 0: f1 = 0.9426760688934019
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 800, epoch 0: f1 = 0.9446770012023191
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 900, epoch 0: f1 = 0.9460506854095714



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=921.0, style=ProgressStyle(description_wi…

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 900, epoch 0

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1100, epoch 1: f1 = 0.9470740695309554
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1200, epoch 1: f1 = 0.9478049245623469
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1300, epoch 1: f1 = 0.948101179358412
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1400, epoch 1: f1 = 0.9484118435100404
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1500, epoch 1: f1 = 0.9489761098382863
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 1500, epoch 1

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 1500, epoch 1

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1800, epoch 1: f1 = 0.9502943809304



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=921.0, style=ProgressStyle(description_wi…

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 1900, epoch 2: f1 = 0.950810329733419
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 1900, epoch 2

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 1900, epoch 2

***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 2200, epoch 2: f1 = 0.9511136685950362
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 2300, epoch 2: f1 = 0.9512224817264607
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 2400, epoch 2: f1 = 0.9512435596656639
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 2500, epoch 2: f1 = 0.9516163884126394
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Configuration saved in ../models/distilbert/formality+jokes/joint/config.json

Model weights saved in ../models/distilbert/formality+jokes/joint/pytorch_model.bin
Saving model checkpoint to ../models/distilbert/formality+jokes/joint
New best model saved at step 2600, epoch 2: f1 = 0.9516536196409374
***** Running Evaluation *****
Num examples = <torch.utils.data.dataset.TensorDataset object at 0x7f26f5c9df40>
Total eval batch size = 512


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=231.0, style=ProgressStyle(description_wi…

Best model still at step 2600, epoch 2





(2763, 0.2426984514953475)

## 5. Predict for a sentence

In [8]:
model = JointSeqClassifier.from_pretrained(training_args.output_dir,tasks=tasks, model_args=model_args,
                                                   task_if_single=None, joint = training_args.train_jointly)
trainer = JointTrainer([training_args,model_args, data_args], model, train_dataset, dev_dataset, idx_to_classes)

loading configuration file ../models/distilbert/formality+jokes/joint/config.json
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-cased",
  "activation": "gelu",
  "architectures": [
    "JointSeqClassifier"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.3.3",
  "vocab_size": 28996
}

loading weights file ../models/distilbert/formality+jokes/joint/pytorch_model.bin
All model checkpoint weights were used when initializing JointSeqClassifier.

All the weights of JointSeqClassifier were initialized from the model checkpoint at ../models/distilbert/formality+jokes/joint.
If your task is similar to the task the model of the check

In [17]:
sentence = "lol, that's dope"
trainer.predict_for_sentence(sentence, tokenizer)

{'formality': {'class': 'informal', 'prob': 0.96661514},
 'jokes': {'class': 'nojoke', 'prob': 0.9876714}}

In [18]:
trainer.predict_for_sentence(sentence, tokenizer, salience=True)

({'formality': {'class': 'informal', 'prob': 0.96661514},
  'jokes': {'class': 'nojoke', 'prob': 0.9876714}},
 {'formality': {'informal': tensor([0.0138, 0.0118, 0.0044, 0.0045, 0.0027, 0.0010, 0.0019, 0.0023, 0.0032,
           0.0161, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000]),
   'formal': tensor([0.0091, 0.0041, 0.0016, 0.0015, 0.0009, 0.0004, 0.0008, 0.0007, 0.0010,
           0.0055, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
    

In [58]:
sys.path.append('../salience/')
import salience

In [59]:
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [60]:
salience.salience(sentence, model, tokenizer, device, task.split('+')[0], 0)

({'formality': 0, 'jokes': 0}, {}, {}, (tensor([[[ 0.5653,  0.1711, -0.0571,  ..., -0.0399,  0.1794, -0.1331],
         [ 0.7349,  0.0629,  0.0695,  ...,  0.3597,  1.8188,  1.2880],
         [-0.9291, -0.3515,  0.3030,  ..., -0.3559, -0.2129,  0.5705],
         ...,
         [ 0.4051, -0.7623, -1.0151,  ...,  1.8879,  0.6867,  0.9014],
         [-0.0872,  0.3867,  0.4939,  ...,  0.5595,  0.3159,  1.1503],
         [-0.0970,  0.0824,  0.2094,  ...,  0.6579, -1.2650,  0.1309]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward>), tensor([[[ 0.2481,  0.0979, -0.0618,  ..., -0.0034, -0.0617, -0.0446],
         [ 0.4034, -0.1321, -0.0512,  ...,  0.8332,  1.8625,  1.4630],
         [-1.7014, -0.4414,  0.5886,  ..., -1.1669, -0.4128,  0.5319],
         ...,
         [ 0.6451, -0.3434, -1.1787,  ...,  1.5239,  1.0760,  1.1709],
         [-0.4663,  0.2110,  1.0824,  ...,  0.0592, -0.2417,  0.9485],
         [-0.1500, -0.0756,  0.0585,  ...,  0.1612, -0.3053, -0.0310]]],
       device='c

KeyError: 'formality'