In [1]:
import sys
import os
from dotenv import load_dotenv

load_dotenv(os.path.expanduser('~/.env'), verbose=True)

data_dir = os.getenv('DATA_IGN_DIR')
adapter_lib_path = os.getenv('ADAPTER_LIB_PATH')

sys.path.insert(0, adapter_lib_path)

In [2]:
import logging
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import random
from dataclasses import dataclass, field
from typing import Optional, List

import datasets
import numpy as np
from datasets import load_dataset, concatenate_datasets

from pprint import pprint

import evaluate
import transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
    get_scheduler,
    PfeifferConfig
)
from transformers.adapters import AdapterArguments, AdapterTrainer, AdapterConfigBase, AutoAdapterModel, setup_adapter_training
from transformers import AdapterConfig, EvalPrediction, TextClassificationPipeline
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from torch.utils.data import DataLoader
import torch

from pdb import set_trace
import transformers.adapters.composition as ac

from transformers.adapters.heads import ClassificationHead
from torch.nn import CrossEntropyLoss, MSELoss

from transformers.trainer_utils import EvalLoopOutput

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, recall_score

from tqdm import tqdm
import json
from datetime import datetime
import random
from datasets import concatenate_datasets, ClassLabel, Value

from transformers import EarlyStoppingCallback

import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import f1_score, accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_count = torch.cuda.device_count()
print(device, device_count)

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),

    'rotten_tomatoes': ("text", None),
    'imdb': ("text", None),
    'yelp_polarity': ("text", None),
    
}

# adapter_info = {'cola': {'load_adapter': 'lingaccept/cola@ukp', 'adapter_config': 'pfeiffer'},
#                 # 'mnli'
#                 'mrpc': {'load_adapter': 'sts/mrpc@ukp',        'adapter_config': 'pfeiffer'},
#                 'qnli': {'load_adapter': 'nli/qnli@ukp',        'adapter_config': 'pfeiffer'},
#                 'qqp' : {'load_adapter': 'sts/qqp@ukp',         'adapter_config': 'pfeiffer'},
#                 'rte' : {'load_adapter': 'nli/rte@ukp',         'adapter_config': 'pfeiffer'},
#                 'sst2': {'load_adapter': 'sentiment/sst-2@ukp', 'adapter_config': 'pfeiffer'},
#                 'stsb': {'load_adapter': 'sts/sts-b@ukp',       'adapter_config': 'pfeiffer'},
                
#                 'rotten_tomatoes': {'load_adapter': 'AdapterHub/bert-base-uncased-pf-rotten_tomatoes', 'adapter_config': 'pfeiffer'},
#                 'imdb': {'load_adapter': 'AdapterHub/bert-base-uncased-pf-imdb', 'adapter_config': 'pfeiffer'},
#                 'yelp_polarity': {'load_adapter': 'AdapterHub/bert-base-uncased-pf-yelp_polarity', 'adapter_config': 'pfeiffer'},
#                }

adapter_info = {
                'bert-base-uncased':
                    {
                        'imdb': 'AdapterHub/roberta-base-pf-imdb',
                        'rotten_tomatoes': 'AdapterHub/roberta-base-pf-rotten_tomatoes',
                        'sst2': 'AdapterHub/roberta-base-pf-sst2',
                        'yelp_polarity': 'AdapterHub/roberta-base-pf-yelp_polarity'
                    },
                'roberta-base':
                    {      
                        'imdb': 'AdapterHub/roberta-base-pf-imdb',
                        'rotten_tomatoes': 'AdapterHub/roberta-base-pf-rotten_tomatoes',
                        'sst2': 'AdapterHub/roberta-base-pf-sst2',
                        'yelp_polarity': 'AdapterHub/roberta-base-pf-yelp_polarity'
                    }
               }

eval_data_dict = {'imdb': 'test', 'yelp_polarity': 'test'}

is_glue = {"cola": True,
            "mnli": True,
            "mrpc": True,
            "qnli": True,
             "qqp": True,
             "rte": True,
            "sst2": True,
            "stsb": True,
            "wnli": True,}

metric_dict = {'rotten_tomatoes': 'sst2', 'imdb': 'sst2', 'yelp_polarity': 'sst2'}

current_time = datetime.now().strftime('%Y%m%d-%H%M%S')

cuda 1


In [3]:
# if len(sys.argv) - 1 != 2:
#     print('Argument error')
#     exit(1)

# _, arg1, arg2 = sys.argv

# task_name_1 = arg1
# adapter_count = int(arg2)

task_name_1= 'rotten_tomatoes'
adapter_count = 16

In [4]:


task_name_str = f'moe_sentiment_{task_name_1}_{adapter_count}E'
model_name_or_path = 'roberta-base'
pad_to_max_length = True
max_seq_length = 128
output_dir = os.path.join(data_dir, f'tmp_case3_sentiment_moeBaseline/{task_name_str}_{current_time}')


adapter_config_default = 'pfeiffer'

adapter_k = 2
noisy_gating = True
gating_layer = None

num_labels = 2

train_test_ratio = 0.2
random_seed = 0

set_seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

print(output_dir)

/home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case3_moeBaseline/moe_sentiment_rotten_tomatoes_16E_20231218-161224


In [5]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
)

def load_dataset_with_glue(task_name):
    if task_name == 'scitail':
        return load_dataset(task_name, 'tsv_format')
    elif task_name in is_glue:
        return load_dataset('glue', task_name)
    else:
        return load_dataset(task_name)

def get_eval_dataset(dataset, task_name):
    if task_name == 'snli' or task_name == 'imdb' or task_name == 'yelp_polarity':
        return dataset['test']
    elif task_name == 'mnli':
        return dataset['validation_matched']
    else:
        return dataset['validation']

def get_data(task_name, raw_datasets):
    sentence1_key, sentence2_key = task_to_keys[task_name]

    if pad_to_max_length:
        padding = "max_length"

    def preprocess_function(examples):    
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
    
        # Map labels to IDs (not necessary for GLUE tasks)
        # if label_to_id is not None and "label" in examples:
            # result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
        result["label"] = [(l if l != -1 else -1) for l in examples["label"]]
        return result
        
    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )

    return raw_datasets

In [6]:
raw_datasets = load_dataset_with_glue(task_name_1)

In [7]:
dataset = get_data(task_name_1, raw_datasets)

train_dataset = dataset['train']

_train_dataset = dataset['train'].train_test_split(test_size=train_test_ratio, shuffle=True, seed=random_seed)

train_dataset = _train_dataset['train']
valid_dataset = _train_dataset['test']

eval_dataset = get_eval_dataset(dataset, task_name_1)

In [8]:
train_dataset

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 6824
})

In [9]:
valid_dataset

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 1706
})

In [10]:
eval_dataset

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 1066
})

In [11]:
model = AutoAdapterModel.from_pretrained(
    model_name_or_path,
    ignore_mismatched_sizes=False
)

model.freeze_model(True)

loaded_adapters = []
for i in range(adapter_count):
    adapter_name = f'expert_{i}'
    model.add_adapter(adapter_name, config=adapter_config_default)
    loaded_adapters.append(adapter_name)

model.active_adapters = ac.Parallel(*loaded_adapters, mode='gating_token')

model.init_gating_network(task_name_str, adapter_k, noisy_gating, gating_layer)

model.add_classification_head(task_name_str)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaAdapterModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaAdapterModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaAdapterModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaAdapterModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
print(model.adapter_summary())

Name                     Architecture         #Param      %Param  Active   Train
--------------------------------------------------------------------------------
expert_0                 bottleneck          894,528       0.716       1       1
expert_1                 bottleneck          894,528       0.716       1       1
expert_2                 bottleneck          894,528       0.716       1       1
expert_3                 bottleneck          894,528       0.716       1       1
expert_4                 bottleneck          894,528       0.716       1       1
expert_5                 bottleneck          894,528       0.716       1       1
expert_6                 bottleneck          894,528       0.716       1       1
expert_7                 bottleneck          894,528       0.716       1       1
expert_8                 bottleneck          894,528       0.716       1       1
expert_9                 bottleneck          894,528       0.716       1       1
expert_10                bot

In [13]:
model.active_head

'moe_sentiment_rotten_tomatoes_16E'

In [14]:
for k, v in model.named_parameters():
    if 'heads' in k or 'gating' in k or 'adapter' in k:
        v.requires_grad = True
    else:
        v.requires_grad = False

In [15]:
total_params = format(sum(p.numel() for p in model.parameters()), ',')
total_params_train = format(sum(p.numel() for p in model.parameters() if p.requires_grad), ',')
print(f'{total_params_train} / {total_params}')

15,199,490 / 139,845,122


In [16]:
# for k, v in model.named_parameters():
#     if v.requires_grad:
#         print(k)

In [17]:
per_device_train_batch_size = 32
per_device_eval_batch_size = 512
weight_decay = 0.0
learning_rate = 1e-3
num_train_epochs = 10
lr_scheduler_type = 'linear'
warmup_ratio = 0.1
patience = 4
alpha_info = 0.5

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_batch_size_train = per_device_train_batch_size * device_count
total_batch_size_eval = per_device_eval_batch_size * device_count

In [18]:
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=1)
    
    return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

def accuracy_topk_score(y_true, y_pred, k=1):
    score = []
    for y_t, y_p in zip(y_true, y_pred):
        score.append(1 if y_t in y_p[:k] else 0)

    return np.mean(score)

In [19]:
training_args = TrainingArguments(
    report_to='all',
    remove_unused_columns=False,
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    logging_dir="./logs",
    seed=random_seed,
    data_seed=random_seed,
    do_train=True,
    do_eval=True,
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    # evaluation_strategy='steps',
    # logging_strategy='steps',
    # save_strategy='steps',
    # eval_steps=2000,
    # logging_steps=2000,
    # save_steps=2000,
    save_total_limit=1,
    load_best_model_at_end = True,
    metric_for_best_model = 'loss'
)

loss_fct = CrossEntropyLoss()

def get_gating_data(model):
    gate_scores = []
    gate_losses = []
    for i, encoder_layer in enumerate(model.base_model.encoder.layer):
        gate_score = encoder_layer.output.gating_data.pop('gate_score')
        gate_loss = encoder_layer.output.gating_data.pop('gate_loss')

        gate_scores.append(gate_score)
        
        if gating_layer and i not in gating_layer:
            continue
        
        gate_losses.append(gate_loss)


    return gate_scores, torch.stack(gate_losses, 0).mean(0)

def loss_gating(logits, gate_loss, labels):
    loss_cls = loss_fct(logits.view(-1, num_labels), labels.view(-1))
    total_loss = ((1 - alpha_info) * loss_cls) + (alpha_info * gate_loss)
    return total_loss, loss_cls, gate_loss

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs):
        labels = inputs.pop('labels')

        # Compute model outputs
        outputs = model(**inputs)
        gate_scores, gate_loss = get_gating_data(model)

        logits = outputs[0].logits
        
        loss, _, _ = loss_gating(logits, gate_loss, labels)

        return loss
        
    def evaluation_loop(
        self,
        dataloader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ):
        # This is a simple modification. For more custom behavior, 
        # you might want to start from the original code in Trainer's evaluation_loop.
        
        # Initialize metrics, etc.
        self.model.eval()
        total_eval_loss = 0.0
        total_eval_loss_cls = 0.0
        total_eval_loss_gate = 0.0
        total_preds = []
        total_logits = []
        total_labels = []
        total_eval_metrics = {}

        adapter_freq = np.array([[0] * adapter_count] * len(model.base_model.encoder.layer))
        
        for step, inputs in enumerate(dataloader):
            labels = inputs.pop('labels').to(self.args.device)
            
            # Move inputs to appropriate device
            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)
            
            # Forward pass and compute loss and metrics
            with torch.no_grad():
                outputs = model(**inputs)
                gate_scores, gate_loss = get_gating_data(model)

                logits = outputs[0].logits

            loss, loss_cls, loss_gate = loss_gating(logits, gate_loss, labels)

            total_eval_loss += loss.item()
            total_eval_loss_cls += loss_cls.item()
            total_eval_loss_gate += loss_gate.item()

            for i, gate_scores_layer in enumerate(gate_scores):
                for gate_scores_batch in gate_scores_layer:
                    top_scores_batch, top_indices_batch = gate_scores_batch.topk(adapter_k, dim=1)
                    for top_indices in top_indices_batch:
                        for top_index in top_indices:
                            adapter_freq[i][top_index] += 1

            total_logits.extend(logits.detach().cpu().numpy())
            total_preds.extend(logits.argmax(dim=-1))
            total_labels.extend(labels.detach().cpu().numpy())

        average_eval_loss = total_eval_loss / len(dataloader)
        average_eval_loss_cls = total_eval_loss_cls / len(dataloader)
        average_eval_loss_gate = total_eval_loss_gate / len(dataloader)
        
        eval_pred = EvalPrediction(predictions=total_logits, label_ids=total_labels)
        
        metrics = self.compute_metrics(eval_pred)

        num_eval_samples = len(dataloader.dataset)

        all_adapter_freq = np.round(adapter_freq / num_eval_samples, decimals=4)
        avg_adapter_freq = np.around(np.mean(adapter_freq, axis=0) / num_eval_samples, decimals=4)
        
        if gating_layer and len(gating_layer) == 1:
            freq_all = None
        else:
            freq_all = [list(o) for o in all_adapter_freq]
            
        total_eval_metrics = {f'{metric_key_prefix}_loss': average_eval_loss,
                              f'{metric_key_prefix}_loss_cls': average_eval_loss_cls,
                              f'{metric_key_prefix}_loss_gate': average_eval_loss_gate,
                              f'{metric_key_prefix}_accuracy': metrics['accuracy'],
                              f'{metric_key_prefix}_gate_freq_avg': list(avg_adapter_freq),
                              f'{metric_key_prefix}_gate_freq_all': freq_all,
                             }

        # return total_eval_loss, total_eval_metrics
        return EvalLoopOutput(predictions=total_preds, 
                              label_ids=total_labels, 
                              metrics=total_eval_metrics, 
                              num_samples=num_eval_samples)


trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=patience)]
    )

In [20]:
os.makedirs(output_dir, exist_ok=True)
train_result = trainer.train()
metrics = train_result.metrics

loss_history = {'base_model': model_name_or_path,
                'max_seq_length': max_seq_length,
                'random_seed': random_seed,
                'lr': learning_rate,
                'warmup_ratio': warmup_ratio,
                'early_stopping_patience': patience,
                'total_batch_size': total_batch_size_train,
                'num_train_epoch': num_train_epochs,
                'adapter_count': adapter_count,
                'adapter_k': adapter_k,
                'noisy_gating': noisy_gating,
                'alpha_info': alpha_info,
                'gating_layer': gating_layer}


with open(os.path.join(output_dir, "hyperparameters.json"), "w") as f:
    json.dump(loss_history, f)

trainer.save_model()

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

os.makedirs(os.path.join(output_dir, f"trained_gating_network"), exist_ok=True)
model.save_gating_network(os.path.join(output_dir, f"trained_gating_network/{task_name_str}"), task_name_str)

os.makedirs(os.path.join(output_dir, f"trained_adapters"), exist_ok=True)
for adapter in loaded_adapters:
    model.save_adapter(os.path.join(output_dir, f"trained_adapters/{adapter}"), adapter)

os.makedirs(os.path.join(output_dir, f"trained_head"), exist_ok=True)
model.save_head(os.path.join(output_dir, f"trained_head/{task_name_str}"), task_name_str)

***** Running training *****
  Num examples = 6824
  Num Epochs = 10
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 2140
  Number of trainable parameters = 15199490


Epoch,Training Loss,Validation Loss,Loss Cls,Loss Gate,Accuracy,Gate Freq Avg,Gate Freq All
1,0.276,0.452187,0.793406,0.110969,0.810082,"[23.9297, 13.7951, 9.2984, 11.6721, 11.4894, 19.1574, 20.97, 17.5376, 16.7266, 15.8474, 16.8836, 21.9895, 18.5515, 13.2515, 17.7794, 7.121]","[[0.6934, 25.8136, 4.939, 83.8986, 13.4988, 0.0, 0.7456, 0.2808, 3.7948, 89.7397, 5.1161, 2.4719, 5.9461, 19.0188, 0.027, 0.0158], [90.3083, 93.9789, 0.3224, 0.8992, 0.3476, 0.1032, 0.0035, 28.9027, 0.5651, 1.1606, 12.7673, 14.5563, 2.3775, 0.1483, 1.6172, 7.942], [102.2755, 0.0311, 1.2732, 2.4601, 0.0064, 0.7169, 0.0, 0.0346, 10.4256, 21.0516, 81.7544, 0.0545, 10.1243, 10.337, 0.7339, 14.721], [0.0, 0.0018, 0.0064, 7.4631, 0.0, 65.6184, 6.6172, 10.2655, 8.0703, 12.5258, 0.0903, 11.4742, 1.1917, 60.1577, 72.4332, 0.0844], [0.1325, 9.9725, 3.5762, 2.2884, 24.49, 18.871, 1.4402, 11.9097, 0.1712, 5.6565, 10.7649, 112.8593, 7.0803, 3.4543, 34.9959, 8.337], [7.221, 24.5012, 18.4994, 3.558, 0.1002, 77.272, 26.9074, 0.7052, 0.0006, 5.8054, 0.9678, 8.4285, 55.4056, 7.2198, 7.0035, 12.4045], [5.0457, 0.0, 53.6753, 9.9988, 0.32, 9.422, 13.2696, 63.9596, 45.3066, 0.0305, 4.8916, 13.4637, 0.0012, 5.2016, 31.4138, 0.0], [0.0569, 8.3001, 0.8347, 0.3599, 29.7526, 16.374, 20.6213, 50.745, 0.1366, 0.0, 26.1383, 40.541, 42.3986, 2.9426, 0.136, 16.6624], [20.3828, 0.0006, 0.2075, 11.5563, 42.2937, 15.6102, 1.3271, 1.6166, 53.7098, 3.9965, 42.8283, 0.0604, 12.313, 0.1395, 46.2444, 3.7134], [0.0, 1.8681, 4.7831, 5.1413, 7.9232, 23.6125, 72.6483, 1.9555, 0.0182, 48.4625, 17.2585, 35.5809, 21.1841, 0.0, 0.0041, 15.5598], [16.0229, 1.0733, 23.4631, 12.4414, 18.4912, 2.289, 46.296, 39.9004, 3.6084, 0.0, 0.0217, 24.3834, 5.6231, 39.3535, 18.7433, 4.2896], [45.017, 0.0, 0.0, 0.0, 0.6483, 0.0, 61.7632, 0.1758, 74.9115, 1.7397, 0.0047, 0.0, 58.973, 11.0445, 0.0, 1.7222]]"
2,0.1943,0.207948,0.312325,0.10357,0.87163,"[8.4195, 9.5398, 18.2963, 15.9447, 16.9168, 16.2247, 20.1295, 16.7506, 19.6154, 11.5378, 26.8729, 8.6438, 14.7053, 24.5061, 12.8042, 15.0927]","[[0.0147, 24.2691, 82.0076, 3.908, 8.8224, 0.0, 40.8453, 33.0574, 18.51, 0.0, 0.0123, 0.323, 4.0358, 0.4004, 5.7052, 34.0891], [0.0604, 2.0199, 0.1841, 17.4543, 11.1483, 1.6442, 0.7005, 25.7327, 67.371, 17.5293, 61.0281, 0.5012, 0.5574, 44.0264, 3.0059, 3.0363], [8.0557, 1.0076, 0.2122, 5.1372, 42.5516, 3.517, 12.5721, 10.1348, 28.2767, 0.0123, 8.0, 1.078, 87.1817, 9.2743, 28.3277, 10.6612], [0.0012, 0.0938, 0.6213, 0.0, 0.0, 30.9496, 1.5698, 7.8341, 15.8013, 12.3353, 73.6489, 2.3177, 5.6571, 104.3927, 0.1342, 0.643], [0.4607, 16.0727, 1.6284, 56.7034, 24.296, 66.9777, 10.9338, 4.9513, 0.3828, 10.17, 1.8118, 3.8136, 0.6049, 16.6254, 26.0615, 14.5059], [6.1934, 25.7597, 23.4226, 23.2057, 2.7737, 15.1964, 33.4132, 0.0, 0.0, 8.5797, 3.9871, 41.561, 0.1366, 15.3236, 38.1676, 18.2796], [3.0826, 0.9273, 39.7579, 8.1043, 8.0305, 29.7198, 32.5334, 15.2362, 32.1055, 2.364, 63.3687, 1.4179, 3.711, 9.8066, 4.2028, 1.6313], [12.9941, 3.5029, 4.5305, 2.2409, 19.7421, 22.1952, 19.8705, 9.4543, 3.4068, 34.1032, 38.9209, 11.2907, 2.0657, 11.262, 0.1096, 60.3107], [10.7773, 10.4179, 2.6225, 36.1659, 14.0674, 0.786, 15.2315, 50.7884, 0.9683, 17.857, 3.4607, 8.4478, 4.493, 19.7474, 37.9678, 22.2011], [29.8892, 18.578, 28.9865, 19.7081, 38.6079, 3.5533, 39.5358, 1.6008, 10.4144, 0.3546, 18.7069, 27.2737, 6.2784, 2.459, 4.1717, 5.8816], [15.347, 3.7104, 4.5135, 1.3822, 32.7767, 10.4977, 30.7198, 11.0662, 3.5199, 18.1712, 35.7145, 2.527, 37.9519, 43.221, 4.34, 0.541], [14.1577, 8.1184, 31.068, 17.3265, 0.1852, 9.6594, 3.6284, 31.1506, 54.6284, 16.9766, 13.8154, 3.1735, 23.7896, 17.5346, 1.4566, 9.3312]]"
3,0.1572,0.255346,0.400203,0.110489,0.861079,"[10.8163, 8.159, 18.9494, 14.7458, 10.918, 15.6252, 19.1049, 15.2458, 15.2125, 6.2448, 14.5965, 22.3553, 14.2371, 18.3429, 32.6826, 18.7638]","[[0.007, 22.7081, 94.6559, 0.4965, 5.9408, 0.0, 4.3154, 0.5317, 12.7087, 0.0657, 0.2134, 0.0006, 11.8265, 14.3546, 0.0199, 88.1553], [0.0029, 1.3271, 0.061, 2.524, 4.1946, 0.0182, 33.5217, 34.9894, 13.1899, 0.5451, 3.4953, 48.2274, 0.0914, 17.5897, 96.2169, 0.0053], [32.4953, 0.0, 1.0627, 1.5117, 20.8277, 6.4678, 0.0006, 0.0, 17.381, 2.7116, 0.6096, 0.0475, 83.6366, 61.6313, 10.8312, 16.7855], [0.0264, 0.0106, 27.2333, 15.9385, 0.9402, 86.976, 9.1706, 30.109, 0.1653, 6.6049, 11.2175, 5.5914, 0.126, 35.8681, 11.9842, 14.0381], [1.9625, 3.2907, 15.0176, 70.6823, 0.1547, 45.6454, 1.1624, 4.4326, 1.3816, 2.932, 0.0979, 16.9121, 0.0006, 2.6125, 11.9256, 77.7896], [6.6465, 25.4103, 0.6225, 26.7438, 1.8312, 1.1249, 30.4988, 0.0, 8.2556, 13.6577, 0.9004, 85.0381, 0.0287, 3.2837, 48.5744, 3.3834], [3.6764, 0.0, 47.881, 4.3634, 15.6958, 1.6032, 36.2808, 4.027, 2.9138, 0.8195, 70.9719, 0.1735, 1.6729, 13.6758, 51.6923, 0.5528], [3.7814, 25.6659, 19.8775, 2.0451, 20.7644, 14.8376, 9.9619, 10.7796, 0.0996, 14.5662, 33.6671, 78.8798, 2.796, 5.9549, 0.524, 11.7989], [44.7562, 7.0047, 0.2128, 20.7614, 30.6301, 0.2948, 14.8646, 42.0064, 6.1401, 4.9355, 1.714, 0.9091, 17.7802, 0.4572, 59.5082, 4.0246], [5.8341, 7.0991, 4.5662, 5.1325, 3.49, 5.8242, 42.5082, 3.8933, 21.898, 6.1196, 23.7087, 19.8072, 17.0457, 5.8933, 81.3535, 1.8265], [8.3277, 1.068, 10.7919, 0.1038, 14.9449, 20.5574, 39.0328, 13.9004, 40.078, 7.3687, 8.5416, 11.061, 26.0457, 35.0217, 15.5932, 3.5633], [22.279, 4.3242, 5.4103, 26.6471, 11.6014, 4.153, 7.9408, 38.2802, 58.3382, 14.6114, 20.0211, 1.6161, 9.7948, 23.7714, 3.9683, 3.2427]]"
4,0.1362,0.193753,0.27586,0.111646,0.891559,"[12.6164, 6.467, 17.5795, 16.1856, 9.9367, 16.2198, 29.0226, 14.2903, 14.4934, 8.2112, 23.3778, 11.8036, 15.6054, 22.0126, 28.1186, 10.0596]","[[0.0914, 25.6137, 85.7773, 0.0, 6.7415, 0.7685, 82.2761, 0.0018, 11.4326, 0.9637, 0.337, 2.2362, 10.2081, 0.4672, 0.0135, 29.0715], [0.7702, 0.2128, 6.0651, 18.6114, 12.8476, 0.0205, 22.466, 27.6565, 34.7749, 0.1448, 2.6653, 1.4771, 1.5117, 95.2374, 28.0387, 3.5], [38.8441, 0.7691, 0.5651, 6.8054, 1.0774, 1.1196, 0.3581, 0.0, 15.0826, 0.041, 87.5574, 1.2825, 69.4273, 11.6928, 7.8945, 13.483], [0.0147, 0.0176, 3.7087, 1.9789, 0.99, 58.2591, 3.6184, 15.4068, 0.0, 7.1237, 58.041, 11.6811, 0.0, 39.1061, 55.2778, 0.7761], [0.051, 0.0363, 22.0387, 90.6782, 1.8658, 34.1706, 1.0041, 37.7304, 2.3113, 1.0973, 0.0229, 22.255, 3.289, 6.3189, 2.3652, 30.7655], [9.561, 22.5311, 0.0, 29.721, 14.0481, 0.7491, 54.5774, 0.0, 3.4144, 14.415, 0.0492, 48.8875, 0.3558, 3.1823, 52.0996, 2.4086], [13.466, 0.0522, 54.0258, 11.9455, 4.1383, 3.5991, 2.2251, 0.5264, 3.2649, 3.3189, 60.3781, 0.1665, 3.68, 10.2896, 84.864, 0.0598], [0.0041, 9.2579, 26.9308, 0.8206, 32.2603, 43.1032, 40.4103, 0.0, 0.0, 44.6852, 42.9461, 4.905, 0.0, 8.7351, 0.3247, 1.6166], [32.8699, 14.7591, 0.2239, 4.4385, 14.0815, 0.1741, 0.0006, 41.5909, 5.1055, 8.4367, 0.0, 0.823, 36.8763, 0.8206, 60.9414, 34.8581], [41.915, 2.9566, 0.0, 8.4443, 6.3488, 7.6184, 72.7433, 3.3189, 1.4285, 15.303, 2.3171, 5.7567, 47.592, 0.0, 40.0856, 0.1717], [1.9766, 0.803, 11.0006, 0.0, 2.6389, 39.068, 45.4478, 4.0264, 34.1114, 0.0692, 10.7755, 39.1835, 12.2046, 48.058, 3.8546, 2.7819], [11.8329, 0.5944, 0.6184, 20.7837, 22.2022, 5.9871, 23.1442, 41.2251, 62.9947, 2.9355, 15.4443, 2.9894, 2.1196, 40.2427, 1.663, 1.2227]]"
5,0.1109,0.228585,0.347533,0.109636,0.882181,"[7.9412, 10.6282, 14.1116, 12.3515, 19.1297, 16.9568, 22.641, 18.7249, 11.7642, 5.9914, 13.8163, 19.8625, 13.4665, 23.5871, 29.7349, 15.2923]","[[0.129, 26.041, 41.4308, 0.0, 6.6067, 1.6981, 59.3165, 0.0059, 7.8429, 0.0006, 0.6454, 0.296, 13.5633, 0.3212, 0.0258, 98.0768], [1.1389, 0.092, 1.1946, 0.4502, 0.4971, 0.425, 18.7866, 38.5, 4.9144, 0.4191, 1.0141, 52.2327, 1.9947, 61.8921, 71.9349, 0.5135], [21.126, 1.2567, 0.3822, 29.2737, 12.8652, 8.5293, 1.4924, 0.1208, 10.1981, 2.2438, 7.4601, 0.2597, 54.959, 86.2784, 12.4654, 7.0891], [0.5305, 22.2743, 1.0639, 3.6981, 0.5328, 51.5498, 1.0281, 26.0604, 0.0, 7.8769, 37.6577, 6.1512, 0.0, 7.7773, 64.7796, 25.0193], [0.0, 0.0363, 18.5914, 60.0879, 2.4062, 46.3664, 1.007, 56.204, 1.2327, 0.4596, 0.0006, 47.3816, 1.6161, 0.7233, 4.2233, 15.6635], [3.2562, 9.8124, 8.578, 2.8523, 56.1389, 0.1659, 51.4232, 0.0, 2.5117, 11.5358, 0.0, 54.5393, 0.1184, 4.8974, 48.459, 1.7116], [5.2456, 0.0047, 61.5023, 2.9596, 34.5586, 9.3804, 13.8394, 0.8447, 9.8681, 0.9543, 47.2591, 0.279, 12.4666, 16.7186, 40.1125, 0.0064], [0.2198, 25.197, 28.6934, 1.7784, 20.5492, 37.1483, 34.7884, 0.0, 0.0, 8.3335, 44.4596, 37.857, 0.0, 16.3792, 0.2825, 0.3136], [33.7433, 0.0, 2.2192, 2.7521, 45.1659, 0.0, 0.0, 43.2081, 9.7978, 4.2245, 0.0, 1.5662, 6.5885, 5.2433, 71.6055, 29.8857], [0.9818, 36.1401, 1.0322, 12.2638, 14.5592, 0.7251, 72.8552, 6.575, 0.1981, 29.3447, 0.1284, 5.9959, 39.8183, 2.4132, 31.0053, 1.9637], [2.7761, 5.4045, 4.5961, 1.7479, 0.5674, 44.2919, 12.2169, 3.5193, 38.4525, 0.0023, 24.7497, 29.8998, 26.1477, 47.8945, 11.8728, 1.8605], [26.1471, 1.279, 0.0545, 30.354, 35.109, 3.2016, 4.9379, 49.6612, 56.1536, 6.5012, 2.4215, 1.8921, 4.3253, 32.507, 0.0516, 1.4033]]"
6,0.0904,0.207781,0.305514,0.110048,0.886284,"[9.7218, 12.8265, 10.3774, 11.7386, 22.1044, 19.7308, 27.6791, 11.9632, 10.2061, 7.3607, 17.8356, 11.2287, 11.0628, 27.2823, 28.2512, 16.6309]","[[0.5674, 25.2122, 24.6987, 0.1864, 6.0973, 0.0, 92.0698, 0.0029, 10.1295, 0.0023, 0.7309, 0.5012, 14.6149, 0.6964, 0.0158, 80.4742], [0.0, 0.2907, 1.3077, 6.7778, 0.374, 0.0023, 32.762, 22.1758, 5.2403, 0.1594, 0.0955, 46.2995, 2.0463, 115.7333, 22.6032, 0.1319], [50.1794, 0.0, 0.1272, 24.4918, 9.1641, 0.2403, 0.357, 0.1301, 11.1647, 0.0094, 0.7562, 0.3259, 37.0914, 105.5932, 9.612, 6.7573], [0.0164, 19.9308, 0.6524, 0.4771, 0.2978, 61.9209, 4.6788, 5.8576, 0.0, 5.8118, 74.2884, 5.483, 0.0, 12.9103, 62.4449, 1.2298], [0.0, 0.0621, 9.7157, 46.1858, 0.3447, 68.6852, 1.0475, 16.9297, 14.541, 0.1155, 0.0018, 15.6184, 0.0, 2.1882, 9.0938, 71.4707], [0.7022, 19.095, 3.2386, 6.1712, 53.5504, 0.2263, 60.6758, 0.0, 1.1225, 18.5457, 0.2128, 31.6893, 0.0633, 12.3118, 48.3939, 0.0012], [1.9127, 0.0193, 49.0346, 5.7075, 70.7134, 3.3394, 5.7216, 2.8171, 5.527, 1.3277, 51.8324, 0.2667, 1.405, 16.9531, 39.4156, 0.007], [1.3646, 9.8669, 29.9115, 1.5739, 33.6589, 44.3288, 39.3312, 0.551, 1.5897, 8.7579, 52.6729, 1.0868, 0.0, 8.6231, 16.609, 6.0739], [36.1805, 3.7655, 0.0, 4.7128, 54.6354, 0.5727, 0.0533, 44.527, 2.2691, 3.956, 0.0, 0.0023, 5.2403, 14.9513, 57.1876, 27.9461], [2.5182, 46.7784, 0.0363, 18.6061, 12.6178, 0.2491, 62.6073, 1.5258, 1.0709, 22.9601, 0.0, 6.075, 47.7257, 2.2145, 30.1067, 0.908], [2.9924, 15.8816, 5.6032, 0.7814, 1.6073, 53.6231, 29.6401, 5.6225, 20.4543, 0.0012, 20.075, 8.0041, 20.1747, 26.4977, 43.1114, 1.9302], [20.228, 13.0152, 0.2028, 25.1911, 22.1923, 3.5815, 3.2046, 43.4191, 49.364, 26.6811, 13.3611, 19.3921, 4.3916, 8.7145, 0.4203, 2.6407]]"
7,0.0721,0.225586,0.346491,0.104681,0.888042,"[12.5136, 8.5907, 10.9527, 12.2519, 17.0731, 18.1158, 27.8012, 14.4875, 11.7695, 11.3764, 18.7212, 13.5471, 13.8542, 25.6154, 21.9054, 17.4243]","[[0.6301, 25.1665, 55.9713, 0.8453, 3.119, 0.1213, 84.6958, 0.0023, 13.0516, 0.1823, 0.2614, 0.4607, 15.8458, 0.4871, 0.0047, 55.1547], [0.0, 1.796, 1.1788, 25.0768, 0.6336, 0.1811, 34.442, 29.2374, 13.0404, 0.4455, 0.2327, 44.3177, 0.6208, 73.6987, 30.4185, 0.68], [52.7743, 0.1213, 0.2005, 8.493, 19.7134, 0.1096, 0.3529, 0.1589, 10.7098, 0.0047, 4.1395, 0.1653, 62.6723, 91.7972, 0.2538, 4.3335], [0.0727, 20.6243, 1.8341, 4.6682, 0.7497, 62.7603, 0.0006, 8.34, 0.0, 12.9549, 55.7573, 7.8822, 0.0, 26.9379, 45.3957, 8.0223], [0.0, 0.0, 8.7732, 40.8599, 2.7556, 55.5252, 1.6272, 23.2526, 20.0363, 3.2485, 0.0023, 24.3406, 0.1794, 0.7732, 12.5047, 62.1213], [5.7421, 9.0358, 0.0, 1.9009, 77.6835, 0.5885, 50.4566, 0.0, 2.8687, 43.9601, 2.0569, 31.0574, 1.0399, 15.9683, 13.5199, 0.1213], [0.3951, 0.0135, 40.0293, 5.7192, 30.6653, 5.8453, 25.9496, 2.7978, 13.0504, 2.228, 55.6571, 0.3781, 0.0, 11.5703, 61.6899, 0.0111], [2.5574, 7.0076, 9.4678, 4.0487, 28.2198, 51.772, 47.6958, 0.0551, 0.0012, 12.303, 34.4642, 0.4912, 0.0, 39.2515, 10.9385, 7.7263], [29.4361, 21.9725, 0.3617, 14.456, 15.1465, 0.5979, 1.0903, 21.8839, 5.5158, 12.5891, 31.1659, 0.0006, 20.7802, 4.0498, 44.6014, 32.3523], [30.3007, 7.7222, 0.0557, 8.0838, 6.2485, 9.7433, 43.8798, 18.0275, 0.7585, 19.0481, 4.0487, 16.2397, 34.129, 4.061, 31.1981, 22.4555], [4.4443, 1.303, 9.4385, 0.1923, 7.0586, 28.4666, 27.2591, 38.4543, 21.1032, 0.0006, 30.7304, 13.3025, 21.5586, 28.18, 12.1987, 12.3095], [23.8107, 8.3253, 4.1213, 32.6782, 12.8834, 1.6782, 16.1647, 31.6407, 41.0979, 29.5522, 6.1383, 23.9297, 9.4244, 10.6096, 0.1413, 3.8042]]"
8,0.0553,0.225402,0.345454,0.10535,0.885111,"[14.1208, 9.5362, 9.2585, 14.0748, 16.0394, 17.8002, 24.7877, 14.109, 12.0541, 12.1739, 20.6892, 13.5276, 15.8872, 26.3602, 18.5508, 17.0301]","[[0.6635, 25.3447, 37.7339, 0.5487, 2.8916, 0.1213, 88.3892, 0.0094, 12.4349, 1.0223, 0.9484, 0.3511, 14.8558, 0.8036, 0.0047, 69.8769], [0.1213, 2.9478, 0.9191, 30.68, 0.4437, 0.4009, 19.9543, 23.6905, 5.3288, 0.3769, 0.9801, 48.8576, 2.507, 105.8441, 12.3921, 0.5557], [66.5557, 0.1729, 0.1026, 22.3535, 14.5147, 0.0574, 0.3998, 0.1788, 13.5715, 0.0041, 2.1829, 0.0399, 57.2022, 72.7544, 0.3998, 5.51], [0.0141, 23.4824, 1.7995, 0.4818, 0.4625, 52.1178, 0.0, 11.7526, 0.0, 7.3335, 80.5692, 5.8189, 0.0, 27.1717, 42.1706, 2.8253], [0.0, 0.0, 8.8734, 43.0023, 1.1336, 53.8664, 1.4484, 21.9578, 11.4174, 0.7978, 0.0018, 19.041, 0.1794, 1.17, 9.6219, 83.4889], [5.4121, 29.6917, 6.2098, 1.3687, 77.8277, 0.2433, 46.7702, 0.0, 4.0141, 37.0293, 1.0129, 22.7591, 1.1313, 13.6243, 8.6987, 0.2069], [0.6125, 0.0275, 38.8962, 17.534, 14.8921, 6.9965, 16.4121, 6.1893, 26.3834, 1.9519, 51.1518, 0.0064, 0.0, 7.9683, 66.9162, 0.0615], [25.0604, 4.9631, 6.9396, 2.0868, 28.8142, 47.1278, 53.1764, 0.0551, 0.0, 9.5686, 32.9496, 0.2872, 0.0, 31.9812, 4.2972, 8.6928], [21.5375, 7.2814, 0.0, 11.007, 13.4918, 0.2831, 0.1671, 29.7819, 13.973, 26.6583, 45.0223, 0.0035, 35.4127, 2.7421, 41.3775, 7.2608], [22.6131, 5.8886, 0.0533, 9.5739, 8.0041, 8.5604, 39.1026, 12.1424, 0.922, 37.1489, 2.3189, 17.1841, 46.5123, 3.956, 22.7667, 19.2526], [0.9654, 3.9132, 6.0123, 0.1319, 5.3036, 43.1823, 19.3693, 34.7374, 19.8916, 0.0006, 23.4701, 26.5873, 25.078, 27.3605, 13.9642, 6.0322], [25.8933, 10.7216, 3.5627, 30.1295, 24.6934, 0.6454, 12.2632, 28.8124, 36.7128, 24.1952, 7.663, 21.3957, 7.7679, 20.9467, 0.0, 0.5973]]"


Trainer is attempting to log a value of "[23.9297, 13.7951, 9.2984, 11.6721, 11.4894, 19.1574, 20.97, 17.5376, 16.7266, 15.8474, 16.8836, 21.9895, 18.5515, 13.2515, 17.7794, 7.121]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[[0.6934, 25.8136, 4.939, 83.8986, 13.4988, 0.0, 0.7456, 0.2808, 3.7948, 89.7397, 5.1161, 2.4719, 5.9461, 19.0188, 0.027, 0.0158], [90.3083, 93.9789, 0.3224, 0.8992, 0.3476, 0.1032, 0.0035, 28.9027, 0.5651, 1.1606, 12.7673, 14.5563, 2.3775, 0.1483, 1.6172, 7.942], [102.2755, 0.0311, 1.2732, 2.4601, 0.0064, 0.7169, 0.0, 0.0346, 10.4256, 21.0516, 81.7544, 0.0545, 10.1243, 10.337, 0.7339, 14.721], [0.0, 0.0018, 0.0064, 7.4631, 0.0, 65.6184, 6.6172, 10.2655, 8.0703, 12.5258, 0.0903, 11.4742, 1.1917, 60.1577, 72.4332, 0.0844], [0.1325, 9.9725, 3.5762, 2.2884, 24.49, 18.871, 1.4402, 11.9097, 0.1712, 5.6565, 10.7649

***** train metrics *****
  epoch                    =        8.0
  total_flos               =  3937761GF
  train_loss               =     0.1366
  train_runtime            = 0:40:45.57
  train_samples_per_second =     27.903
  train_steps_per_second   =      0.875


Module weights saved in /home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case3_moeBaseline/moe_sentiment_rotten_tomatoes_16E_20231218-161224/trained_adapters/expert_12/pytorch_adapter.bin
Configuration saved in /home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case3_moeBaseline/moe_sentiment_rotten_tomatoes_16E_20231218-161224/trained_adapters/expert_13/adapter_config.json
Module weights saved in /home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case3_moeBaseline/moe_sentiment_rotten_tomatoes_16E_20231218-161224/trained_adapters/expert_13/pytorch_adapter.bin
Configuration saved in /home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case3_moeBaseline/moe_sentiment_rotten_tomatoes_16E_20231218-161224/trained_adapters/expert_14/adapter_config.json
Module weights saved in /home/jaehan/research/adapter/adapter-poisoning/data_ign/tmp_case3_moeBaseline/moe_sentiment_rotten_tomatoes_16E_20231218-161224/trained_adapters/expert_14/pytorch_adapter.bin
Co

In [21]:
metrics = trainer.evaluate(eval_dataset=eval_dataset)
pprint(metrics)
trainer.save_metrics("eval", metrics)

Trainer is attempting to log a value of "[12.3778, 6.6641, 17.4005, 16.6528, 10.3441, 15.9994, 29.4254, 13.8707, 14.5707, 7.8677, 23.373, 11.8183, 15.4928, 21.9837, 28.3972, 9.7616]" of type <class 'list'> for key "eval/gate_freq_avg" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "[[0.1088, 25.8902, 84.9737, 0.0009, 6.666, 0.6764, 84.575, 0.0019, 11.5685, 0.8743, 0.3602, 2.2655, 10.1585, 0.6914, 0.0159, 27.1726], [0.6801, 0.2064, 7.0835, 20.4784, 14.3161, 0.0216, 21.9456, 26.56, 33.0441, 0.1313, 2.6548, 1.6961, 1.728, 93.6501, 28.0525, 3.7514], [37.9737, 0.6764, 0.591, 8.3255, 0.9615, 0.8386, 0.3555, 0.0816, 15.0882, 0.0403, 88.03, 1.3096, 68.0816, 11.6407, 8.6051, 13.4006], [0.0188, 0.1642, 2.9428, 1.6895, 1.3039, 58.7495, 3.606, 15.0816, 0.0, 6.4362, 57.0882, 11.1332, 0.0, 40.5094, 56.4756, 0.8011], [0.1932, 0.0, 20.6238, 90.1323, 2.2552, 34.0957, 1.0084, 35.7786, 2.4878, 1.258, 0.0

{'epoch': 8.0,
 'eval_accuracy': 0.8902438879013062,
 'eval_gate_freq_all': [[0.1088,
                         25.8902,
                         84.9737,
                         0.0009,
                         6.666,
                         0.6764,
                         84.575,
                         0.0019,
                         11.5685,
                         0.8743,
                         0.3602,
                         2.2655,
                         10.1585,
                         0.6914,
                         0.0159,
                         27.1726],
                        [0.6801,
                         0.2064,
                         7.0835,
                         20.4784,
                         14.3161,
                         0.0216,
                         21.9456,
                         26.56,
                         33.0441,
                         0.1313,
                         2.6548,
                         1.6961,
               

In [22]:
# input('Remove files?\n')
# import shutil
# directory_path = output_dir
# shutil.rmtree(directory_path)

In [23]:
# import os
# os._exit(00)

In [24]:
# for layer in model.roberta.encoder.layer:
#     layer.output.gating_data.pop('gate_score')
#     layer.output.gating_data.pop('gate_loss')