In [43]:
import itertools
import os

import abc
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass

from enum import Enum
from tqdm import tqdm
from typing import Any, Dict, List, Optional, Tuple, Set, Union, cast

import numpy as np
import torch
from torch import Tensor
from datasets import Dataset, load_dataset
from lightning import Fabric
from lightning.fabric.strategies import DDPStrategy
from lightning.pytorch.callbacks import DeviceStatsMonitor, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.bert.modeling_bert import BertLayer
from transformers.models.bert import BertForSequenceClassification
from transformers import AutoTokenizer, BertTokenizer, PreTrainedTokenizer

from _common import RESULTS_DIR, logging
from fusionlib.merge.task_arithmetic import task_arithmetic_merge_modules
from src.module.dict_moe import ParetoWeightEnsemblingModule
from src.module.utils import print_trainable_parameters
from src.phn.solvers import EPOSolver
from src.utils import timeit_context

In [44]:
# Configure log.
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
log.addHandler(handler)

In [None]:
cfg = DictConfig(
    {
        "model": "bert-base-uncased",
        "version": None,
        "num_devices": 1,
        "tasks": {"task1": 1, "task2": 1},
        "partial": True,
        "init_lambda": 0.6,
        "router_hidden_layers": 1,
        "batch_size": 1,
        "train": True,
        "lr": 1e-2,
        "num_steps": 1000,
        "alpha": 1,
        "save_interval": 500,
    }
)

In [46]:
RESULTS_DIR

WindowsPath('c:/Users/Admin/OneDrive/Documents/DANC/source_code/pareto_set_learning/results')

In [47]:
tokenizer = BertTokenizer.from_pretrained(cfg.model)

In [48]:
class Mode(Enum):
    train = "train"
    dev = "dev"
    test = "test"
    inference = "inference"

In [49]:
def load_imdb_dataset(split: Mode) -> Tuple[Dict[str, List[str]], List[str]]:
    """
    Load IMDB dataset.
    :param split: Train or Test split.
    :return: Dataset in dictionary format and list of all labels of the dataset.
    """
    log.info(f"Load IMDB dataset from {split} split")
    # TODO: We convert the dataset into format {"text": [], "labels": []} which was used commonly. This step can require
    #  a large MEM as the dataset is duplicated.
    imdb_dataset = load_dataset("imdb")[split.value]
    output_dataset: Dict[str, List[str]] = {"text": [], "labels": []}
    all_labels: Set[str] = set()
    for sample in tqdm(imdb_dataset):
        output_dataset["text"].append(sample["text"])
        label = str(sample["label"])
        output_dataset["labels"].append(label)
        all_labels.add(label)
    log.info(f"Loaded dataset with {len(all_labels)} labels")
    return output_dataset, list(all_labels)


dataset_mapping = {
    "task1": load_imdb_dataset(Mode.train),
    "task2": load_imdb_dataset(Mode.train),
}

2025-04-15 22:50:22,698 - __main__ - INFO - Load IMDB dataset from Mode.train split
2025-04-15 22:50:22,698 - __main__ - INFO - Load IMDB dataset from Mode.train split


100%|██████████| 25000/25000 [00:01<00:00, 23733.73it/s]
2025-04-15 22:50:37,584 - __main__ - INFO - Loaded dataset with 2 labels
2025-04-15 22:50:37,584 - __main__ - INFO - Loaded dataset with 2 labels
2025-04-15 22:50:37,590 - __main__ - INFO - Load IMDB dataset from Mode.train split
2025-04-15 22:50:37,590 - __main__ - INFO - Load IMDB dataset from Mode.train split
100%|██████████| 25000/25000 [00:00<00:00, 26374.38it/s]
2025-04-15 22:50:50,663 - __main__ - INFO - Loaded dataset with 2 labels
2025-04-15 22:50:50,663 - __main__ - INFO - Loaded dataset with 2 labels


In [8]:
@dataclass
class InputExample:
    uid: str


@dataclass
class TextClassificationExample(InputExample):
    doc_tokens: List[str]
    label: Optional[str] = None
    positions: Optional[List[List[int]]] = None


@dataclass
class InputFeatures:
    input_ids: List[int]
    attention_mask: Optional[List[int]] = None
    token_type_ids: Optional[List[int]] = None
    positions: Optional[List[List[int]]] = None


@dataclass
class TextClassificationFeatures(InputFeatures):
    label: Optional[int] = None


class Processor(ABC):
    def __init__(
        self,
        tokenizer: Union[PreTrainedTokenizer, str],
        max_seq_len: int,
        label_list: List[str],
        **kwargs,
    ):
        if isinstance(tokenizer, PreTrainedTokenizer):
            self.tokenizer = tokenizer
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                tokenizer, **kwargs, use_fast=False
            )
        self.max_seq_len = max_seq_len
        self.label_list = label_list

    @abc.abstractclassmethod
    def convert_examples_to_features(
        self, examples: List[InputExample]
    ) -> List[InputFeatures]:
        """Generate input features from examples"""
        raise NotImplementedError()

    @abc.abstractclassmethod
    def features_to_dataset(
        self, features: List[InputFeatures], mode: Union[str, Mode]
    ) -> Dataset:
        """Get Pytorch Dataset object from list of input features"""
        raise NotImplementedError()


def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

In [51]:
class TextClassificationProcessor(Processor):
    def __init__(
        self,
        tokenizer: Union[PreTrainedTokenizer, str],
        max_seq_len: int,
        label_list: List[str],
        multilabel: bool = False,
        quotechar: str = '"',
        skiprows: int = 1,
        **kwargs,
    ):
        super().__init__(
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            label_list=label_list,
            **kwargs,
        )
        self.quotechar = quotechar
        self.skiprows = skiprows
        self.multilabel = multilabel

    def get_examples(self, training_data: Dict) -> List[TextClassificationExample]:
        """
        convert training data to bert examples
        :param training_data: dict{"text": text, "labels", labels}
        """
        assert len(training_data["text"]) == len(training_data["labels"]), (
            f"{len(training_data['text'])} text and {len(training_data['labels'])} labels"
        )
        examples = []
        is_contain_position = "positions" in training_data
        for ii in tqdm(range(len(training_data["text"]))):
            label = training_data["labels"][ii]
            assert label in self.label_list, (
                f"Non exist label: '{label}' in label list: {self.label_list}."
            )
            context_text = training_data["text"][ii]
            # List of tokens of the doc.
            doc_tokens: List[str] = []
            # Mapping between position of the character to the word position.
            char_to_word_offset = []
            prev_is_whitespace = True
            for cc in context_text:
                if is_whitespace(cc):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(cc)
                    else:
                        doc_tokens[-1] += cc
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            if is_contain_position:
                assert len(training_data["positions"][ii]) == len(doc_tokens)
                positions = training_data["positions"][ii]
            else:
                positions = [[0, 0, 0, 0]] * len(doc_tokens)
            examples.append(
                TextClassificationExample(
                    uid=f"{ii}", doc_tokens=doc_tokens, label=label, positions=positions
                )
            )
        return examples

    def _convert_example_to_feature(
        self, example: TextClassificationExample, label_to_idx: Dict[str, int]
    ) -> TextClassificationFeatures:
        all_positions = []
        all_doc_tokens = []
        for ii, token in enumerate(example.doc_tokens):
            sub_tokens = self.tokenizer.tokenize(token)
            all_doc_tokens.extend(sub_tokens)
            all_positions.extend([example.positions[ii]] * len(sub_tokens))
        encoded_dict = self.tokenizer.encode_plus(
            all_doc_tokens,
            padding="max_length",
            truncation=True,
            max_length=self.max_seq_len,
            return_token_type_ids=True,
        )
        position_pad = [0, 0, 0, 0]
        if len(all_positions) <= self.max_seq_len - 2:
            positions = [position_pad] + all_positions
        else:
            positions = [position_pad] + all_positions[: self.max_seq_len - 2]

        if len(positions) < self.max_seq_len:
            positions += [position_pad] * (self.max_seq_len - len(positions))
        encoded_dict["positions"] = positions
        encoded_dict["label"] = label_to_idx[example.label]
        return TextClassificationFeatures(**encoded_dict)

    def convert_examples_to_features(
        self, examples: List[TextClassificationExample]
    ) -> List[TextClassificationFeatures]:
        """Generate text classification features from examples"""
        label_to_idx = {label: ii for ii, label in enumerate(self.label_list)}
        features: List[TextClassificationFeatures] = []
        for ii in tqdm(range(len(examples))):
            features.append(
                self._convert_example_to_feature(examples[ii], label_to_idx)
            )
        return features

    def features_to_dataset(
        self,
        features: List[TextClassificationFeatures],
        mode: Union[str, Mode] = Mode.train,
    ) -> TensorDataset:
        """Get Pytorch Dataset object from list of classification features"""
        if isinstance(mode, Mode):
            mode = mode.value
        dataset = [
            torch.tensor([f.input_ids for f in features], dtype=torch.long),
            torch.tensor([f.attention_mask for f in features], dtype=torch.long),
            torch.tensor([f.token_type_ids for f in features], dtype=torch.long),
            torch.tensor([f.positions for f in features], dtype=torch.long),
            torch.tensor([f.label for f in features], dtype=torch.long),
        ]

        return TensorDataset(*dataset)

In [52]:
def get_dataloader(data: Dict[str, Any], labels: List[str]) -> DataLoader:
    dataprocessor = TextClassificationProcessor(
        tokenizer=tokenizer, max_seq_len=512, label_list=labels
    )
    log.info("Load examples")
    examples = dataprocessor.get_examples(data)
    log.info("Convert examples to features")
    features = dataprocessor.convert_examples_to_features(examples)
    log.info("Construct dataset")
    dataset = dataprocessor.features_to_dataset(features)
    dataloader = DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        shuffle=False if cfg.num_devices > 1 else True,
        sampler=(
            DistributedSampler(dataset, shuffle=True) if cfg.num_devices > 1 else None
        ),
    )
    return dataloader

In [53]:
train_loaders = {
    task_name: get_dataloader(*data)
    for task_name, data in dataset_mapping.items()
}

2025-04-15 22:52:19,309 - __main__ - INFO - Load examples
2025-04-15 22:52:19,309 - __main__ - INFO - Load examples
100%|██████████| 25000/25000 [00:20<00:00, 1208.22it/s]
2025-04-15 22:52:40,007 - __main__ - INFO - Convert examples to features
2025-04-15 22:52:40,007 - __main__ - INFO - Convert examples to features
100%|██████████| 25000/25000 [05:58<00:00, 69.77it/s] 
2025-04-15 22:58:38,323 - __main__ - INFO - Construct dataset
2025-04-15 22:58:38,323 - __main__ - INFO - Construct dataset
2025-04-15 22:58:47,665 - __main__ - INFO - Load examples
2025-04-15 22:58:47,665 - __main__ - INFO - Load examples
100%|██████████| 25000/25000 [00:13<00:00, 1847.51it/s]
2025-04-15 22:59:01,208 - __main__ - INFO - Convert examples to features
2025-04-15 22:59:01,208 - __main__ - INFO - Convert examples to features
100%|██████████| 25000/25000 [06:40<00:00, 62.35it/s] 
2025-04-15 23:05:42,189 - __main__ - INFO - Construct dataset
2025-04-15 23:05:42,189 - __main__ - INFO - Construct dataset


In [55]:
task_num_labels_mapping = {
    task_name: len(data[1]) for task_name, data in dataset_mapping.items()
}

In [56]:
task_num_labels_mapping

{'task1': 2, 'task2': 2}

In [57]:
for task_name, data in dataset_mapping.items():
    break

In [61]:
pretrained_model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=list(task_num_labels_mapping.values())[0]
)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [67]:
task_num_labels_mapping.values()

dict_values([2, 2])

In [None]:
pretrained_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [69]:
finetuned_models = dict()
for task in cfg.tasks:
    log.info(f"Loading finetuned model for task {task}")
    finetuned_models[task] = BertForSequenceClassification.from_pretrained(
        cfg.model, num_labels=task_num_labels_mapping[task]
    )

# Store the finetuned backbone.
finetuned_backbone: Dict[str, nn.Module] = {
    task: cast(BertLayer, model.bert)
    for task, model in finetuned_models.items()
}

2025-04-15 23:38:47,722 - __main__ - INFO - Loading finetuned model for task task1
2025-04-15 23:38:47,722 - __main__ - INFO - Loading finetuned model for task task1


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-04-15 23:38:48,671 - __main__ - INFO - Loading finetuned model for task task2
2025-04-15 23:38:48,671 - __main__ - INFO - Loading finetuned model for task task2
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
pretrained_backbone = pretrained_model.bert
model: nn.Module = task_arithmetic_merge_modules(
    pretrained_backbone,
    list(finetuned_backbone.values()),
    scaling_coef=cfg.init_lambda,
)

In [85]:
cast(BertLayer, model.encoder.layer[1].output.dense)

Linear(in_features=3072, out_features=768, bias=True)