In [None]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
from transformers import AutoTokenizer
import dataclasses
from torch.utils.data.dataloader import DataLoader
from transformers.training_args import is_tpu_available
from transformers.trainer import get_tpu_sampler
from transformers.data.data_collator import DataCollator, InputDataClass
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict
import pickle

In [None]:
# globals

TASKS = ['rte', 'wnli']

# modify this if you're encountering memory errors with your hardware setup
max_length = 340

# to run with BERT, swap the commented and uncommented lines below for model_class
# model_class = 'bert'
model_class = 'xlnet'
model_name = '%s-base-cased' % model_class

In [None]:
dataset_dict = {task : nlp.load_dataset('glue', name=task) for task in TASKS}

In [None]:
class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        """
        Setting MultitaskModel up as a PretrainedModel allows us
        to take better advantage of Trainer features
        """
        super().__init__(transformers.PretrainedConfig(max_length=max_length))

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict=None):
        """
        This creates a MultitaskModel using the model class and config objects
        from single-task models. 

        We do this by creating each single-task model, and having them share
        the same encoder transformer.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = getattr(model, model.base_model_prefix)                
            else:
                setattr(model, model.base_model_prefix
                        , shared_encoder)
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)


    def forward(self, task_name, **kwargs):
        return self.taskmodels_dict[task_name](**kwargs)

In [None]:
def convert_to_wnli_features(example_batch):
    #print(example_batch)
    inputs = list(zip(example_batch['sentence1'], example_batch['sentence2']))
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["label"]
    return features

def convert_to_rte_features(example_batch):
    inputs = list(zip(example_batch['sentence1'], example_batch['sentence2']))
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["label"]
    return features

In [None]:
class NLPDataCollator(DataCollator):
    """
    Extending the existing DataCollator to work with NLP dataset batches
    """
    def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        first = features[0]
        if isinstance(first, dict):
          # NLP data sets current works presents features as lists of dictionary
          # (one per example), so we  will adapt the collate_batch logic for that
            if "labels" in first and first["labels"] is not None:
                if first["labels"].dtype == torch.int64:
                    labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
                else:
                    labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
                batch = {"labels": labels}
            for k, v in first.items():
                if k != "labels" and v is not None and not isinstance(v, str):
                    batch[k] = torch.stack([f[k] for f in features])
            return batch
        else:
          # otherwise, revert to using the default collate_batch
          return DefaultDataCollator().collate_batch(features)


class StrIgnoreDevice(str):
    """
    This is a hack. The Trainer is going call .to(device) on every input
    value, but we need to pass in an additional `task_name` string.
    This prevents it from throwing an error
    """
    def to(self, device):
        return self


class DataLoaderWithTaskname:
    """
    Wrapper around a DataLoader to also yield a task name
    """
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.

        We use size-proportional sampling, but you could easily modify this
        to sample from some-other distribution.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    

class MultitaskTrainer(transformers.Trainer):

    def get_single_train_dataloader(self, task_name, train_dataset):
        """
        Create a single-task data loader that also yields task names
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        if is_tpu_available():
            train_sampler = get_tpu_sampler(train_dataset)
        else:
            train_sampler = (
                RandomSampler(train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(train_dataset)
            )

        data_loader = DataLoaderWithTaskname(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )

        if is_tpu_available():
            data_loader = pl.ParallelLoader(
                data_loader, [self.args.device]
            ).per_device_loader(self.args.device)
        return data_loader

    def get_train_dataloader(self):
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })

In [None]:
local_caller = locals()

convert_func_dict = {task : local_caller['convert_to_%s_features' % task] for task in TASKS}

In [None]:
columns_dict = {task : ['input_ids', 'attention_mask', 'labels'] for task in TASKS}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

train_dataset = {
task_name: dataset["train"] 
for task_name, dataset in features_dict.items()
}

In [None]:
from time import time
scores = {
    'rte' : {},
    'wnli': {}
}

# This will calculate for sequence length 8, 16, ..., 128
# Adjust loop range to fit your experiment and hardware constraints 
for i in range(3,8):
    max_length=2**i

    multitask_model = MultitaskModel.create(
        model_name=model_name,
        model_type_dict={
            "rte": transformers.AutoModelForSequenceClassification,
            "wnli": transformers.AutoModelForSequenceClassification
        },
        model_config_dict={
            "rte": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
            "wnli": transformers.AutoConfig.from_pretrained(model_name, num_labels=2)
        }
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,do_lower_case=True)
    
    trainer = MultitaskTrainer(
        model=multitask_model,
        args=transformers.TrainingArguments(
            output_dir="./models/mtl_%s_model_%i" % (model_class, i),
            overwrite_output_dir=True,
            learning_rate=5e-5,
            do_train=True,
            num_train_epochs=3,
            # Adjust batch size if this doesn't fit on the Colab GPU
            per_device_train_batch_size=8,  
            save_steps=3000,
            logging_steps=100,
            logging_dir='./mtl_%s_logs_%i' % (model_class, i)
        ),
        data_collator=NLPDataCollator(),
        train_dataset=train_dataset,
    )
    print('Training for max_length=',2**i)
    start=time()
    trainer.train()
    print('Training time=',time()-start)
    preds_dict = {}
    for task_name in TASKS:
        eval_dataloader = DataLoaderWithTaskname(
            task_name,
            trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["validation"])
        )
        print(eval_dataloader.data_loader.collate_fn)
        preds_dict[task_name] = trainer._prediction_loop(
            eval_dataloader, 
            description=f"Validation: {task_name}",
        )
     
    scores['rte'][i] = nlp.load_metric('glue', name="rte").compute(
        np.argmax(preds_dict["rte"].predictions, axis=1),
        preds_dict["rte"].label_ids
    )
    
    scores['wnli'][i] = nlp.load_metric('glue', name="wnli").compute(
        np.argmax(preds_dict["wnli"].predictions, axis=1),
        preds_dict["wnli"].label_ids
    ) 

In [None]:
scores

In [None]:
with open('seq-length-%s.pkl' % model_class, 'wb') as file :
    pickle.dump(scores, file)