# **Try 2. Classification ver.**

*  Dataset: superGLUE
*  Model: T5-base



In [None]:
!pip install transformers
!pip install wandb

In [6]:
import json

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AdamW, T5TokenizerFast, T5ForConditionalGeneration, DataProcessor

import wandb
# from dataset import SoftDataset, task_to_target_len, get_tasks_processor
# from utils import compute_task_metrics


In [None]:
class Finetuner:
    def __init__(self, args, device):
        self.superglue_datasets = [ 'axb', 'axg', 'boolq', 'cb', 'copa', 'multirc', 'record',  'rte', 'wic', 'wsc', 'wsc.fixed' ] # 수정 완료

        # 학습 정보를 args(argument) 에 저장한 상태로 pass - 실제 학습 시 argument 딕셔너리 생성 필요
        self.args = args
        self.task = args.task
        self.data_dir = args.data_dir

        self.lr = args.lr
        self.exp = args.exp
        self.weight_decay = args.weight_decay
        self.eps = args.eps
        self.batch_size = args.batch_size
        self.seq_len = args.seq_len
        self.device = device
        self.num_prompts = args.num_prompts
        self.use_prompt_token = args.use_prompt_token
        self.model_name = args.model_name
        self.tokenizer = T5TokenizerFast.from_pretrained(args.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(args.model_name)
        self.tasks_data_dict = self.get_data_loader_dict() # 하단에 함수 정의
        self.model = self.model.to(device)

    def get_optimizer(self, lr, weight_decay, eps):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [

            # parameters with weight decay
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
                "lr": lr,
            },

            # parameters without weight decay
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
                "lr": lr,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, eps=eps)
        return optimizer

    # 전처리 함수
    def get_data_loader_dict(self):
        data_loader_dict = {}

        target_len = task_to_target_len[self.task]
        Processor = get_tasks_processor(self.args.processor, self.args.task)

        train_data = Processor().get_train_examples(self.data_dir)
        data_loader_dict['train'] = DataLoader(
            dataset=SoftDataset(train_data, self.tokenizer, self.num_prompts, self.seq_len, target_len, True),
            batch_size=self.batch_size, shuffle=True)

        dev_data = Processor().get_dev_examples(self.data_dir)
        data_loader_dict['val'] = DataLoader(
            dataset=SoftDataset(dev_data, self.tokenizer, self.num_prompts, self.seq_len, target_len, False),
            batch_size=self.batch_size)

        test_data = Processor().get_test_examples(self.data_dir)
        data_loader_dict['test'] = DataLoader(
            dataset=SoftDataset(test_data, self.tokenizer, self.num_prompts, self.seq_len, target_len, False),
            batch_size=self.batch_size)

        return data_loader_dict

    def finetune(self, epochs):
        print('finetune')
        print(f'task = {self.task}')
        print(f'pad_id = {self.tokenizer.pad_token_id}')
        model = self.model
        optimizer = self.get_optimizer(lr=self.lr, weight_decay=self.weight_decay, eps=self.eps)
        dataloader_train = self.tasks_data_dict['train']
        dataloader_val = self.tasks_data_dict['val']
        target_len = task_to_target_len[self.task]

        generation_arguments = {
            "max_length": target_len,
        }

        max_acc, max_f1_score, max_em = 0, 0, 0
        for epoch in range(epochs):
            model.train()
            logs = None
            loss_values, ppl_loss_values = [], []
            for i, batch in enumerate(tqdm(dataloader_train)):
                batch = {k: batch[k].to(model.device) for k in batch}
                y = batch["target_ids"]
                lm_labels = y[:, :].clone().detach()
                lm_labels[y[:, :] == self.tokenizer.pad_token_id] = -100

                loss = model(
                    input_ids=batch["source_ids"],
                    labels=lm_labels
                ).loss

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                loss_values.append(loss.detach().cpu().numpy())
                logs = {
                    'epoch': epoch,
                    'train_loss': np.mean(loss_values)
                }
                if self.args.use_wandb:
                    wandb.log(logs)

                del batch
                del lm_labels
                del loss

            with torch.no_grad():
                print(logs)
                y_pred, y_true, loss_values, ppl_loss_values = [], [], [], []
                model.eval()
                for i, batch in enumerate(tqdm(dataloader_val)):
                    batch = {k: batch[k].to(model.device) for k in batch}

                    outputs = model.generate(
                        input_ids=batch['source_ids'],
                        **generation_arguments,
                    )

                    pred = [self.tokenizer.decode(token_ids=ids, skip_special_tokens=True) for ids in outputs]
                    target = [self.tokenizer.decode(token_ids=ids, skip_special_tokens=True) for ids in batch['target_ids']]
                    model.train()
                    if i == 0:
                        print(f'example predictions pred: {pred} target: {target}')

                    y_pred += pred
                    y_true += target

                    lm_labels = batch["target_ids"]
                    lm_labels[lm_labels[:, :] == 0] = -100

                    loss = model(
                        input_ids=batch['source_ids'],
                        labels=lm_labels
                    ).loss

                    loss_values.append(loss.detach().cpu().numpy())

                    del lm_labels
                    del loss

                logs = compute_task_metrics(task=self.task,
                                            y_pred=y_pred,
                                            y_true=y_true,
                                            val_dataset=dataloader_val.dataset)
                save_json = False
                if 'acc' in logs:
                    if max_acc < logs['acc']:
                        max_acc = logs['acc']
                        if self.args.save_model:
                            model.save_pretrained(f'{self.args.output_dir}/{self.args.exp}_task{self.args.task}_{self.args.model_name.replace("/","_")}.pth', from_pt=True)
                        save_json = True
                    logs['max_acc'] = float(max_acc)
                if 'f1_score' in logs:
                    if max_f1_score < logs['f1_score']:
                        max_f1_score = logs['f1_score']
                    logs['max_f1_score'] = float(max_f1_score)
                if 'em' in logs:
                    if max_em < logs['em']:
                        max_em = logs['em']
                        if self.args.save_model:
                            model.save_pretrained(f'{self.args.output_dir}/{self.args.exp}_task{self.args.task}_{self.args.model_name.replace("/","_")}.pth', from_pt=True)
                        save_json = True
                    logs['max_em'] = float(max_em)

                logs['epoch'] = epoch
                logs['val_loss'] = float(np.mean(loss_values))

                print(logs)
                if self.args.use_wandb:
                    wandb.log(logs)
                if self.args.save_model and save_json:
                    with open(f'{self.args.output_dir}/{self.args.exp}_task{self.args.task}_{self.args.model_name.replace("/","_")}.json', 'w') as outfile:
                        logs['model_name'] = self.args.model_name
                        logs['exp'] = self.args.exp
                        logs['task'] = self.args.task
                        logs['lr'] = self.args.lr
                        json.dump(logs, outfile, indent=4)


# **이하 레퍼런스 코드**

In [None]:
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import datasets


class T5Dataset:
    def __init__(self, tokenizer, task):
        """Dataset class for T5 model experiments.
        Args:
            task (str): Name of the downstream task.
            tokenizer (HuggingFace Tokenizer): T5 model tokenizer to use.
        """

        self.tokenizer = tokenizer
        self.glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', \
                              'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax']
        self.superglue_datasets = ['copa', 'boolq', 'wic', 'wsc', 'cb', 'record', 'multirc', 'rte_superglue', 'wsc_bool']

        # Column keys used in the dataset
        self.task_to_keys = {
            "cola": ("sentence", None),
            "mnli": ("premise", "hypothesis"),
            "mnli-mm": ("premise", "hypothesis"),
            "mrpc": ("sentence1", "sentence2"),
            #"qnli": ("question", "sentence"),
            "qnli": ("text1", "text2"),
            "qqp": ("question1", "question2"),
            "rte": ("sentence1", "sentence2"),
            "sst2": ("sentence", None),
            "stsb": ("sentence1", "sentence2"),
            "wnli": ("sentence1", "sentence2"),

            "boolq": ("passage", "question"),
            "copa": ('choice1', 'choice2', 'premise', 'question'),
            "wic": ("start1", "end1", "sentence1", "start2", "end2", "sentence2", "word"),
            "wsc": ("span1_text", "span1_index", "span2_text", "span2_index", "text"),
            "wsc_bool": ("span1_text", "span1_index", "span2_text", "span2_index", "text"),
            "cb": ("premise", "hypothesis"),
            "record": ("passage", "query", "entities"),
            "multirc": ("question", "answer", "paragraph"),
            "rte_superglue": ("premise", "hypothesis"),

            "scicite": ("sectionName", "string"),
            "imdb": ("text", None),

            "ag_news": ("text", None),
            "yelp_review_full": ("text", None),
            "yahoo_answers_topics": ("question_content", "best_answer"),
            "dbpedia_14": ("title", "content"),

            "ag": ("content", None),
            "yelp": ("content", None),
            "yahoo": ("content", None),
            "dbpedia": ("content", None),
            "amazon": ("content", None),
        }

        # Label text for T5 tasks
        # (T5 has text-to-text format for text and labels)
        self.task_to_labels = {
            "cola": ("not_acceptable", "acceptable"),
            "mnli": ("entailment", "neutral", "contradiction"),
            "mnli-mm": (),
            "mrpc": ("not_equivalent", "equivalent"),
            "qnli": ("entailment", "not_entailment"),
            "qqp": ("not_duplicate", "duplicate"),
            "rte": ("entailment", "not_entailment"),
            "sst2": ("negative", "positive"),
            "stsb": (),
            "wnli": (),

            "boolq": ("false", "true"),
            "copa": ("false", "true"),
            "wic": ("false", "true"),
            "wsc_bool": ("false", "true"),
            "cb": ("entailment", "contradiction", "neutral"),
            "multirc": ("false", "true"),
            "rte_superglue": ("entailment", "not_entailment"),

            "scicite": (),
            "imdb": ("negative", "positive"),

            "ag_news": ("world", "sports", "business", "science"),
            "yelp_review_full": ("terrible", "bad", "middle", "good", "wonderful"),
            "yahoo_answers_topics": ("society and culture", "science", "health", "education and reference",
                                     "computers and internet", "sports", "business", "entertainment and music",
                                     "family and relationships", "politics and government"),
            "dbpedia_14": ("company", "educationalinstitution", "artist", "athlete", "officeholder",
                           "meanoftransportation", "building", "naturalplace", "village", "animal",
                           "plant", "album", "film", "writtenwork"),

            "ag": ("world", "sports", "business", "science"),
            "yelp": ("terrible", "bad", "middle", "good", "wonderful"),
            "yahoo": ("society and culture", "science", "health", "education and reference",
                      "computers and internet", "sports", "business", "entertainment and music",
                      "family and relationships", "politics and government"),
            "dbpedia": ("company", "educationalinstitution", "artist", "athlete", "officeholder",
                        "meanoftransportation", "building", "naturalplace", "village", "animal",
                        "plant", "album", "film", "writtenwork"),
            "amazon": ("terrible", "bad", "middle", "good", "wonderful"),
        }

        self.task = task
        self.label_key = 'label'
        if 'yahoo_' in task: self.label_key = 'topic'
        if 'stsb' in task: self.label_key = 'similarity_score'
        if task=='record': self.label_key = 'answers'


    # Helper function to save idx of multirc questions (needed later for test metric computation)
    def save_multirc_questions_idx(self, val_ds):
        idx = []
        i = 0
        x_prev, y_prev= val_ds['paragraph'][0], val_ds['question'][0]

        for x,y in zip(val_ds['paragraph'], val_ds['question']):
            if x_prev!=x or y_prev!=y:
                i += 1
            x_prev = x
            y_prev = y
            idx.append(i)
        self.multirc_idx = np.array(idx)


    # Helper function to select a subset of k samples per class in a dataset
    def select_subset_ds(self, ds, k=2000, seed=0):
        if self.task in ['stsb', 'record', 'wsc']: # non-discrete labels
            idx_total = np.random.choice(np.arange(ds.shape[0]), min(k,ds.shape[0]), replace=False)

        else:
            label_key = self.label_key
            N = len(ds[label_key])
            idx_total = np.array([], dtype='int64')

            for l in set(ds[label_key]):
                idx = np.where(np.array(ds[label_key]) == l)[0]
                idx_total = np.concatenate([idx_total, # we cannot take more samples than there are available
                                            np.random.choice(idx, min(k, idx.shape[0]), replace=False)])

        np.random.seed(seed)
        np.random.shuffle(idx_total)
        return ds.select(idx_total)


    # WSC task function to preprocess raw input & label text into tokenized dictionary
    def process_wsc(self, wsc_row):
        text_proc = wsc_row['text'].split(' ')
        #text_proc[wsc_row['span1_index']] = '*' + text_proc[wsc_row['span1_index']] +'*'
        target = text_proc[wsc_row['span1_index']]
        text_proc[wsc_row['span2_index']] = '*' + text_proc[wsc_row['span2_index']] + '*'
        text_proc = (' ').join(text_proc)
        return text_proc, target


    # Function to preprocess raw input & label text into tokenized dictionary
    def preprocess_function(self, examples, task,
                            max_length=512, max_length_target=2,
                            prefix_list=[]):
        tokenizer = self.tokenizer
        keys = self.task_to_keys[task]
        label_key = self.label_key

        if keys[1]!=None:
            if task=='record':
                text = 'passage : ' + str(examples['passage']) + ' query: ' + str(examples['query']) + ' entities: ' + ('; ').join((examples['entities']))
            elif task=='wsc':
                text, target = self.process_wsc(examples)
            else:
                text = ''
                for key in keys:
                    text += key + ': ' + str(examples[key]) + ' '
        else:
            text = examples[keys[0]]

        if len(prefix_list)>0:
            text = (' ').join(prefix_list) + ' ' + text
        source = tokenizer(text.strip()+' </s>',
                          truncation=True,
                          #padding=False,
                          padding='max_length',
                          max_length=max_length)

        if task=='stsb':
            target = str(examples[label_key])[:3]
        elif task=='record':
            target = '; '.join(examples[label_key])
        elif task=='wsc':
            pass # already obtained target
        else:
            target = self.task_to_labels[task][examples[label_key]]
        target += ' </s>'
        target = tokenizer(
                  target, max_length=max_length_target, pad_to_max_length=True, #return_tensors="pt"
                )

        dict_final = {"source_ids": source['input_ids'],
                      "source_mask": source['attention_mask'],
                      "target_ids": target['input_ids'],
                      "target_mask": target['attention_mask']}
        return dict_final



    def get_final_ds(self,
                     task,
                     split,
                     batch_size,
                     k=-1,
                     seed=0,
                     return_test=False,
                     target_len=2,
                     max_length=512,
                     prefix_list=[]):
        """Function that returns final T5 dataloader.
        Args:
            task (str): Name of the downstream task.
            split (str): Which data split to use (train/validation/test).
            batch_size (int): Batch size to use in the dataloader.
            k (int, optional): Number of samples to use for each class. Defaults to -1, not sub-sample the data.
            seed (int, optional): Seed used for random shuffle. Defaults to 0.
            return_test (bool, optional): Whether to create a test split.
                When True, two Dataloaders are returned. Defaults to False.
            target_len (int, optional): Length of the model output (in tokens). Defaults to 2.
            max_length (int, optional): Length of the model input (in tokens). Defaults to 512.
            prefix_list (List[str], optional): List of prompt virtual tokens to pre-pend to the input.
                We do not encode soft prompt as extra virtual tokens in the latest implementation.
                Defaults to [], empty list.

        Returns:
            Dataloader: Torch Dataloader with preprocessed input text & label.
        """

        if task in ['amazon']: # amazon not available with hugging face
            df = pd.read_csv('../datasets/src/data/'+task+'/'+split+'.csv', header=None)
            df = df.rename(columns={0: "label", 1: "title", 2: "content"})
            df['label'] = df['label'] - 1
            dataset = datasets.Dataset.from_pandas(df)
        elif task == 'mnli':
            dataset = load_dataset('LysandreJik/glue-mnli-train', split=split)
        elif task == 'qnli':
            dataset = load_dataset('SetFit/qnli', split=split)
        elif task == 'stsb':
            dataset = load_dataset('stsb_multi_mt', name='en', split=split if split=='train' else 'dev')
        else:
            if task not in self.glue_datasets and task not in self.superglue_datasets:
                dataset = load_dataset(task, split=split)
            else:
                benchmark = 'glue' if task not in self.superglue_datasets else 'super_glue'
                dataset = load_dataset(benchmark,
                                       task.replace('_superglue', '').replace('_bool', ''),
                                       split=split)

        # For yahoo dataset we need to filter out empty rows
        # (i.e. where "question" field is empty)
        if self.task == "yahoo_answers_topics":
            if split=='train':
                good_id = np.load('good_id_yahoo_train.npy')
                dataset = dataset.select(good_id)
            elif split=='test':
                good_id = np.load('good_id_yahoo_test.npy')
                dataset = dataset.select(good_id)

        # Using Lester et al. setting for WSC task, e.g.
        # using only positive samples (for output generation)
        if self.task == 'wsc':
            idx = np.where(np.array(dataset['label']) == 1)[0]
            dataset = dataset.select(idx)

        # Selecting k subset of the samples (if requested)
        if k!=-1:
            dataset = self.select_subset_ds(dataset, k=k)

        if k==-1 and split!='train' and self.task=='multirc':
            # we do not shuffle full validation set of multirc
            # but we save idx of the same questions
            # which are used for multirc test metric computation
            self.save_multirc_questions_idx(dataset)
        else:
            dataset = dataset.shuffle(seed=seed)

        # Returning the selected data split (train/val/test)
        if return_test==False:
            encoded_dataset = dataset.map(lambda x: self.preprocess_function(x, task,
                                                                            max_length=max_length,
                                                                            max_length_target=target_len,
                                                                            prefix_list=prefix_list),
                                          batched=False)
            encoded_dataset.set_format(type='torch', columns=['source_ids', 'source_mask',
                                                              'target_ids', 'target_mask'])
            dataloader = DataLoader(encoded_dataset, batch_size=batch_size)

            return dataloader

        # Creating an extra test set from the selected data split
        else:
            N = len(dataset)
            dataset_val = dataset.select(np.arange(0, N//2))
            dataset_test = dataset.select(np.arange(N//2, N))

            dataloaders_val_test = []
            for dataset in [dataset_val, dataset_test]:
                encoded_dataset = dataset.map(lambda x: self.preprocess_function(x, task,
                                                                                 max_length=max_length,
                                                                                 max_length_target=target_len,
                                                                                 prefix_list=prefix_list),
                                              batched=False)
                encoded_dataset.set_format(type='torch', columns=['source_ids', 'source_mask',
                                                                  'target_ids', 'target_mask'])
                dataloader = DataLoader(encoded_dataset, batch_size=batch_size)
                dataloaders_val_test.append(dataloader)

            return dataloaders_val_test

In [None]:
import torch
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import logging, os, argparse

from t5_continual import T5ContinualLearner


def main(args):
    save_path = os.path.join(args.save_dir, args.save_name)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    task_list = args.task_list

    model_name = args.model_name
    continual_learner = T5ContinualLearner(model_name,
                                           task_list,
                                           batch_size=args.batch_size,
                                           select_k_per_class=args.select_k_per_class,
                                           prefix_len=args.prefix_len,
                                           freeze_weights=args.freeze_weights==1,
                                           freeze_except=args.freeze_except,
                                           lr=args.lr,
                                           seq_len=args.seq_len,
                                           early_stopping=args.early_stopping==1,
                                           prefix_MLP=args.prefix_MLP,
                                           prefix_path=args.prefix_path if args.prefix_path!='' else None,
                                           mlp_layer_norm=args.mlp_layer_norm==1,
                                           bottleneck_size=args.bottleneck_size,
                                           get_test_subset=args.get_test_subset==1,
                                           memory_perc=args.memory_perc
                                           )
    if args.get_test_subset==0:
        print("Not creating test subset")

    if args.multitask == 1:
        print('Multi task learning')
        results_dict = continual_learner.multi_task_training(num_epochs=args.num_epochs, save_path=save_path)
        np.save(os.path.join(save_path, 'results_dict.npy'), results_dict)

    else:
        if args.num_epochs<=50:
            eval_every_N = 1
        elif args.num_epochs>50 and args.num_epochs<=200:
            eval_every_N = 5
        elif args.num_epochs>200:
            eval_every_N = 10

        results_dict = continual_learner.train_continual(continual_learner.task_list,
                                                        epochs=args.num_epochs,
                                                        save_path=save_path,
                                                        progressive=args.progressive==1,
                                                        eval_every_N=eval_every_N,
                                                        test_eval_after_every_task=args.test_eval_after_every_task==1,
                                                        data_replay_freq=args.data_replay_freq,
                                                        )
        np.save(os.path.join(save_path, 'results_dict.npy'), results_dict)
        np.save(os.path.join(save_path, 'prompts.npy'), continual_learner.previous_prompts.detach().cpu().numpy())




if __name__ == "__main__":
    parser = argparse.ArgumentParser(
      description='NLP training script in PyTorch'
    )

    parser.add_argument(
        '--save_dir',
        type=str,
        help='base directory of all models / features (should not be changed)',
        default='/data/home/arazdai/T5_prompts/T5_continual/' #'/scratch/hdd001/home/anastasia/CL/'
    )

    parser.add_argument(
        '--save_name',
        type=str,
        help='folder name to save',
        required=True
    )

    parser.add_argument(
        '--task_list',
        nargs='+',
        help='List of tasks for training',
        required=True
    )

    parser.add_argument(
        '--model_name',
        type=str,
        help='Name of the model used for training',
        default="t5-base"
    )

    parser.add_argument(
        '--num_epochs',
        type=int,
        help='Number of epochs to train model',
        default=5
    )

    parser.add_argument(
        '--multitask',
        type=int,
        help='Whether to perform multi-task training',
        default=0
    )

    parser.add_argument(
        '--batch_size',
        type=int,
        help='Batch size',
        default=8
    )

    parser.add_argument(
        '--seq_len',
        type=int,
        help='Length of a single repeat (in #tokens)',
        default=512
    )

    parser.add_argument(
        '--prefix_len',
        type=int,
        help='Length of prompt (in #tokens)',
        default=10
    )

    parser.add_argument(
        '--prefix_path',
        type=str,
        help='path to a pre-trained progressive prefix (for superGLUE experiments)',
        default=''
    )


    parser.add_argument(
        '--lr',
        type=float,
        help='Learning rate',
        default=0.3
    )


    parser.add_argument(
        '--memory_perc',
        type=float,
        help='Memory perc',
        default=0.01
    )

    parser.add_argument(
        '--data_replay_freq',
        type=float,
        help='Replay data every X iterations',
        default=-1
    )

    parser.add_argument(
        '--select_k_per_class',
        type=int,
        help='Select k examples from each class (default -1, i.e. no changes to the original dataset)',
        default=-1
    )

    parser.add_argument(
        '--test_eval_after_every_task',
        type=int,
        help='Whether to re-evaluate test accuracy after every task (0 - False, 1 - True)',
        default=0
    )

    parser.add_argument(
        '--progressive',
        type=int,
        help='Whether to concatenate prompts in a progressive way (0 - False, 1 - True)',
        default=1
    )

    parser.add_argument(
        '--freeze_weights',
        type=int,
        help='Whether to freeze model weigts (except word emb)',
        default=0
    )

    parser.add_argument(
        '--freeze_except',
        type=str,
        help='If freeze_weights==1, freeze all weights except those that contain this keyword',
        default='xxxxxxx' # freeze all
    )

    parser.add_argument(
        '--get_test_subset',
        type=int,
        help='Whether to create a separate test split',
        default=1
    )

    parser.add_argument(
        '--early_stopping',
        type=int,
        help='If early_stopping==1, do early stopping based on val accuracy',
        default=1 # freeze all
    )

    parser.add_argument(
        '--prefix_MLP',
        type=str,
        help='Type of MLP reparametrization (if None - use Lester original implementation)',
        default='None' # freeze all
    )

    parser.add_argument(
        '--mlp_layer_norm',
        type=int,
        help='Do layer norm in MLP',
        default=1 # use layer norm
    )

    parser.add_argument(
        '--bottleneck_size',
        type=int,
        help='MLP bottleneck size',
        default=800
    )

    main(parser.parse_args())
