In [266]:
import os

In [267]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [268]:
%pwd

'/Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project'

In [269]:
# changing working directory to the root of the project:/Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project
os.chdir("/Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project")

# Purpose of the Notebook

In this notebook I create the pipeline for using the model built in the MAGPIE Repository in order to understand its components and test how the elements work together.
Besides, I already look into the parts I want to change, like more detailed logging / debugging steps to better understand the process. Apart from that I will use the code provided by the Repository.

In order to understand the core elements of the model architecture and pipeline, I will only display the main parts here and import utils and other functions.

# Data Ingestion and Preprocessing
Since the repository already provides the datasets in a preprocessed way, I will use these files for the model training according to the data sets I chose (see README file).
Nevertheless, I will need to include data preprocessing in order to do inference on new data. Therefore, the following part tests the preporcessing of random text input and the tokenization of the text.

In [270]:
from transformers import DistilBertTokenizerFast
from media_bias_detection.utils.logger import general_logger


class Tokenizer:
    """Singleton class to maintain a single tokenizer instance."""

    _instance = None
    _tokenizer = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(Tokenizer, cls).__new__(cls)
            cls._tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        return cls._instance

    def __len__(self):
        """Return vocabulary size"""
        return len(self._tokenizer)  # Add this method

    def __call__(self, *args, **kwargs):
        """Make the tokenizer callable directly."""
        return self._tokenizer(*args, **kwargs)

    @property
    def tokenizer(self):
        return self._tokenizer


# Global tokenizer instance
tokenizer = Tokenizer()

In [271]:
text = "This is a test"
tokenized = tokenizer(text, truncation=True, return_tensors="pt")
print(tokenized)

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 3231,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}


# Data Initialization with Task and Subtask Classes
Since the MTL approach is about combining different tasks, the data needs to contain the information bout which model head it needs in the model pipeline.
I will use the Task and Subtask classes from the MAGPIE Repository to create the tasks I want to use. The tasks are defined in the following part.

The Sub Tasks I am going to use according to the datasets I chose are:
- Token-Level Classification (POS) --> noch besser verstehen was das genau ist
- Binary Classification
- Multi-Class Classification
- (Regression) - not used in current implementation
- (Masked Language Modelling) - not used in current implementation

The Subtask class defines how to load and structure the respective data set, as well as other functions like weight scaling and class weights for imbalanced datasets

The Task class is a wrapper for the subtasks and contains the task id and the subtasks list. Since I am using only one dataset for each subtask, the subtask list contains only one subtask. Nevertheless, I am leaving the wrapper in the code in order to use the other steps in the same way as in the repository.

The SubTaskDataset class creates then the actual data loaders and also contains the BatchList class for Training and for Evaluation.



In [272]:
import os
from typing import List, Tuple, Optional, Dict
import pandas as pd
import torch
import numpy as np
import re
from pathlib import Path

from media_bias_detection.utils.common import get_class_weights
from media_bias_detection.utils.enums import Split
from media_bias_detection.utils.logger import general_logger
from media_bias_detection.config.config import DEV_RATIO, MAX_LENGTH, REGRESSION_SCALAR, TRAIN_RATIO


In [273]:
class DataProcessingError(Exception):
    """Custom exception for data processing errors."""
    pass

In [274]:
"""This part contains the Task class."""

class Task:
    """Wrap subtasks."""

    def __init__(self, task_id, subtasks_list):
        """Initialize a Task."""
        self.task_id = task_id
        self.subtasks_list = subtasks_list

    def __repr__(self):
        """Represent a task."""
        return (
            f"Task {self.task_id} with {len(self.subtasks_list)} subtask{'s' if len(self.subtasks_list) > 1 else ''}"
        )

    def __str__(self) -> str:
        return str(self.task_id)

In [275]:


def get_pos_idxs(pos: str, text: str):
    """
    Get the correct idxs of the pos for a given text.

    @param pos: A pattern as text.
    @param text: The text to search trough.
    @return: The ids of the tokens in the text that match the pattern.
    """
    if pos == text:
        mask = np.array(np.ones((len(text))), dtype="int")
    else:
        pos = pos.replace("[", "\[")
        pos = pos.replace("$", "\$")
        pos = pos.replace("?", "\?")
        pos = pos.replace(")", "\)")
        pos = pos.replace("(", "\(")
        pos = pos.replace("*", "\*")
        pos = pos.replace("+", "\+")
        start, end = re.search(pos, text).span()

        mask = np.zeros((len(text)), dtype=int)
        mask[start:end] = 1
    c, idx_list = 0, []
    for t in text.split():
        idx_list.append(c)
        c += len(t) + 1
    mask_idxs = [mask[i] for i in idx_list]
    return mask_idxs


def align_labels_with_tokens(labels: List[int], word_ids: List[int]):
    """Align labels with tokens.

    C/p from https://huggingface.co/course/chapter7/2
    """
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = (
                -100 if word_id is None else labels[word_id]
            )  # -100 is an index that will be ignored by cross entropy
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels


def get_tokens_and_labels(pos_list_list, text_list, labels):
    """Get tokens and labels for scattered POS.

    In this objective, we have a list of consecutive spans.
    For each of these consecutive spans, find the correct index of the corresponding tokens in the text_list.
    Returns the bitwise or ('union') of this ids.
    """
    mask_idxs_list = []
    for i, pos_list in enumerate(pos_list_list):
        label = labels[i]
        text = text_list[i]
        observation_mask_idxs = []
        for pos in pos_list:
            if len(pos) == 0:
                # If there is no POS, we just return zeros
                observation_mask_idxs.append(get_pos_idxs("", text))
            else:
                for pos in pos_list:
                    if label == 0:  # In that case, the label is the neutral class
                        observation_mask_idxs.append(get_pos_idxs(pos, text))
                    else:
                        pos_idxs = get_pos_idxs(pos, text)
                        pos_idxs = [label if idx == 1 else 0 for idx in pos_idxs]
                        observation_mask_idxs.append(pos_idxs)

        # reduce observation_mask_idxs
        observation_mask_idxs = np.bitwise_or.reduce(observation_mask_idxs, axis=0)
        mask_idxs_list.append(observation_mask_idxs)

    return [t.split() for t in text_list], mask_idxs_list

In [276]:
"""This part contains the Subtask."""
class SubTask:
    """Base class for all subtasks.

    Attributes:
        id: Unique identifier for the subtask
        task_id: ID of parent task
        filename: Path to data file
        src_col: Column name for input text
        tgt_cols_list: List of target column names
    """

    def __init__(
            self,
            id: int,
            task_id: int,
            filename: str,
            src_col: str = "text",
            tgt_cols_list: List[str] = ["label"],
            cache_dir: Optional[str] = None
    ):
        if type(self) == SubTask:
            raise RuntimeError("Abstract class <SubTask> must not be instantiated.")

        self.id = id
        self.task_id = task_id
        self.src_col = src_col
        self.tgt_cols_list = tgt_cols_list
        self.filename = Path(os.path.join("datasets", filename))
        self.cache_dir = Path(cache_dir) if cache_dir else None

        # Data attributes
        self.attention_masks: Optional[Dict[Split, torch.Tensor]] = None
        self.X: Optional[Dict[Split, torch.Tensor]] = None
        self.Y: Optional[Dict[Split, torch.Tensor]] = None
        self.class_weights: Optional[torch.Tensor] = None
        self.processed = False

        general_logger.info(
            f"Initialized SubTask {id} for task {task_id} "
            f"using file {self.filename}"
        )

    def process(self, force_download: bool = False) -> None:
        """Process and split the data.

        Args:
            force_download: Whether to force data reprocessing

        Raises:
            DataProcessingError: If data processing fails
        """
        try:
            # Check cache first
            if self.cache_dir and not force_download:
                if self._load_from_cache():
                    return

            general_logger.info(f"Processing SubTask {self.id}")
            X, Y, attention_masks = self.load_data()

            # Validate data
            if not (len(X) == len(Y) == len(attention_masks)):
                raise DataProcessingError("Mismatched lengths in processed data")

            # Split data
            train_split = int(len(X) * TRAIN_RATIO)
            dev_split = train_split + int(len(X) * DEV_RATIO)

            self.X = {Split.TRAIN: X[:train_split], Split.DEV: X[train_split:dev_split],
                      Split.TEST: X[dev_split:]}

            self.attention_masks = {
                Split.TRAIN: attention_masks[:train_split],
                Split.DEV: attention_masks[train_split:dev_split],
                Split.TEST: attention_masks[dev_split:],
            }
            self.Y = {Split.TRAIN: Y[:train_split], Split.DEV: Y[train_split:dev_split],
                      Split.TEST: Y[dev_split:]}

            self.create_class_weights()
            self._save_to_cache()

            self.processed = True
            general_logger.info(
                f"SubTask {self.id} processed successfully. "
                f"Splits: Train={len(self.X[Split.TRAIN])}, "
                f"Dev={len(self.X[Split.DEV])}, "
                f"Test={len(self.X[Split.TEST])}"
            )

        except Exception as e:
            raise DataProcessingError(f"Failed to process subtask {self.id}: {str(e)}")

    def _load_from_cache(self) -> bool:
        """Try to load processed data from cache."""
        if not self.cache_dir:
            return False

        cache_file = self.cache_dir / f"subtask_{self.id}.pt"
        if cache_file.exists():
            try:
                cached_data = torch.load(cache_file)
                self.X = cached_data['X']
                self.Y = cached_data['Y']
                self.attention_masks = cached_data['attention_masks']
                self.class_weights = cached_data.get('class_weights')
                self.processed = True
                general_logger.info(f"Loaded cached data for SubTask {self.id}")
                return True
            except Exception as e:
                general_logger.warning(f"Failed to load cache for SubTask {self.id}: {e}")
                return False
        return False

    def _save_to_cache(self) -> None:
        """Save processed data to cache."""
        if not self.cache_dir:
            return

        self.cache_dir.mkdir(parents=True, exist_ok=True)
        cache_file = self.cache_dir / f"subtask_{self.id}.pt"

        try:
            torch.save({
                'X': self.X,
                'Y': self.Y,
                'attention_masks': self.attention_masks,
                'class_weights': self.class_weights
            }, cache_file)
            general_logger.info(f"Saved cache for SubTask {self.id}")
        except Exception as e:
            general_logger.warning(f"Failed to save cache for SubTask {self.id}: {e}")

    # Abstract methods
    def load_data(self) -> Tuple:
        """Load the data of a SubTask.

        Must be implemented for inherited.
        """
        raise NotImplementedError

    def create_class_weights(self):
        """Compute the weights for imbalanced classes."""
        pass

    def get_scaling_weight(self):
        """Get the scaling weight of a Subtask.

        Needs to be overwritten.
        """
        raise NotImplementedError

    def get_X(self, split: Split):
        """Get all X of a given split."""
        return self.X[split]

    def get_att_mask(self, split: Split):
        """Get attention_masks for inputs of a given split."""
        return self.attention_masks[split]

    def get_Y(self, split: Split):
        """Get all Y of a given split."""
        return self.Y[split]

    def __str__(self) -> str:
        return str(self.id)

# a[43485:43500]
class ClassificationSubTask(SubTask):
    """A ClassificationSubTask."""

    def __init__(self, num_classes=2, *args, **kwargs):
        """Initialize a ClassificationSubTask."""
        super(ClassificationSubTask, self).__init__(*args, **kwargs)
        self.num_classes = num_classes

    def load_data(self) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
        """Load the data of a ClassificationSubTask."""
        df = pd.read_csv(self.filename)

        X, Y = df[self.src_col], df[self.tgt_cols_list]
        tokenized_inputs = tokenizer(X.to_list(), padding="max_length", truncation=True,
                                     max_length=MAX_LENGTH)
        X = tokenized_inputs.get("input_ids")
        attention_masks = tokenized_inputs.get("attention_mask")
        assert Y.nunique().squeeze() == self.num_classes
        assert Y[self.tgt_cols_list[0]].min(axis=0) == 0
        if self.num_classes == 2:  # if it's binary classification
            Y = Y.to_numpy()
        else:
            Y = Y[self.tgt_cols_list].to_numpy()
        return torch.LongTensor(X), torch.LongTensor(Y), torch.LongTensor(attention_masks)

    def __repr__(self):
        """Represent a Classification Subtask."""
        return f"{'Multi-class' if self.num_classes != 2 else 'Binary'} Classification"

    def create_class_weights(self):
        """Compute the weights."""
        self.class_weights = get_class_weights(self.Y[Split.TRAIN], method="isns")

    def get_scaling_weight(self):
        """Get the weight of a Classification Subtask.

        As with the other tasks, we normalize by the natural logarithm of the domain size.
        """
        return 1 / np.log(self.num_classes)


# in current implementation, the regression subtask is not used
class RegressionSubTask(SubTask):
    """A RegressionSubTask."""

    def __init__(self, *args, **kwargs):
        """Initialize a RegressionSubTask."""
        super(RegressionSubTask, self).__init__(*args, **kwargs)

    def load_data(self) -> Tuple[torch.LongTensor, torch.FloatTensor, torch.LongTensor]:
        """Load the data of a RegressionSubTask."""
        df = pd.read_csv(self.filename)
        X, Y = df[self.src_col], df[self.tgt_cols_list]
        tokenized_inputs = tokenizer(X.to_list(), padding="max_length", truncation=True,
                                     max_length=MAX_LENGTH)
        X = tokenized_inputs.get("input_ids")
        attention_masks = tokenized_inputs.get("attention_mask")
        Y = (((Y - Y.min()) / (Y.max() - Y.min())).to_numpy()).astype("float32")  # scale from 0 to 1
        return torch.LongTensor(X), torch.FloatTensor(Y), torch.LongTensor(attention_masks)

    def __repr__(self):
        """Represent a Regression Subtask."""
        return "Regression"

    def get_scaling_weight(self):
        """Get the scaling weight of a Regression Subtask.

        As of now, this scaling weight is a simple scalar and is a mere heuristic-based approximation (ie. we eyeballed it).
        """
        return REGRESSION_SCALAR


class MultiLabelClassificationSubTask(SubTask):
    """A MultiLabelClassificationSubTask."""

    def __init__(self, num_classes=2, num_labels=2, *args, **kwargs):
        """Initialize a MultiLabelClassificationSubTask."""
        super(MultiLabelClassificationSubTask, self).__init__(*args, **kwargs)
        self.num_classes = None
        self.num_classes = num_classes
        self.num_labels = num_labels
        print(f"MultiClass Subtask {self.id}:\nNum classes: {num_classes}, Num labels: {num_labels}")


    def load_data(self) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
        """Load the data of a MultiLabelClassificationSubTask."""
        print(f"Loading data for MultiLabelClassificationSubTask {self.id}")
        df = pd.read_csv(self.filename)
    
        #X = df[self.src_col].tolist()
        #Y = df[self.tgt_cols_list].values
        X, Y = df[self.src_col], df[self.tgt_cols_list]
        
        print(f"X type: {type(X)}")
        print(f"Y type: {type(Y)}")
        print(f"X shape: {len(X)}")
        print(f"Y shape: {Y.shape}")
    
        tokenized_inputs = tokenizer(X.tolist(), padding="max_length", truncation=True,
                                     max_length=MAX_LENGTH)
        X = torch.LongTensor(tokenized_inputs.get("input_ids"))
        attention_masks = torch.LongTensor(tokenized_inputs.get("attention_mask"))
        assert Y.max(axis=0).to_numpy().max() == 1
        Y = Y.to_numpy()
        Y = torch.LongTensor(Y)
    
        print(f"X shape: {X.shape}")
        print(f"Y shape: {Y.shape}")
        print(f"Attention masks shape: {attention_masks.shape}")
    
        return X, Y, attention_masks


    def __repr__(self):
        """Represent a Multi-label Classification Subtask."""
        return "Multi-label Classification"

    def get_scaling_weight(self):
        """Get the weight of a Multi-label Classification Subtask.

        As with the other tasks, we normalize by the natural logarithm of the domain size.
        """
        return 1 / np.log(self.num_classes * self.num_labels)


class POSSubTask(SubTask):
    """A POSSubTask.

    Each POSSubTask can be either binary classification or multiclass classification.
    If it is binary classification, zero (0) must be the neutral class.
    This neutral class is also applied to all other, 'normal' tokens.
    """

    def __init__(self, tgt_cols_list, label_col=None, *args, **kwargs):
        """Initialize a POSSubTask.

        Normally, we have 3 classes: (0=no-tag, 1=tag-start, 2=tag-continue)
        However, we have POS-tasks where we have more than just 'binary token level classification'.
        In these scenarios, each class has two tags: 'tag-start' and 'tag-continue'.
        The 'no-class' tag has no 'tag-continue'.
        """
        self.num_classes = 3  # The default num_classes is 2 or 3 (0=no-tag, 1=tag-start, 2=tag-continue)
        self.label_col = label_col
        assert len(tgt_cols_list) == 1
        super(POSSubTask, self).__init__(tgt_cols_list=tgt_cols_list, *args, **kwargs)

    def load_data(self) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
        """Load the data of a POSSubTask."""
        df = pd.read_csv(self.filename)

        df[self.tgt_cols_list] = df[self.tgt_cols_list].fillna("")
        mask = df.apply(
            lambda row: all([p in row[self.src_col] for p in row[self.tgt_cols_list[0]].split(";")]), axis=1
        )
        df = df[mask].reset_index(drop=True)
        assert sum(mask) == len(df[self.tgt_cols_list]), "At least one POS is not contained in the source column."

        pos_list_list = df[self.tgt_cols_list[0]].apply(lambda x: x.split(";")).to_list()
        X = df[self.src_col].values
        # If we do not provide a labels column, we assume that, whenever a pos is present, that is the non-neutral class
        labels = (
            df[self.label_col]
            if self.label_col
            else [1 if len(pos) > 0 else 0 for pos in df[self.tgt_cols_list[0]].to_list()]
        )
        tokens, labels = get_tokens_and_labels(pos_list_list=pos_list_list, text_list=X, labels=labels)
        tokenized_inputs = tokenizer(
            tokens, padding="max_length", is_split_into_words=True, truncation=True,
            max_length=MAX_LENGTH
        )
        new_labels = []
        for i, labels in enumerate(labels):
            word_ids = tokenized_inputs.word_ids(i)
            new_labels.append(align_labels_with_tokens(labels, word_ids))
        Y = np.array(new_labels)
        # This should in most cases not alter self.num_classes, as we only use binary tags (+ tag-continue = 3 classes).
        # However, we leave this generic implementation for future tasks.
        self.num_classes = len(np.unique(Y)) - 1
        X = tokenized_inputs.get("input_ids")
        attention_masks = tokenized_inputs.get("attention_mask")
        return torch.LongTensor(X), torch.LongTensor(Y), torch.LongTensor(attention_masks)

    def __repr__(self):
        """Represent a Token-level classification Subtask."""
        return "Token-level classification"

    def create_class_weights(self):
        """Compute the weights."""
        labels = self.Y[Split.TRAIN]
        only_class_labels = labels[labels != -100]
        self.class_weights = get_class_weights(only_class_labels, method="isns")

    def get_scaling_weight(self):
        """Get the weight of a POS Subtask.

        As with the other tasks, we normalize by the natural logarithm of the domain size.
        In case of POS subtask, the domain size equals the vocab size.
        """
        return 1 / np.log(self.num_classes)

# not used in current implementation, but important for testing?
class MLMSubTask(SubTask):
    """A Masked Language Modelling Subtask."""

    def __init__(self, *args, **kwargs):
        """Initialize a MLMSubTask."""
        super(MLMSubTask, self).__init__(*args, **kwargs)

    def load_data(self) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
        """Load the data of a MLMSubTask."""
        df = pd.read_csv(self.filename)
        X = df[self.src_col]
        tokenized_inputs = tokenizer(X.to_list(), padding="max_length", truncation=True, max_length=MAX_LENGTH)
        X = torch.LongTensor(tokenized_inputs.get("input_ids"))
        attention_masks = tokenized_inputs.get("attention_mask")

        MASK_TOKEN = tokenizer.mask_token_id
        SEP_TOKEN = tokenizer.sep_token_id
        CLS_TOKEN = tokenizer.cls_token_id
        PAD_TOKEN = tokenizer.pad_token_id

        Y = X.clone()
        rand = torch.rand(X.shape)
        masking_mask = (rand < 0.15) * (X != SEP_TOKEN) * (X != CLS_TOKEN) * (X != PAD_TOKEN)
        X[masking_mask] = MASK_TOKEN
        Y[~masking_mask] = -100
        return torch.LongTensor(X), torch.LongTensor(Y), torch.LongTensor(attention_masks)

    def __repr__(self):
        """Represent a MLM Subtask."""
        return "Masked Language Modelling"

    def get_scaling_weight(self):
        """Get the weights for imbalanced classes."""
        return 1 / np.log(len(tokenizer))

In [277]:
"""Dataset handling module for MTL model.

This module provides dataset classes for handling different types of data loading
and batch generation for the MTL training process.
"""

from typing import List, Dict, Iterator, Tuple, Optional
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from dataclasses import dataclass
from collections import defaultdict

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.enums import Split
from media_bias_detection.utils.common import set_random_seed

@dataclass
class BatchData:
    """Container for batch data.

    Attributes:
        input_ids: Token IDs from tokenizer
        attention_mask: Attention mask for padding
        labels: Target labels
        subtask_id: ID of the subtask this batch belongs to
    """
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    labels: torch.Tensor
    subtask_id: int


class SubTaskDataset(Dataset):
    """Dataset class for a single SubTask.

    This dataset handles the loading and iteration over data for a specific subtask,
    with support for shuffling and automatic reset.

    Attributes:
        split: The data split (TRAIN/DEV/TEST)
        subtask: The subtask this dataset is for
        observations: List of indices into the data
        _counter: Internal counter for iteration
        cache: Optional cache for frequently accessed items
    """

    def __init__(
            self,
            subtask: SubTask,
            split: Split,
            cache_size: int = 100
    ) -> None:
        """Initialize the dataset.

        Args:
            subtask: SubTask instance containing the data
            split: Which data split to use
            cache_size: Number of items to keep in memory cache
        """
        general_logger.info(f"Initializing dataset for subtask {subtask.id} with split {split}")

        if not subtask.processed:
            raise RuntimeError(f"Subtask {subtask.id} must be processed before creating dataset")

        self.split = split
        self.subtask = subtask
        self.observations: List[int] = []
        self._counter: int = 0
        self._cache: Dict[int, BatchData] = {}
        self._cache_size = cache_size
        self._reset()

    def __len__(self) -> int:
        """Get number of items in dataset."""
        return len(self.observations)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
        """Get a single item from the dataset.

        Args:
            idx: Index of the item to get

        Returns:
            BatchData containing the item data

        Raises:
            IndexError: If index is out of bounds
        """
        try:
            if self._counter >= len(self.observations):
                self._reset()

            i = self.observations[self._counter]

            # Check cache first
            if i in self._cache:
                self._counter += 1
                return self._cache[i]

            # Load and process item
            x = self.subtask.get_X(split=self.split)[i]
            masks = self.subtask.get_att_mask(split=self.split)[i]
            y = self.subtask.get_Y(split=self.split)[i]

            batch_data = BatchData(
                input_ids=x,
                attention_mask=masks,
                labels=y,
                subtask_id=self.subtask.id
            )

            # Update cache
            if len(self._cache) >= self._cache_size:
                # Remove oldest item
                del self._cache[next(iter(self._cache))]
            self._cache[i] = batch_data

            self._counter += 1
            return x, masks, y, self.subtask.id

        except Exception as e:
            general_logger.error(f"Error retrieving item {idx} from dataset: {str(e)}")
            raise

    def _reset(self) -> None:
        """Reset the dataset state and shuffle observations."""
        general_logger.info(f"Resetting dataset for subtask {self.subtask.id}")
        self.observations = [i for i in range(len(self.subtask.get_X(split=self.split)))]
        set_random_seed()
        np.random.shuffle(self.observations)  # Not a real 'reshuffling' as it will always arrange same.
        self._counter = 0
        self._cache.clear()


class BatchList:
    """Wrapper around dataloaders for continuous batch generation.

    This class provides an infinite stream of batches by automatically resetting
    exhausted dataloaders. It includes support for dynamic batch sizing and
    memory-efficient data loading.

    Attributes:
        sub_batch_size: Size of each sub-batch
        datasets: Mapping of subtask IDs to datasets
        dataloaders: Mapping of subtask IDs to dataloaders
        iter_dataloaders: Mapping of subtask IDs to dataloader iterators
    """

    def __init__(
            self,
            subtask_list: List[SubTask],
            sub_batch_size: int,
            split: Split = Split.TRAIN,
            num_workers: int = 0,
            pin_memory: bool = True
    ) -> None:
        """Initialize BatchList.

        Args:
            subtask_list: List of subtasks to create batches for
            sub_batch_size: Size of each sub-batch
            split: Which data split to use
            num_workers: Number of worker processes for data loading
            pin_memory: Whether to pin memory in GPU training
        """
        general_logger.info(
            f"Creating BatchList with {len(subtask_list)} subtasks, "
            f"batch size {sub_batch_size}"
        )

        self.sub_batch_size = sub_batch_size
        self.split = split

        # Initialize datasets and dataloaders
        self.datasets = {
            str(st.id): SubTaskDataset(subtask=st, split=split)
            for st in subtask_list
        }

        self.dataloaders = {
            st_id: DataLoader(
                dataset,
                batch_size=self.sub_batch_size,
                num_workers=num_workers,
                pin_memory=pin_memory
            )
            for st_id, dataset in self.datasets.items()
        }

        self.iter_dataloaders = {
            st_id: iter(dl)
            for st_id, dl in self.dataloaders.items()
        }

        # Statistics tracking
        self._batch_counts = defaultdict(int)

    def __next__(self) -> List[BatchData]:
        """Get next batch of sub-batches.

        Returns:
            List of BatchData, one for each task

        Raises:
            RuntimeError: If batch generation fails
        """
        try:
            data = []
            items = list(self.iter_dataloaders.items())
            random.shuffle(items)

            for st_id, dl in items:
                try:
                    batch = next(dl)
                except StopIteration:
                    # Reset iterator and try again
                    self.iter_dataloaders[st_id] = iter(self.dataloaders[st_id])
                    batch = next(self.iter_dataloaders[st_id])

                data.append(batch)
                self._batch_counts[st_id] += 1

            general_logger.debug(f"Generated batch with {len(data)} sub-batches")
            return data

        except Exception as e:
            general_logger.error(f"Error generating batch: {str(e)}")
            raise RuntimeError(f"Batch generation failed: {str(e)}")

    def _reset(self):
        """Reset this BatchListEvalTest."""
        self.iter_dataloaders = {f"{st_id}": iter(dl) for st_id, dl in self.dataloaders.items()}


class BatchListEvalTest:
    """A BatchListEvalTest is a wrapper around dataloaders for each subtask."""

    def __init__(self, subtask_list: List[SubTask], sub_batch_size, split=Split.TRAIN):
        self.sub_batch_size = sub_batch_size
        self.datasets = {f"{st.id}": SubTaskDataset(subtask=st, split=split) for st in subtask_list}
        self.dataloaders = {
            f"{st_id}": DataLoader(ds, batch_size=self.sub_batch_size) for st_id, ds in self.datasets.items()
        }
        self.iter_dataloaders = {f"{st_id}": iter(dl) for st_id, dl in self.dataloaders.items()}

    def __len__(self):
        return min(len(dl) for dl in self.dataloaders.values())

    def _reset(self): # Add this method matching the original
        self.iter_dataloaders = {f"{st_id}": iter(dl) for st_id, dl in self.dataloaders.items()}

In [278]:
"""Data initialization and task definitions for MTL model."""

from typing import List
import itertools

from media_bias_detection.utils.logger import general_logger


# initializing the sub-tasks I want to use
st_1_cw_hard_03 = ClassificationSubTask(
task_id=3,
filename="03_CW_HARD/preprocessed.csv",
id=300001)
st_1_me_too_ma_108 = MultiLabelClassificationSubTask(
num_classes=2,
num_labels=2,
task_id=108,
filename="108_MeTooMA/preprocessed.csv",
id=10801,
tgt_cols_list=["hate_speech_label", "sarcasm_label"],
)
st_1_mdgender_116 = ClassificationSubTask(
task_id=116,
id=11601,
filename="116_MDGender/preprocessed.csv",
num_classes=6
)
st_1_mpqa_103 = ClassificationSubTask(
task_id=103,
id=10301,
filename="103_MPQA/preprocessed.csv")
st_1_stereotype_109 = ClassificationSubTask(
task_id=109,
id=10901,
filename="109_stereotype/preprocessed.csv")
st_2_stereotype_109 = MultiLabelClassificationSubTask(
task_id=109,
id=10902,
filename="109_stereotype/preprocessed.csv",
tgt_cols_list=["stereotype_explicit_label", "stereotype_explicit_label"],
num_classes=2,
num_labels=2,
)
st_1_good_news_everyone_42 = POSSubTask(
tgt_cols_list=["cue_pos"],
task_id=42,
id=42001,
filename="42_GoodNewsEveryone/preprocessed.csv"
)
st_2_good_news_everyone_42 = POSSubTask(
tgt_cols_list=["experiencer_pos"],
task_id=42,
id=42002,
filename="42_GoodNewsEveryone/preprocessed.csv",
)
st_1_pheme_12 = ClassificationSubTask(
task_id=12,
id=12001,
filename="12_PHEME/preprocessed.csv")
st_2_pheme_12 = ClassificationSubTask(
task_id=12,
id=12002,
filename="12_PHEME/preprocessed.csv",
tgt_cols_list=["veracity_label"],
num_classes=3,
)
st_1_babe_10 = ClassificationSubTask(
task_id=10,
id=10001,
filename="10_BABE/preprocessed.csv",
num_classes=2)
st_2_babe_10 = POSSubTask(
task_id=10,
id=10002,
filename="10_BABE/preprocessed.csv",
tgt_cols_list=["biased_words"])
st_1_gwsd_128 = ClassificationSubTask(
task_id=128,
num_classes=3,
filename="128_GWSD/preprocessed.csv",
id=12801)

# Tasks
cw_hard_03 = Task(task_id=3, subtasks_list=[st_1_cw_hard_03])
babe_10 = Task(task_id=10, subtasks_list=[st_1_babe_10, st_2_babe_10])
me_too_ma_108 = Task(task_id=108, subtasks_list=[st_1_me_too_ma_108])
mdgender_116 = Task(task_id=116, subtasks_list=[st_1_mdgender_116])
pheme_12 = Task(task_id=12, subtasks_list=[st_2_pheme_12, st_1_pheme_12])
mpqa_103 = Task(task_id=103, subtasks_list=[st_1_mpqa_103])
stereotype_109 = Task(task_id=109, subtasks_list=[st_1_stereotype_109,
                                              st_2_stereotype_109])
good_news_everyone_42 = Task(task_id=42,
                         subtasks_list=[st_1_good_news_everyone_42,
                                        st_2_good_news_everyone_42])
gwsd_128 = Task(task_id=128, subtasks_list=[st_1_gwsd_128])


# MBIB ###
# st_linguistic = ClassificationSubTask(task_id=11111, id=11111, filename="mbib_linguistic/preprocessed.csv", num_classes=2)
# mbib_lingustic = Task(task_id=11111, subtasks_list=[st_linguistic])

# Create task object
all_tasks = [
babe_10,
cw_hard_03,
me_too_ma_108,
pheme_12,
mdgender_116,
mpqa_103,
stereotype_109,
good_news_everyone_42,
gwsd_128,
]

# Get all subtasks
all_subtasks = list(itertools.chain.from_iterable(t.subtasks_list for t in all_tasks))

# Task families
media_bias = [babe_10]
subjective_bias = [cw_hard_03]
hate_speech = [me_too_ma_108]
gender_bias = [mdgender_116]
sentiment_analysis = [mpqa_103]
fake_news = [pheme_12]
group_bias = [stereotype_109]
emotionality = [good_news_everyone_42]
stance_detection = [gwsd_128]
#mlm = [mlm_0]

[2024-12-09 17:50:21,884: INFO: 1312152763: Initialized SubTask 300001 for task 3 using file datasets/03_CW_HARD/preprocessed.csv]
[2024-12-09 17:50:21,886: INFO: 1312152763: Initialized SubTask 10801 for task 108 using file datasets/108_MeTooMA/preprocessed.csv]
MultiClass Subtask 10801:
Num classes: 2, Num labels: 2
[2024-12-09 17:50:21,887: INFO: 1312152763: Initialized SubTask 11601 for task 116 using file datasets/116_MDGender/preprocessed.csv]
[2024-12-09 17:50:21,888: INFO: 1312152763: Initialized SubTask 10301 for task 103 using file datasets/103_MPQA/preprocessed.csv]
[2024-12-09 17:50:21,889: INFO: 1312152763: Initialized SubTask 10901 for task 109 using file datasets/109_stereotype/preprocessed.csv]
[2024-12-09 17:50:21,890: INFO: 1312152763: Initialized SubTask 10902 for task 109 using file datasets/109_stereotype/preprocessed.csv]
MultiClass Subtask 10902:
Num classes: 2, Num labels: 2
[2024-12-09 17:50:21,891: INFO: 1312152763: Initialized SubTask 42001 for task 42 using 

# Building the Model
After initializing the datasets I want to use in the last step, I now build the model:
- The backbone is changes to DistilBERT as it is a smaller model and therefore faster to train.
- For each task a specific model head is needed to fulfill the task. For this a head factory is used to decide which head to use for the specific task type.
- Apart from that, the model needs a GradsWrapper to get and set the gradients of the weights and biases of all trainable layers.
- In the model factory the model is then instantiated by combining the backbone with the different model heads for the different tasks.

In [279]:
# from helper classes the accumuator classes
"""Module for gradient accumulation and manipulation."""

from typing import Dict
import copy
import torch
from torch import nn

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.common import rsetattr


class AccumulatorError(Exception):
    """Custom exception for accumulator-related errors."""
    pass


class Accumulator:
    """Abstract Accumulator for gradient handling."""
    
    def __init__(self):
        """Initialize abstract accumulator."""
        if type(self) == Accumulator:
            raise RuntimeError("Abstract class <Accumulator> must not be instantiated.")
        self.gradients = None
        self.n = 0

    def update(self, gradients: Dict[str, torch.Tensor], weight: float = 1.0) -> None:
        """Update gradient values (must be implemented by concrete classes)."""
        raise NotImplementedError

    def get_avg_gradients(self) -> Dict[str, torch.Tensor]:
        """Return gradients normalized across 0-axis."""
        try:
            if not self.gradients:
                raise AccumulatorError("No gradients available")

            out_gradients = copy.deepcopy(self.gradients)
            for k, v in self.gradients.items():
                out_gradients[k] /= self.n
                out_gradients[k] = out_gradients[k].squeeze(dim=0)
            return out_gradients
        except Exception as e:
            raise AccumulatorError(f"Failed to get average gradients: {str(e)}")

    def get_gradients(self) -> Dict[str, torch.Tensor]:
        """Return raw gradients."""
        if not self.gradients:
            raise AccumulatorError("No gradients available")
        return self.gradients


class StackedAccumulator(Accumulator):
    """Accumulator that stacks gradients along 0-axis."""
    
    def __init__(self):
        """Initialize StackedAccumulator."""
        try:
            super().__init__()
            general_logger.debug("Initialized StackedAccumulator")
        except Exception as e:
            raise AccumulatorError(f"Failed to initialize StackedAccumulator: {str(e)}")

    def update(self, gradients: Dict[str, torch.Tensor], weight: float = 1.0) -> None:
        """Update by concatenating new gradients along 0-axis."""
        try:
            if not self.gradients:
                self.gradients = gradients
                # Unsqueeze all gradients for later concatenation
                for k, v in self.gradients.items():
                    self.gradients[k] = self.gradients[k].unsqueeze(dim=0) * weight
            else:
                for k, v in self.gradients.items():
                    new_value = gradients[k].unsqueeze(dim=0) * weight
                    self.gradients[k] = torch.cat((v, new_value), dim=0)
            self.n += 1
        except Exception as e:
            raise AccumulatorError(f"Failed to update stacked gradients: {str(e)}")

    def set_gradients(self, gradients: Dict[str, torch.Tensor]) -> None:
        """Set gradients directly."""
        try:
            for k, v in self.gradients.items():
                self.gradients[k] = gradients[k].unsqueeze(dim=0)
        except Exception as e:
            raise AccumulatorError(f"Failed to set gradients: {str(e)}")


class RunningSumAccumulator(Accumulator):
    """Accumulator that maintains running sum of gradients."""
    
    def __init__(self):
        """Initialize RunningSumAccumulator."""
        try:
            super().__init__()
            general_logger.debug("Initialized RunningSumAccumulator")
        except Exception as e:
            raise AccumulatorError(f"Failed to initialize RunningSumAccumulator: {str(e)}")

    def update(self, gradients: Dict[str, torch.Tensor], weight: float = 1.0) -> None:
        """Update by summing gradients along 0-axis."""
        try:
            if not self.gradients:
                self.gradients = gradients
                # Unsqueeze all gradients for later addition
                for k, v in self.gradients.items():
                    self.gradients[k] = self.gradients[k].unsqueeze(dim=0) * weight
            else:
                for k, v in self.gradients.items():
                    new_value = gradients[k].unsqueeze(dim=0) * weight
                    self.gradients[k] = torch.add(v, new_value)
            self.n += 1
        except Exception as e:
            raise AccumulatorError(f"Failed to update running sum gradients: {str(e)}")

In [280]:
class GradsWrapper(nn.Module):
    """Wrapper for getting/setting gradients of trainable layers."""
    
    def __init__(self, *args, **kwargs):
        """Initialize GradsWrapper."""
        if type(self) == GradsWrapper:
            raise RuntimeError("Abstract class <GradsWrapper> must not be instantiated.")
        super().__init__()
        general_logger.debug("Initialized GradsWrapper")

    def get_grads(self) -> Dict[str, torch.Tensor]:
        """Get gradients of weights and biases for all trainable layers."""
        try:
            return {
                k: v.grad.clone() if v.grad is not None else None 
                for k, v in dict(self.named_parameters()).items()
            }
        except Exception as e:
            raise AccumulatorError(f"Failed to get gradients: {str(e)}")

    def set_grads(self, grads: Dict[str, torch.Tensor]) -> None:
        """Set gradients of weights and biases for all trainable layers."""
        try:
            for k, v in grads.items():
                rsetattr(self, f"{k}.grad", v)
        except Exception as e:
            raise AccumulatorError(f"Failed to set gradients: {str(e)}")


In [281]:
"""Module for combining potentially conflicting gradients with error handling."""

import random
from typing import Dict, Optional
import torch

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.enums import AggregationMethod


class GradientError(Exception):
    """Custom exception for gradient-related errors."""
    pass


class GradientAggregator:
    """Aggregator class for combining possibly conflicting gradients into one 'optimal' grad."""

    def __init__(self, aggregation_method: AggregationMethod = AggregationMethod.MEAN):
        try:
            self.aggregation_method = aggregation_method
            self.accumulator = (
                RunningSumAccumulator() if aggregation_method == AggregationMethod.MEAN else StackedAccumulator()
            )
            self._conflicting_gradient_count = 0
            self._nonconflicting_gradient_count = 0
            general_logger.info(f"Initialized GradientAggregator with {aggregation_method}")
        except Exception as e:
            raise GradientError(f"Failed to initialize GradientAggregator: {str(e)}")

    def reset_accumulator(self) -> None:
        try:
            self.accumulator = (
                RunningSumAccumulator() if self.aggregation_method == AggregationMethod.MEAN else StackedAccumulator()
            )
            general_logger.debug("Reset gradient accumulator")
        except Exception as e:
            raise GradientError(f"Failed to reset accumulator: {str(e)}")

    def find_nonconflicting_grad(self, grad_tensor: torch.tensor) -> torch.tensor:
        try:
            if self.aggregation_method == AggregationMethod.PCGRAD:
                return self.pcgrad(grad_tensor).mean(dim=0)
            elif self.aggregation_method == AggregationMethod.PCGRAD_ONLINE:
                assert len(grad_tensor) == 2
                return self.pcgrad_online(grad_tensor)
            else:
                raise GradientError(f"Unsupported aggregation method: {self.aggregation_method}")
        except Exception as e:
            raise GradientError(f"Failed to find nonconflicting gradient: {str(e)}")

    def aggregate_gradients(self) -> Dict[str, torch.tensor]:
        try:
            conflicting_grads = self.accumulator.get_gradients()
            length = len(conflicting_grads[list(conflicting_grads.keys())[0]])

            if (self.aggregation_method == AggregationMethod.PCGRAD_ONLINE
                    or self.aggregation_method == AggregationMethod.MEAN):
                assert length == 1
                return self.accumulator.get_avg_gradients()

            elif self.aggregation_method == AggregationMethod.PCGRAD:
                conflicting_grads = [{k: v[i, ...] for k, v in conflicting_grads.items()} for i in range(length)]
                final_grad: Dict[str, torch.Tensor] = {}

                if len(conflicting_grads) == 1:
                    return conflicting_grads[0]

                keys = list(conflicting_grads[0].keys())
                for layer_key in keys:
                    list_of_st_grads = [st_grad[layer_key] for st_grad in conflicting_grads]
                    final_grad.update({layer_key: self.find_nonconflicting_grad(torch.stack(list_of_st_grads, dim=0))})

                return final_grad
            else:
                raise GradientError(f"Unsupported aggregation method: {self.aggregation_method}")
        except Exception as e:
            raise GradientError(f"Gradient aggregation failed: {str(e)}")

    def pcgrad(self, grad_tensor: torch.tensor) -> torch.tensor:
        try:
            pc_grads, num_of_tasks = grad_tensor.clone(), len(grad_tensor)
            original_shape = grad_tensor.shape
            
            pc_grads = pc_grads.view(num_of_tasks, -1)
            grad_tensor = grad_tensor.view(num_of_tasks, -1)

            for g_i in range(num_of_tasks):
                task_index = list(range(num_of_tasks))
                random.shuffle(task_index)
                for g_j in task_index:
                    dot_product = pc_grads[g_i].dot(grad_tensor[g_j])
                    if dot_product < 0:
                        pc_grads[g_i] -= (dot_product / (grad_tensor[g_j].norm() ** 2)) * grad_tensor[g_j]
                        self._conflicting_gradient_count += 1
                    else:
                        self._nonconflicting_gradient_count += 1
            return pc_grads.view(original_shape)
        except Exception as e:
            raise GradientError(f"PCGrad processing failed: {str(e)}")

    def pcgrad_online(self, grad_tensor: torch.tensor) -> torch.tensor:
        try:
            assert len(grad_tensor) == 2
            p = grad_tensor[0].view(-1)
            g = grad_tensor[-1].view(-1)

            dot_product = p.dot(g)
            if dot_product < 0:
                p = p - (dot_product / (g.norm() ** 2)) * g
                self._conflicting_gradient_count += 1
            else:
                self._nonconflicting_gradient_count += 1

            p += g
            return p.view(grad_tensor[0].shape)
        except Exception as e:
            raise GradientError(f"Online PCGrad processing failed: {str(e)}")

    def aggregate_gradients_online(self) -> Dict[str, torch.tensor]:
        try:
            conflicting_grads = self.accumulator.get_gradients()
            length = len(conflicting_grads[list(conflicting_grads.keys())[0]])
            conflicting_grads = [{k: v[i, ...] for k, v in conflicting_grads.items()} for i in range(length)]
            current_overall_grad: Dict[str, torch.Tensor] = {}

            if length == 1:
                return conflicting_grads[0]
            elif length == 2:
                keys = list(conflicting_grads[0].keys())
                for layer_key in keys:
                    list_of_st_grads = [st_grad[layer_key] for st_grad in conflicting_grads]
                    current_overall_grad.update(
                        {layer_key: self.find_nonconflicting_grad(torch.stack(list_of_st_grads, dim=0))}
                    )
                return current_overall_grad
            else:
                raise GradientError("Invalid gradient length for online aggregation")
        except Exception as e:
            raise GradientError(f"Online gradient aggregation failed: {str(e)}")

    def update(self, gradients: Dict[str, torch.tensor], scaling_weight: float) -> None:
        try:
            self.accumulator.update(gradients=gradients, weight=scaling_weight)
            if self.aggregation_method == AggregationMethod.PCGRAD_ONLINE:
                self.accumulator.set_gradients(gradients=self.aggregate_gradients_online())
        except Exception as e:
            raise GradientError(f"Failed to update gradients: {str(e)}")

    def get_conflicting_gradients_ratio(self) -> Optional[float]:
        try:
            if self.aggregation_method == AggregationMethod.MEAN:
                raise GradientError("Cannot get conflict ratio for MEAN method")
            if self._conflicting_gradient_count + self._nonconflicting_gradient_count == 0:
                raise GradientError("No gradients processed yet")
            return self._conflicting_gradient_count / (
                self._conflicting_gradient_count + self._nonconflicting_gradient_count
            )
        except Exception as e:
            raise GradientError(f"Failed to calculate gradient conflict ratio: {str(e)}")

In [282]:
"""Backbone model module providing the shared language model."""

import torch
from torch import nn
from transformers import DistilBertModel
from media_bias_detection.utils.logger import general_logger

class BackboneLM(GradsWrapper):
    """Language model backbone shared across all tasks.

    This class wraps the pretrained DistilBERT model and handles
    gradient manipulation for the shared parameters.

    Attributes:
        backbone: The underlying DistilBERT model
    """

    def __init__(self, pretrained_path: str = None):
        """Initialize the backbone model.

        Args:
            pretrained_path: Optional path to pretrained weights
        """
        super().__init__()

        try:
            general_logger.info("Initializing backbone language model")
            self.backbone = DistilBertModel.from_pretrained('distilbert-base-uncased')

            if pretrained_path:
                self.load_pretrained(pretrained_path)

        except Exception as e:
            general_logger.error(f"Failed to initialize backbone: {str(e)}")
            raise

    def load_pretrained(self, path: str) -> None:
        """Load pretrained weights.

        Args:
            path: Path to pretrained weights

        Raises:
            RuntimeError: If loading fails
        """
        try:
            state_dict = torch.load(path)
            self.backbone.load_state_dict(state_dict)
            general_logger.info(f"Loaded pretrained weights from {path}")
        except Exception as e:
            raise RuntimeError(f"Failed to load pretrained weights: {str(e)}")

In [283]:
"""Model heads implementation for MTL model.

This module contains all task-specific heads and the factory for creating them.
Each head implements specific logic for different types of tasks while maintaining
consistent interfaces for the MTL architecture.
"""

from typing import Dict, Tuple, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, MeanSquaredError, Perplexity, R2Score

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.common import get_class_weights


class HeadError(Exception):
    """Custom exception for head-related errors."""
    pass


def HeadFactory(st: SubTask, *args, **kwargs) -> 'BaseHead':
    """Create appropriate head based on subtask type.

    Args:
        st: Subtask to create head for
        *args, **kwargs: Additional arguments for head initialization

    Returns:
        Initialized head instance

    Raises:
        HeadError: If head creation fails or subtask type is unsupported
    """
    try:
        if isinstance(st, ClassificationSubTask):
            print(f"Creating ClassificationHead for subtask {st.id}")
            print(f"num_classes: {st.num_classes}")
            return ClassificationHead(
                num_classes=st.num_classes,
                class_weights=st.class_weights,
                *args,
                **kwargs
            )
        elif isinstance(st, MultiLabelClassificationSubTask):
            print(f"Creating MultiLabelClassificationHead for subtask {st.id}")
            print(f"num_classes: {st.num_classes}, num_labels: {st.num_labels}")
            return ClassificationHead(
                num_classes=st.num_classes,
                num_labels=st.num_labels if st.num_labels is not None else 2,
                class_weights=st.class_weights,
                *args,
                **kwargs
            )
        elif isinstance(st, POSSubTask):
            return TokenClassificationHead(
                num_classes=st.num_classes,
                class_weights=st.class_weights,
                *args,
                **kwargs
            )
        elif isinstance(st, RegressionSubTask):
            return RegressionHead(*args, **kwargs)
        elif isinstance(st, MLMSubTask):
            return LanguageModellingHead(*args, **kwargs)
        else:
            raise HeadError(f"Unsupported subtask type: {type(st)}")
    except Exception as e:
        raise HeadError(f"Head creation failed: {str(e)}")


class BaseHead(GradsWrapper):
    """Base class for all model heads.

    Attributes:
        metrics: Dictionary of metric names to metric instances
    """

    def __init__(self):
        super().__init__()
        self.metrics: Dict = {}

    def forward(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """Forward pass through the head.

        Args:
            X: Input features (batch_size, seq_len, hidden_dim)
            y: Target labels

        Returns:
            Tuple of (logits, loss, metric_values)
        """
        raise NotImplementedError


class ClassificationHead(BaseHead):
    def __init__(
            self,
            input_dimension: int,
            hidden_dimension: int,
            dropout_prob: float,
            num_classes: int = 2,
            num_labels: int = 1,
            class_weights: Optional[torch.Tensor] = None
    ):
        super().__init__()
        print(f"Initializing ClassificationHead")
        print(f"num_classes: {num_classes}")
        print(f"num_labels: {num_labels}")
        
        # Common layers
        self.dense = nn.Linear(input_dimension, hidden_dimension)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.out_proj = nn.Linear(hidden_dimension, num_classes * num_labels)
        
        # Store dimensions
        self.num_classes = num_classes
        self.num_labels = num_labels
        
        # Use CrossEntropyLoss for both cases
        self.loss = nn.CrossEntropyLoss(weight=class_weights)
        print(f"Initializing ClassificationHead with {num_labels} labels")
        # Set up metrics based on task type
        if num_labels > 1:  # Multi-label case
            self.metrics = {
                "f1": F1Score(
                    num_classes=num_classes,
                    num_labels=num_labels,
                    task="multilabel",
                    average="macro"
                ),
                "acc": Accuracy(
                    task="multilabel",
                    num_classes=num_classes,
                     num_labels=num_labels,
                ),
            }
        else:  # Regular classification case
            self.metrics = {
                "f1": F1Score(
                    num_classes=num_classes,
                    task="binary" if num_classes == 2 else "multiclass",
                    average="macro"
                ),
                "acc": Accuracy(
                    task="binary" if num_classes == 2 else "multiclass",
                    num_classes=num_classes,
                ),
            }

    def forward(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        try:
            batch_size = y.shape[0]
            print(f"Batch size: {batch_size}")
            print(f"X shape: {X.shape}")
            print(f"y shape: {y.shape}")
    
            # Get CLS token representation
            x = X[:, 0, :]  # take <s> token (equiv. to [CLS])
            
            # Pass through layers
            x = self.dropout(x)
            x = self.dense(x)
            x = torch.tanh(x)
            x = self.dropout(x)
            logits = self.out_proj(x)
    
            # Compute loss
            loss = self.loss(logits.view(-1, self.num_classes), y.view(-1))
            print(f"Logits shape: {logits.shape}")
            
            # Reshape logits based on task type
            if self.num_labels > 1:  # Multi-label case
                logits = logits.view(batch_size, self.num_labels, self.num_classes)
                y = y.view(batch_size, self.num_labels)
            else:  # Binary/multiclass case
                logits = logits.view(batch_size, self.num_classes)
                y = y.view(batch_size)  # Flatten targets
    
            # Compute loss
            loss = self.loss(logits, y)
            
            # Get predictions in correct shape for metrics
            predictions = torch.argmax(logits, dim=-1)  # Use last dimension for class prediction
            
            # Calculate metrics
            metrics = {
                name: metric(predictions.cpu(), y.cpu())
                for name, metric in self.metrics.items()
            }
    
            return logits, loss, metrics

        except Exception as e:
            raise HeadError(f"Classification forward pass failed: {str(e)}")


class TokenClassificationHead(BaseHead):
    """Head for token-level classification tasks.

    Attributes:
        dropout: Dropout layer
        classifier: Classification layer
        num_classes: Number of classes
        loss: Loss function
        metrics: Dictionary of metrics
    """

    def __init__(
            self,
            num_classes: int,
            class_weights: Optional[torch.Tensor],
            hidden_dimension: int,
            dropout_prob: float,
            *args,
            **kwargs
    ):
        super().__init__()

        self.dropout = nn.Dropout(p=dropout_prob)
        self.classifier = nn.Linear(hidden_dimension, num_classes)
        self.num_classes = num_classes
        self.loss = nn.CrossEntropyLoss(weight=class_weights)

        self.metrics = {
            "f1": F1Score(
                num_classes=num_classes,
                task="multiclass",
                average="macro"
            ),
            "acc": Accuracy(
                task="multiclass",
                num_classes=num_classes
            ),
        }

        general_logger.info(
            f"Initialized TokenClassificationHead with {num_classes} classes"
        )

    def forward(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        try:
            # Process sequence
            sequence_output = self.dropout(X)
            logits = self.classifier(sequence_output)

            # Compute loss
            loss = self.loss(logits.view(-1, self.num_classes), y.view(-1))

            # Mask padding tokens for metrics
            mask = torch.where(y != -100, 1, 0)
            logits = torch.masked_select(
                logits,
                mask.unsqueeze(-1).expand(logits.size()) == 1
            )
            y = torch.masked_select(y, mask == 1)
            logits = logits.view(y.shape[0], self.num_classes)

            # calculate metrics with predictions instead of logits
            predictions = torch.argmax(logits, dim=1)
            metrics = {
                name: metric(predictions.cpu(), y.cpu())
                for name, metric in self.metrics.items()
            }

            return logits, loss, metrics

        except Exception as e:
            raise HeadError(f"Token classification forward pass failed: {str(e)}")


class RegressionHead(BaseHead):
    """Head for regression tasks.

    Attributes:
        dense: Dense layer
        dropout: Dropout layer
        out_proj: Output projection layer
        loss: Loss function
        metrics: Dictionary of metrics
    """

    def __init__(
            self,
            input_dimension: int,
            hidden_dimension: int,
            dropout_prob: float
    ):
        super().__init__()

        self.dense = nn.Linear(input_dimension, hidden_dimension)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.out_proj = nn.Linear(hidden_dimension, 1)

        self.loss = nn.MSELoss()
        self.metrics = {
            "R2": R2Score(),
            "MSE": MeanSquaredError()
        }

        general_logger.info("Initialized RegressionHead")

    def forward(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        try:
            # Get CLS token
            x = X[:, 0, :]

            # Pass through layers
            x = self.dropout(x)
            x = self.dense(x)
            x = torch.tanh(x)
            x = self.dropout(x)
            logits = self.out_proj(x)

            loss = self.loss(logits.squeeze(), y.squeeze())

            metrics = {
                name: metric(logits.cpu(), y.cpu()).detach()
                for name, metric in self.metrics.items()
            }

            return logits, loss, metrics

        except Exception as e:
            raise HeadError(f"Regression forward pass failed: {str(e)}")


class LanguageModellingHead(BaseHead):
    """Head for masked language modeling tasks.

    Attributes:
        dense: Dense layer
        layer_norm: Layer normalization
        decoder: Output decoder
        loss: Loss function
        metrics: Dictionary of metrics
    """

    def __init__(
            self,
            input_dimension: int,
            hidden_dimension: int,
            dropout_prob: float
    ):
        super().__init__()

        self.dense = nn.Linear(input_dimension, hidden_dimension)
        self.layer_norm = nn.LayerNorm(hidden_dimension, eps=1e-5)
        self.gelu = nn.GELU()

        self.decoder = nn.Linear(hidden_dimension, tokenizer.vocab_size)
        self.bias = nn.Parameter(torch.zeros(tokenizer.vocab_size))
        self.decoder.bias = self.bias

        self.loss = nn.CrossEntropyLoss()
        self.metrics = {"perplexity": Perplexity()}

        general_logger.info("Initialized LanguageModellingHead")

    def forward(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        try:
            x = self.dense(X)
            x = self.gelu(x)
            x = self.layer_norm(x)

            logits = self.decoder(x)
            loss = self.loss(
                logits.view(-1, tokenizer.vocab_size),
                y.view(-1)
            )

            metrics = {
                name: metric(logits.cpu(), y.cpu())
                for name, metric in self.metrics.items()
            }

            return logits, loss, metrics

        except Exception as e:
            raise HeadError(f"Language modeling forward pass failed: {str(e)}")

In [284]:
"""Main MTL model implementation."""

from typing import Dict, Tuple, Optional
import torch
from torch import nn
from typing import List

from media_bias_detection.utils.logger import general_logger



class Model(nn.Module):
    """MTL model combining backbone and task-specific heads."""

    def __init__(self, stl: List, *args, **kwargs):
        """Initialize model with subtasks list and create task-specific heads.

        Args:
            stl: List of subtasks to create heads for
            *args: Additional positional arguments for heads
            **kwargs: Additional keyword arguments for heads
        """
        super().__init__()
        self.stl = stl
        self.subtask_id_to_subtask = {int(f"{st.id}"): st for st in stl}
        # Setup device
        self.device = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))

        # Initialize backbone
        self.language_model = BackboneLM()
        self.language_model.backbone.resize_token_embeddings(len(tokenizer))
        # Initialize heads
        self.heads = nn.ModuleDict({str(st.id): HeadFactory(st, *args, **kwargs) for st in stl})

        # Move model to device
        self.to(self.device)
        general_logger.info(f"Initialized model with {len(self.heads)} heads on {self.device}")

    def forward(self, X, attention_masks, Y, st_id):
        """Pass data through model and task-specific head.

        Args:
            X: Input tensor
            attention_masks: Attention mask tensor
            Y: Target tensor
            st_id: Subtask ID

        Returns:
            Tuple of (loss, metrics)
        """
        # Pass through backbone
        with torch.set_grad_enabled(self.training):
            x_enc = self.language_model.backbone(
                input_ids=X,
                attention_mask=attention_masks
            ).last_hidden_state

            # Pass through task-specific head
            head = self.heads[str(st_id.item())]
            logits, loss, metrics = head(x_enc, Y)

            return loss, metrics

In [285]:
#TO DO: muss mit restlichen modulen abgeglichen werden

"""Module for creating instantiating the appropriate model defined by the task list only."""

from typing import List

import torch

from media_bias_detection.utils.enums import Split

def ModelFactory(
        task_list: List,
        sub_batch_size: int,
        eval_batch_size: int,
        pretrained_path: str = None,
        *args,
        **kwargs
):
    """Create model and return it along with dataloaders."""
    # Get all subtasks from task list
    subtask_list = [st for t in task_list for st in t.subtasks_list]

    # Verify data is processed
    for st in subtask_list:
        assert st.processed, "Data must be loaded at this point."

    # Create model
    model = Model(stl=subtask_list, **kwargs)

    if pretrained_path is not None:
        model = load_pretrained_weights(model, pretrained_path=pretrained_path)

    # Move model to appropriate device
    model.to(model.device)

    try:
        # Create dataloaders with updated classes
        batch_list_train = BatchList(
            subtask_list=subtask_list,
            sub_batch_size=sub_batch_size,
            split=Split.TRAIN
        )
        
        batch_list_dev = BatchList(
            subtask_list=subtask_list, 
            sub_batch_size=eval_batch_size,
            split=Split.DEV
        )
        
        batch_list_eval = BatchListEvalTest(
            subtask_list=subtask_list,
            sub_batch_size=sub_batch_size, 
            split=Split.DEV
        )
        
        batch_list_test = BatchListEvalTest(
            subtask_list=subtask_list,
            sub_batch_size=sub_batch_size,
            split=Split.TEST
        )
        
        return model, batch_list_train, batch_list_dev, batch_list_eval, batch_list_test
        
    except Exception as e:
        general_logger.error(f"Failed to create model and dataloaders: {str(e)}")
        raise


def save_head_initializations(model):
    """Save weight initialization of the head. This method will not be called anymore.
     It's only for the initial saving of weight inits for all tasks."""
    for head_name in model.heads.keys():
        torch.save(model.heads[head_name].state_dict(), 'model_files/heads/' + head_name + '_init.pth')


def load_head_initializations(model):
    """Load fixed weight initialization for each head in order to ensure reproducibility."""
    for head_name in model.heads.keys():
        weights_path = 'model_files/heads/' + head_name + '_init.pth'
        head_weights = torch.load(weights_path)
        model.heads[head_name].load_state_dict(head_weights, strict=True)


def load_pretrained_weights(model, pretrained_path):
    """Load the weights of a pretrained model."""
    weight_dict = torch.load(pretrained_path)
    model.load_state_dict(weight_dict, strict=False)
    return model

# Training the Model

For the model training the MAGPIE repository first introduces some helper functions. Since they are specific to the training, I include them into the notebook, instead of using them as a separate module like the other utility functions.

In [286]:
"""Training utilities module for MTL training.

Contains utility classes for:
- Logging
- Early stopping

"""

import math
import os
import logging

from enum import Enum
from typing import Dict, List, Any, Optional

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.enums import Split

import wandb



class Logger:
    """Logger to keep track of metrics, losses and artifacts."""

    def __init__(self, experiment_name: str):
        PATH = "logging/" + experiment_name
        os.makedirs(PATH, exist_ok=True)

        self.experiment_logfilename = PATH + "/train_data.log"
        experiment_logfile_handler = logging.FileHandler(filename=self.experiment_logfilename)
        experiment_logfile_formatter = logging.Formatter(fmt="%(message)s")
        experiment_logfile_handler.setFormatter(experiment_logfile_formatter)

        self.experiment_logger = logging.getLogger("experiment_logger")
        self.experiment_logger.addHandler(experiment_logfile_handler)
        self.experiment_logger.setLevel("INFO")

    def log(self, out: Dict[str, Any]) -> None:
        try:
            self.experiment_logger.info(out)
            wandb.log(out)
        except Exception as e:
            print(f"Logging failed: {str(e)}")


class EarlyStoppingMode(Enum):
    """Mode for early stopping behavior."""
    HEADS = "heads"  # Only stop heads
    BACKBONE = "backbone"  # Also stop backbone
    NONE = "none"  # No early stopping


class EarlyStopperSingle:
    """Early stopping tracker for a single model component."""

    def __init__(
            self,
            patience: int,
            min_delta: float,
            resurrection: bool,
            zombie_patience: int = 10
    ):
        """Initialize early stopping tracker.

        Args:
            patience: How many epochs to wait before stopping
            min_delta: Minimum change to count as improvement
            resurrection: Whether to allow resurrection
            zombie_patience: Patience for zombie state
        """
        self.patience = patience
        self.patience_zombie = zombie_patience
        self.min_delta = min_delta
        self.counter = 0
        self.counter_zombie = 0
        self.min_dev_loss = float('inf')
        self.min_dev_loss_zombie = float('inf')
        self.resurrection = resurrection

    def early_stop(self, dev_loss: float) -> bool:
        """Check if training should stop.

        Args:
            dev_loss: Current validation loss

        Returns:
            Whether to stop training
        """
        if math.isnan(dev_loss):
            return False

        if dev_loss < self.min_dev_loss:
            self.min_dev_loss = dev_loss
            self.counter = 0
        elif dev_loss > (self.min_dev_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def resurrect(self, dev_loss: float) -> bool:
        """Check if training should resume.

        Args:
            dev_loss: Current validation loss

        Returns:
            Whether to resume training
        """
        if math.isnan(dev_loss) or not self.resurrection:
            return False

        if dev_loss < self.min_dev_loss_zombie:
            self.min_dev_loss_zombie = dev_loss
            self.counter_zombie = 0
        elif dev_loss > self.min_dev_loss_zombie:
            self.counter_zombie += 1
            if self.counter_zombie >= self.patience_zombie:
                return True
        return False

    def reset(self) -> None:
        """Reset early stopping state."""
        self.counter_zombie = 0
        self.counter = 0
        self.min_dev_loss_zombie = float('inf')
        self.min_dev_loss = float('inf')


class EarlyStopper:
    """Container for managing multiple early stoppers."""

    def __init__(
            self,
            st_ids: List[str],
            mode: EarlyStoppingMode,
            patience: Dict[str, int],
            resurrection: bool,
            min_delta: float = 0
    ):
        """Initialize early stopping manager.

        Args:
            st_ids: List of subtask IDs
            mode: Early stopping mode
            patience: Dictionary of patience values per subtask
            resurrection: Whether to allow resurrection
            min_delta: Minimum improvement threshold
        """
        self.mode = mode
        self.early_stoppers = {
            st_id: EarlyStopperSingle(
                patience=patience[st_id],
                min_delta=min_delta,
                resurrection=resurrection
            )
            for st_id in st_ids
        }
        general_logger.info(
            f"Initialized early stopping manager with mode {mode}"
        )

    def early_stop(self, st_id: str, dev_loss: float) -> bool:
        """Check if specific task should stop."""
        return (
            False if self.mode == EarlyStoppingMode.NONE
            else self.early_stoppers[st_id].early_stop(dev_loss=dev_loss)
        )

    def resurrect(self, st_id: str, dev_loss: float) -> bool:
        """Check if specific task should resurrect."""
        return (
            False if self.mode == EarlyStoppingMode.NONE
            else self.early_stoppers[st_id].resurrect(dev_loss=dev_loss)
        )

    def reset_early_stopper(self, st_id: str) -> None:
        """Reset early stopper for specific task."""
        self.early_stoppers[st_id].reset()

In [287]:
"""Metrics tracking and computation module.

This module provides classes and utilities for tracking metrics during training,
computing running averages, and managing metric history.
"""

from typing import Dict, List, Optional, Union, Any
import numpy as np
import json
from pathlib import Path

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.enums import Split

class MetricError(Exception):
    """Custom exception for metric-related errors."""
    pass


class AverageMeter:
    """Tracks running average of a metric.

    This class maintains a history of values and provides
    different methods for computing averages.

    Attributes:
        name: Name of the metric
        values: List of recorded values
    """

    def __init__(self, name: str):
        self.name = name
        self.values: List[float] = []

    def update(self, value: float) -> None:
        """Add a new value to history.

        Args:
            value: Value to add
        """
        try:
            self.values.append(float(value))
        except (TypeError, ValueError) as e:
            raise MetricError(f"Invalid value for metric {self.name}: {str(e)}")

    def mean_last_k(self, k: int = 10) -> float:
        """Calculate mean of last k values.

        Args:
            k: Number of last values to average

        Returns:
            Mean of last k values

        Raises:
            MetricError: If not enough values available
        """
        try:
            if k < 1:
                raise MetricError("k must be positive")
            if not self.values:
                raise MetricError("No values recorded")
            if len(self.values) < k:
                return float("nan")
            return float(np.mean(self.values[-k:]))
        except Exception as e:
            raise MetricError(f"Error computing mean_last_k: {str(e)}")

    def mean_all(self) -> float:
        """Calculate mean of all values.

        Returns:
            Mean of all values

        Raises:
            MetricError: If no values available
        """
        try:
            if not self.values:
                raise MetricError("No values recorded")
            return float(np.mean(self.values))
        except Exception as e:
            raise MetricError(f"Error computing mean_all: {str(e)}")

    def reset(self) -> None:
        """Clear all recorded values."""
        self.values.clear()

    def get_history(self) -> List[float]:
        """Get complete history of values.

        Returns:
            List of all recorded values
        """
        return self.values.copy()

    def __repr__(self) -> str:
        """String representation showing latest value."""
        return f"{self.mean_last_k(1):.4f}"


class Tracker:
    """Tracks metrics and losses across training.

    This class manages metrics and losses for different splits
    and tasks, providing logging and analysis capabilities.

    Attributes:
        metrics: Nested dictionary of metrics for each split/task
        losses: Nested dictionary of losses for each split/task
        combined_losses: Dictionary of combined losses per split
        logger: Logger instance
    """

    def __init__(self, heads: Dict, logger: Any):
        """Initialize tracker.

        Args:
            heads: Dictionary of model heads
            logger: Logger instance
        """
        try:
            self.metrics = self._init_metrics(heads)
            self.losses, self.combined_losses = self._init_losses(heads)
            self.logger = logger

            # Track best metrics
            self.best_metrics: Dict[str, float] = {}

            general_logger.info("Initialized metric tracker")

        except Exception as e:
            raise MetricError(f"Failed to initialize tracker: {str(e)}")

    def _init_metrics(self, heads: Dict) -> Dict:
        """Initialize metric tracking structures.

        Args:
            heads: Dictionary of model heads

        Returns:
            Initialized metrics dictionary
        """
        try:
            metrics = {}
            for split in Split:
                metrics[split] = {
                    st_id: {
                        m: AverageMeter(name=f"{st_id}_{split.value}_{m}")
                        for m in head.metrics.keys()
                    }
                    for st_id, head in heads.items()
                }
            return metrics

        except Exception as e:
            raise MetricError(f"Failed to initialize metrics: {str(e)}")

    def _init_losses(self, heads: Dict) -> tuple:
        """Initialize loss tracking structures.

        Args:
            heads: Dictionary of model heads

        Returns:
            Tuple of (loss_dict, combined_loss_dict)
        """
        try:
            # Task-specific losses
            losses = {}
            for split in Split:
                losses[split] = {
                    st_id: AverageMeter(name=f"{st_id}_{split.value}_loss")
                    for st_id in heads.keys()
                }

            # Combined losses
            combined_losses = {
                split: AverageMeter(name=f"combined_{split.value}_loss")
                for split in Split
            }

            return losses, combined_losses

        except Exception as e:
            raise MetricError(f"Failed to initialize losses: {str(e)}")

    def update_metric(
            self,
            split: Split,
            st_id: str,
            metric: str,
            value: float
    ) -> None:
        """Update a specific metric value.

        Args:
            split: Data split
            st_id: Subtask ID
            metric: Metric name
            value: New value
        """
        try:
            self.metrics[split][st_id][metric].update(value)

            # Track best metrics for validation
            if split == Split.DEV:
                metric_key = f"{st_id}_{metric}"
                current_value = value
                if metric_key not in self.best_metrics or current_value > self.best_metrics[metric_key]:
                    self.best_metrics[metric_key] = current_value

        except Exception as e:
            raise MetricError(f"Failed to update metric: {str(e)}")

    def update_loss(
            self,
            split: Split,
            st_id: str,
            value: float
    ) -> None:
        """Update a specific loss value.

        Args:
            split: Data split
            st_id: Subtask ID
            value: New loss value
        """
        try:
            self.losses[split][st_id].update(value)
        except Exception as e:
            raise MetricError(f"Failed to update loss: {str(e)}")

    def update_combined_loss(
            self,
            split: Split,
            value: float
    ) -> None:
        """Update combined loss for a split.

        Args:
            split: Data split
            value: New loss value
        """
        try:
            self.combined_losses[split].update(value)
        except Exception as e:
            raise MetricError(f"Failed to update combined loss: {str(e)}")

    def get_last_st_loss(
            self,
            split: Split,
            st_id: str,
            k: int
    ) -> float:
        """Get mean of last k loss values for a subtask.

        Args:
            split: Data split
            st_id: Subtask ID
            k: Number of values to average

        Returns:
            Mean loss value
        """
        try:
            return self.losses[split][st_id].mean_last_k(k=k)
        except Exception as e:
            raise MetricError(f"Failed to get subtask loss: {str(e)}")

    def get_last_st_metric(
            self,
            split: Split,
            st_id: str,
            k: int
    ) -> float:
        """Get mean of last k metric values for a subtask.

        Args:
            split: Data split
            st_id: Subtask ID
            k: Number of values to average

        Returns:
            Mean metric value
        """
        try:
            # Get first metric as representative
            first_metric = next(iter(self.metrics[split][st_id]))
            return self.metrics[split][st_id][first_metric].mean_last_k(k=k)
        except Exception as e:
            raise MetricError(f"Failed to get subtask metric: {str(e)}")

    def log(
            self,
            splits: List[Split],
            additional_payload: Optional[Dict[str, float]] = None
    ) -> None:
        """Log metrics and losses.

        Args:
            splits: List of splits to log
            additional_payload: Optional additional values to log
        """
        try:
            out: Dict[str, float] = additional_payload or {}

            for split in splits:
                # For training and validation, log last values
                if split in [Split.DEV, Split.TRAIN]:
                    # Log metrics
                    metrics = {
                        m.name: m.mean_last_k(1)
                        for d in self.metrics[split].values()
                        for m in d.values()
                    }
                    # Log losses
                    combined_loss = self.combined_losses[split].mean_last_k(1)
                    losses = {
                        v.name: v.mean_last_k(1)
                        for v in self.losses[split].values()
                    }
                # For test and eval, log means
                else:
                    # Log metrics
                    metrics = {
                        m.name: m.mean_all()
                        for d in self.metrics[split].values()
                        for m in d.values()
                    }
                    # Log losses
                    combined_loss = self.combined_losses[split].mean_all()
                    losses = {
                        v.name: v.mean_all()
                        for v in self.losses[split].values()
                    }

                out.update(metrics)
                out[f"combined_{split.value}_loss"] = combined_loss
                out.update(losses)

            # Log to wandb and local logger
            self.logger.log(out)

        except Exception as e:
            raise MetricError(f"Failed to log metrics: {str(e)}")

    def save_history(self, path: Union[str, Path]) -> None:
        """Save complete metric history.

        Args:
            path: Path to save history
        """
        try:
            path = Path(path)
            history = {
                'metrics': {
                    split.value: {
                        st_id: {
                            metric: meter.get_history()
                            for metric, meter in st_metrics.items()
                        }
                        for st_id, st_metrics in split_metrics.items()
                    }
                    for split, split_metrics in self.metrics.items()
                },
                'losses': {
                    split.value: {
                        st_id: meter.get_history()
                        for st_id, meter in split_losses.items()
                    }
                    for split, split_losses in self.losses.items()
                },
                'combined_losses': {
                    split.value: meter.get_history()
                    for split, meter in self.combined_losses.items()
                },
                'best_metrics': self.best_metrics
            }

            with open(path, 'w') as f:
                json.dump(history, f, indent=2)

            general_logger.info(f"Saved metric history to {path}")

        except Exception as e:
            raise MetricError(f"Failed to save history: {str(e)}")

    def load_history(self, path: Union[str, Path]) -> None:
        """Load metric history.

        Args:
            path: Path to load history from
        """
        try:
            path = Path(path)
            with open(path) as f:
                history = json.load(f)

            # Restore metrics
            for split_name, split_metrics in history['metrics'].items():
                split = Split(split_name)
                for st_id, st_metrics in split_metrics.items():
                    for metric, values in st_metrics.items():
                        for value in values:
                            self.metrics[split][st_id][metric].update(value)

            # Restore losses
            for split_name, split_losses in history['losses'].items():
                split = Split(split_name)
                for st_id, values in split_losses.items():
                    for value in values:
                        self.losses[split][st_id].update(value)

            # Restore combined losses
            for split_name, values in history['combined_losses'].items():
                split = Split(split_name)
                for value in values:
                    self.combined_losses[split].update(value)

            # Restore best metrics
            self.best_metrics = history['best_metrics']

            general_logger.info(f"Loaded metric history from {path}")

        except Exception as e:
            raise MetricError(f"Failed to load history: {str(e)}")

    def reset(self) -> None:
        """Reset all metrics and losses."""
        try:
            # Reset metrics
            for split_metrics in self.metrics.values():
                for st_metrics in split_metrics.values():
                    for meter in st_metrics.values():
                        meter.reset()

            # Reset losses
            for split_losses in self.losses.values():
                for meter in split_losses.values():
                    meter.reset()

            # Reset combined losses
            for meter in self.combined_losses.values():
                meter.reset()

            general_logger.info("Reset all metrics and losses")

        except Exception as e:
            raise MetricError(f"Failed to reset metrics: {str(e)}")


In [288]:
"""Checkpoint management module for MTL model.

This module handles saving, loading, and managing model checkpoints,
including best model tracking and checkpoint rotation.
"""

from typing import Dict, Optional, Union, Any
import torch
from pathlib import Path
import json
import shutil
import time
from dataclasses import dataclass
from collections import deque

from media_bias_detection.utils.logger import general_logger
from media_bias_detection.utils.enums import Split


@dataclass
class CheckpointMetadata:
    """Container for checkpoint metadata.

    Attributes:
        epoch: Training epoch number
        global_step: Global training step
        train_loss: Training loss
        val_loss: Validation loss
        metrics: Dictionary of metrics
        timestamp: When checkpoint was created
    """
    epoch: int
    global_step: int
    train_loss: float
    val_loss: float
    metrics: Dict[str, float]
    timestamp: float


class CheckpointError(Exception):
    """Custom exception for checkpoint-related errors."""
    pass


class CheckpointManager:
    """Manages model checkpoints.

    This class handles saving and loading checkpoints, including
    maintaining best models and checkpoint rotation.

    Attributes:
        save_dir: Directory for saving checkpoints
        max_checkpoints: Maximum number of checkpoints to keep
        checkpoint_name: Base name for checkpoint files
        save_best_only: Whether to save only best models
        best_metric: Name of metric to track for best model
        minimize_metric: Whether metric should be minimized
    """

    def __init__(
            self,
            save_dir: Union[str, Path],
            max_checkpoints: int = 5,
            checkpoint_name: str = "model",
            save_best_only: bool = False,
            best_metric: str = "val_loss",
            minimize_metric: bool = True
    ):
        """Initialize checkpoint manager.

        Args:
            save_dir: Directory to save checkpoints in
            max_checkpoints: Maximum number of checkpoints to keep
            checkpoint_name: Base name for checkpoint files
            save_best_only: Whether to save only best models
            best_metric: Metric to track for best model
            minimize_metric: Whether metric should be minimized
        """
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        self.max_checkpoints = max_checkpoints
        self.checkpoint_name = checkpoint_name
        self.save_best_only = save_best_only
        self.best_metric = best_metric
        self.minimize_metric = minimize_metric

        # Track checkpoints
        self.checkpoints = deque(maxlen=max_checkpoints)
        self.best_checkpoint: Optional[Path] = None
        self.best_metric_value = float('inf') if minimize_metric else float('-inf')

        # Load existing checkpoints if any
        self._load_existing_checkpoints()

        general_logger.info(
            f"Initialized checkpoint manager in {save_dir} "
            f"(max_checkpoints={max_checkpoints}, save_best_only={save_best_only})"
        )

    def _load_existing_checkpoints(self) -> None:
        """Load information about existing checkpoints."""
        try:
            # Find all checkpoint files
            checkpoint_files = sorted(
                self.save_dir.glob(f"{self.checkpoint_name}*.pt")
            )

            for checkpoint_file in checkpoint_files:
                metadata_file = checkpoint_file.with_suffix('.json')
                if metadata_file.exists():
                    with open(metadata_file) as f:
                        metadata = json.load(f)

                    # Update best checkpoint if applicable
                    if self.best_metric in metadata['metrics']:
                        metric_value = metadata['metrics'][self.best_metric]
                        if self._is_better_metric(metric_value):
                            self.best_checkpoint = checkpoint_file
                            self.best_metric_value = metric_value

                    self.checkpoints.append(checkpoint_file)

            general_logger.info(
                f"Found {len(self.checkpoints)} existing checkpoints"
            )

        except Exception as e:
            raise CheckpointError(f"Failed to load existing checkpoints: {str(e)}")

    def _is_better_metric(self, new_value: float) -> bool:
        """Check if new metric value is better than current best.

        Args:
            new_value: New metric value to compare

        Returns:
            Whether new value is better
        """
        if self.minimize_metric:
            return new_value < self.best_metric_value
        return new_value > self.best_metric_value

    def _save_metadata(
            self,
            path: Path,
            metadata: CheckpointMetadata
    ) -> None:
        """Save checkpoint metadata to JSON file.

        Args:
            path: Path to save metadata
            metadata: Metadata to save
        """
        try:
            metadata_dict = {
                'epoch': metadata.epoch,
                'global_step': metadata.global_step,
                'train_loss': metadata.train_loss,
                'val_loss': metadata.val_loss,
                'metrics': metadata.metrics,
                'timestamp': metadata.timestamp
            }

            with open(path.with_suffix('.json'), 'w') as f:
                json.dump(metadata_dict, f, indent=2)

        except Exception as e:
            raise CheckpointError(f"Failed to save metadata: {str(e)}")

    def save(
            self,
            model: torch.nn.Module,
            optimizer: Optional[torch.optim.Optimizer],
            scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
            metadata: CheckpointMetadata
    ) -> Optional[Path]:
        """Save model checkpoint.

        Args:
            model: Model to save
            optimizer: Optional optimizer to save
            scheduler: Optional scheduler to save
            metadata: Checkpoint metadata

        Returns:
            Path to saved checkpoint if saved, None otherwise

        Raises:
            CheckpointError: If saving fails
        """
        try:
            # Check if we should save
            metric_value = metadata.metrics.get(self.best_metric)
            if self.save_best_only and metric_value is not None:
                if not self._is_better_metric(metric_value):
                    return None

            # Create checkpoint
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'metadata': {
                    'epoch': metadata.epoch,
                    'global_step': metadata.global_step,
                    'metrics': metadata.metrics
                }
            }

            # Generate checkpoint path
            checkpoint_path = self.save_dir / (
                f"{self.checkpoint_name}_epoch{metadata.epoch:03d}.pt"
            )

            # Save checkpoint and metadata
            torch.save(checkpoint, checkpoint_path)
            self._save_metadata(checkpoint_path, metadata)

            # Update checkpoint tracking
            self.checkpoints.append(checkpoint_path)

            # Update best checkpoint if applicable
            if metric_value is not None and self._is_better_metric(metric_value):
                if self.best_checkpoint is not None:
                    old_best = self.best_checkpoint
                    if old_best != checkpoint_path:
                        shutil.copy(checkpoint_path, self.best_checkpoint.parent / 'best.pt')
                else:
                    shutil.copy(checkpoint_path, self.save_dir / 'best.pt')
                self.best_checkpoint = checkpoint_path
                self.best_metric_value = metric_value

            # Clean up old checkpoints if necessary
            while len(self.checkpoints) > self.max_checkpoints:
                old_checkpoint = self.checkpoints.popleft()
                if old_checkpoint != self.best_checkpoint:
                    old_checkpoint.unlink()
                    old_checkpoint.with_suffix('.json').unlink()

            general_logger.info(f"Saved checkpoint to {checkpoint_path}")
            return checkpoint_path

        except Exception as e:
            raise CheckpointError(f"Failed to save checkpoint: {str(e)}")

    def load(
            self,
            path: Optional[Union[str, Path]] = None,
            load_best: bool = False,
            map_location: Optional[torch.device] = None
    ) -> Dict[str, Any]:
        """Load checkpoint.

        Args:
            path: Path to checkpoint to load, or None for latest
            load_best: Whether to load best checkpoint
            map_location: Optional device to map tensors to

        Returns:
            Loaded checkpoint dictionary

        Raises:
            CheckpointError: If loading fails
        """
        try:
            if load_best:
                if self.best_checkpoint is None:
                    raise CheckpointError("No best checkpoint available")
                path = self.best_checkpoint
            elif path is None:
                if not self.checkpoints:
                    raise CheckpointError("No checkpoints available")
                path = self.checkpoints[-1]
            else:
                path = Path(path)

            if not path.exists():
                raise CheckpointError(f"Checkpoint not found: {path}")

            checkpoint = torch.load(path, map_location=map_location)
            general_logger.info(f"Loaded checkpoint from {path}")
            return checkpoint

        except Exception as e:
            raise CheckpointError(f"Failed to load checkpoint: {str(e)}")

    def get_latest_checkpoint(self) -> Optional[Path]:
        """Get path to latest checkpoint.

        Returns:
            Path to latest checkpoint or None if no checkpoints exist
        """
        return self.checkpoints[-1] if self.checkpoints else None

    def get_best_checkpoint(self) -> Optional[Path]:
        """Get path to best checkpoint.

        Returns:
            Path to best checkpoint or None if no best checkpoint exists
        """
        return self.best_checkpoint

    def cleanup(self) -> None:
        """Remove all checkpoints."""
        try:
            for checkpoint in self.checkpoints:
                checkpoint.unlink()
                checkpoint.with_suffix('.json').unlink()
            self.checkpoints.clear()
            self.best_checkpoint = None
            self.best_metric_value = float('inf') if self.minimize_metric else float('-inf')
            general_logger.info("Removed all checkpoints")

        except Exception as e:
            raise CheckpointError(f"Failed to cleanup checkpoints: {str(e)}")

In [289]:
"""Training module for MTL model.

Provides enhanced training functionality with:
- Comprehensive error handling
- Memory optimization
- Detailed logging
- Training efficiency improvements
"""
from typing import Dict, List, Optional, Any
from pathlib import Path
import gc
import time

import numpy as np
import psutil
import torch
import statistics as stats
from tqdm import tqdm
from transformers import get_polynomial_decay_schedule_with_warmup

from media_bias_detection.config.config import MAX_NUMBER_OF_STEPS
from media_bias_detection.utils.enums import Split, LossScaling
from media_bias_detection.utils.logger import general_logger



class TrainerError(Exception):
    """Custom exception for training-related errors."""
    pass


class Trainer:
    """Enhanced trainer for MTL model.

    Features:
    - Automatic mixed precision training
    - Memory-optimized batch processing
    - Detailed progress tracking
    - Comprehensive error handling
    """

    def __init__(
            self,
            task_list: List[Task],
            initial_lr: float,
            model_name: str,
            pretrained_path: Optional[str],
            sub_batch_size: int,
            eval_batch_size: int,
            early_stopping_mode,
            resurrection: bool,
            aggregation_method: AggregationMethod,
            loss_scaling: LossScaling,
            num_warmup_steps: int,
            head_specific_lr_dict: Dict[str, float],
            head_specific_patience_dict: Dict[str, int],
            head_specific_max_epoch_dict: Dict[str, int],
            logger: Logger,
            device: Optional[torch.device] = None,
            use_amp: bool = True,
            *args,
            **kwargs,
    ):
        """Initialize trainer with enhanced configuration."""
        try:
            self.logger = logger
            general_logger.info("Initializing trainer...")

            # Basic setup
            self.early_stopping_mode = early_stopping_mode
            self.loss_scaling = loss_scaling
            self.use_amp = use_amp and torch.cuda.is_available()
            self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")


            self.model, batch_list_train, batch_list_dev, batch_list_eval, batch_list_test = ModelFactory(
                task_list=task_list,
                sub_batch_size=sub_batch_size,
                eval_batch_size=eval_batch_size,
                pretrained_path=pretrained_path,
                *args,
                **kwargs,
            )
            self.batch_lists = {
                Split.TRAIN: batch_list_train,
                Split.DEV: batch_list_dev,
                Split.EVAL: batch_list_eval,
                Split.TEST: batch_list_test,
            }

            # shared backbone model optimizer
            self.lm_optimizer = torch.optim.AdamW(self.model.language_model.backbone.parameters(), lr=initial_lr)
            self.lm_lr_scheduler = get_polynomial_decay_schedule_with_warmup(
                optimizer=self.lm_optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=max([len(dl) for dl in self.batch_lists[Split.TRAIN].dataloaders.values()])
                                   * stats.median(head_specific_max_epoch_dict.values()),
            )

            # task-specifics optimizers
            self.head_optimizers = {
                str(st_id): torch.optim.AdamW(head.parameters(), lr=head_specific_lr_dict[st_id])
                for st_id, head in self.model.heads.items()
            }
            self.head_lr_schedulers = {
                str(st_id): get_polynomial_decay_schedule_with_warmup(
                    optimizer=self.head_optimizers[st_id],
                    num_warmup_steps=num_warmup_steps,
                    num_training_steps=len(self.batch_lists[Split.TRAIN].dataloaders[st_id])
                                       * head_specific_max_epoch_dict[st_id],
                )
                for st_id in self.model.heads.keys()
            }

            # flags controlling stopping and resurrection
            self.task_alive_flags = {str(st_id): True for st_id in self.model.heads.keys()}
            self.task_zombie_flags = {str(st_id): False for st_id in self.model.heads.keys()}
            self.early_stopper = EarlyStopper(
                st_ids=self.model.heads.keys(),
                mode=self.early_stopping_mode,
                patience=head_specific_patience_dict,
                resurrection=resurrection,
            )

            # Initialize tracking components
            self.tracker = Tracker(heads=self.model.heads, logger=logger)
            self.GA = GradientAggregator(aggregation_method=aggregation_method)
            self.progress_bar = tqdm(range(len(self.model.heads)))
            self.model_name = model_name
            self.scaling_weights = {str(st.id): st.get_scaling_weight() for t in task_list for st in t.subtasks_list}
            self.MAX_NUMBER_OF_STEPS = MAX_NUMBER_OF_STEPS
            self.k = 50

            # Memory tracking
            self._last_memory_check = time.time()
            self._memory_check_interval = 60  # seconds

            general_logger.info(
                f"Trainer initialized successfully on {self.device}"
                f"{' with AMP' if self.use_amp else ''}"
            )

        except Exception as e:
            general_logger.error(f"Failed to initialize trainer: {str(e)}")
            raise TrainerError(f"Failed to initialize trainer: {str(e)}")

    def _optimize_memory(self) -> None:
        """Perform memory optimization."""
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

            process = psutil.Process()
            memory_info = process.memory_info().rss / 1024 ** 2  # Convert to MB
            general_logger.debug(f"Current memory usage: {memory_info:.2f} MB")

        except Exception as e:
            general_logger.warning(f"Memory optimization failed: {str(e)}")

    def head_specific_optimization(self, st_id: str, lm_grads, scaling_weight):
        """
        Perform the optimization of a task-specific head.

        This method is only called when mode is training.
        @param st_id: The subtask id.
        @param lm_grads: The LM gradients.
        @param scaling_weight: The scaling weight of that subtask.
        @return: A dictionary with additional payload containing the conflicting gradients ratio.
        """
        try:
            additional_payload = {}
            
            # Check if we have dev metrics recorded
            if self.tracker.losses[Split.DEV][st_id].values:
                last_dev_loss = self.tracker.get_last_st_loss(split=Split.DEV, st_id=st_id, k=self.k)
                
                # Check early stopping conditions only if we have dev metrics
                should_stop_now = (
                    self.early_stopper.early_stop(st_id=st_id, dev_loss=last_dev_loss)
                    if (self.task_alive_flags[st_id] or self.task_zombie_flags[st_id])
                    else False
                )
                
                should_resurrect_now = (
                    self.early_stopper.resurrect(st_id=st_id, dev_loss=last_dev_loss)
                    if (not self.task_zombie_flags[st_id] and not self.task_alive_flags[st_id])
                    else False
                )
                
                should_stay_zombie = (
                    not self.task_alive_flags[st_id] and 
                    self.task_zombie_flags[st_id] and 
                    not should_stop_now
                )
                
                # Update task states based on conditions
                if should_stop_now and self.task_alive_flags[st_id]:
                    general_logger.info(f"Subtask {st_id} is now DEAD.")
                    self.eval_st(split=Split.EVAL, st_id=st_id)
                    self.tracker.log(splits=[Split.EVAL], additional_payload={st_id + "_STOPPED": 0})
                    self.progress_bar.update()
                
                elif should_resurrect_now and not self.task_zombie_flags[st_id]:
                    general_logger.info(f"Subtask {st_id} is now ZOMBIE.")
                    additional_payload[st_id + "_ZOMBIE"] = 0
                    self.early_stopper.reset_early_stopper(st_id=st_id)
                
                elif should_stop_now and self.task_zombie_flags[st_id]:
                    general_logger.info(f"Subtask {st_id} is now DEAD AGAIN.")
                    additional_payload[st_id + "_DEAD_ZOMBIE"] = 0
                    self.early_stopper.reset_early_stopper(st_id=st_id)
                
                # Update flags
                self.task_alive_flags[st_id] = (
                    self.task_alive_flags[st_id] and 
                    not (should_stop_now or self.tracker.get_last_st_metric(split=Split.DEV, st_id=st_id, k=10) == 1)
                )
                self.task_zombie_flags[st_id] = should_resurrect_now or should_stay_zombie
                
            else:
                # If no dev metrics yet, keep task alive and don't trigger early stopping
                should_stop_now = False
                self.task_alive_flags[st_id] = True
                self.task_zombie_flags[st_id] = False
    
            # Optimize task if it's alive or zombie
            optimize_task = self.task_alive_flags[str(st_id)] or self.task_zombie_flags[str(st_id)]
            if optimize_task:
                self.head_optimizers[st_id].step()
                self.head_lr_schedulers[st_id].step()
    
            # Update gradients if appropriate
            if self.early_stopping_mode != EarlyStoppingMode.BACKBONE or optimize_task:
                self.GA.update(lm_grads, scaling_weight=scaling_weight)
    
            return additional_payload
    
        except Exception as e:
            raise TrainerError(f"Failed to do head specific optimization: {str(e)}")


    def backbone_optimization(self) -> Dict[str, Any]:
        """
        Perform the optimization of the backbone.

        This method is only called when mode is training.
        @return: A dictionary with additional payload containing the conflicting gradients ratio.
        """
        # Optimize the LM such that: we aggregate gradients from subtasks and set the final
        # gradient to the LM and subsequently optimize (only the LM)
        try:
            additional_payload = {}
            if any(self.task_alive_flags.values()):
                aggregated_gradients = self.GA.aggregate_gradients()
                self.model.language_model.set_grads(aggregated_gradients)
                self.lm_optimizer.step()
                self.lm_lr_scheduler.step()
            if self.GA.aggregation_method in [AggregationMethod.PCGRAD, AggregationMethod.PCGRAD_ONLINE]:
                conflicting_gradients_ratio = self.GA.get_conflicting_gradients_ratio()
                additional_payload["conflicting_gradients_ratio"] = conflicting_gradients_ratio
        except Exception as e:
            raise TrainerError(f"Failed to do backbone optimization: {str(e)}")
        return additional_payload


    def handle_batch(self, batch, split: Split = Split.TRAIN) -> Dict[str, Any]:
        """Handle a batch.

         (always) Pass a batch of sub_batches through the network.
         (in train-mode) For each sub_batch, accumulate the gradients of the LM.
         For each sub_batch and each st_id,
            - (in train-mode) accumulate the gradients of the respective head,
            - (always) accumulate the metric of the respective head,
            - (always) accumulate the loss of the respective head.
        (always) Log all metrics and losses to wandb.
         (in train-mode) After all sub_batches are processed, normalize the LM gradients and the head-specific gradients.
         (in train-mode) Then, perform the step of the lr_scheduler and the optimizer.

        @param batch: The batch containing sub-batches.
        @param split: The split (TRAIN, DEV, TEST)
        @return: A dictionary containing additional payload that needs to be logged.
        """
        try:
            training = split == Split.TRAIN
            losses = []
            additional_payloads = {}

            if training:
                self.GA.reset_accumulator()
                general_logger.debug("Reset gradient accumulator")

            for sub_batch in batch:
                if isinstance(sub_batch, BatchData):
                    X = sub_batch.input_ids
                    attention_masks = sub_batch.attention_mask
                    Y = sub_batch.labels
                    st_id = sub_batch.subtask_id
                else:
                    X, attention_masks, Y, st_id = sub_batch
                st_id_str = str(st_id.unique().item())

                general_logger.debug(f"Processing sub-batch for task {st_id_str}")

                # Forward pass and compute loss
                loss, metric_values, lm_grads = self._step(
                    (X, attention_masks, Y, st_id.unique()),
                    training=training
                )

                scaling_weight = (
                    self.scaling_weights[st_id_str]
                    if self.loss_scaling == LossScaling.STATIC
                    else 1.0
                )

                if training:
                    payload = self.head_specific_optimization(
                        st_id=st_id_str,
                        lm_grads=lm_grads,
                        scaling_weight=scaling_weight
                    )
                    additional_payloads.update(payload)

                # Update metrics and losses
                for metric, value in metric_values.items():
                    self.tracker.update_metric(split=split, st_id=st_id_str, metric=metric, value=value)
                self.tracker.update_loss(split=split, st_id=st_id_str, value=loss.item())
                losses.append(loss.item())

            if training:
                payload = self.backbone_optimization()
                additional_payloads.update(payload)

            mean_loss = np.mean(losses)
            self.tracker.update_combined_loss(split=split, value=mean_loss)
            general_logger.debug(f"Batch processed. Mean loss: {mean_loss:.4f}")

            return additional_payloads

        except Exception as e:
            general_logger.error(f"Failed to handle batch: {str(e)}")
            raise TrainerError(f"Batch processing failed: {str(e)}")

    def _step(self, batch, training: bool = True):
        """Perform a single training/evaluation step."""
        inputs = {
                "X": batch[0].to(self.device),
                "attention_masks": batch[1].to(self.device),
                "Y": batch[2].to(self.device),
                "st_id": batch[3]
        }
        
        try:
            if training:
                self.model.train()
                self.lm_optimizer.zero_grad()
                for optim in self.head_optimizers.values():
                    optim.zero_grad()

                loss, metric_values = self.model(**inputs)
                loss.backward()
                lm_gradients = self.model.language_model.get_grads()

            else:
                self.model.eval()
                with torch.no_grad():
                    loss, metric_values = self.model(**inputs)
                lm_gradients = None

            return loss, metric_values, lm_gradients

        except Exception as e:
            general_logger.error(f"Step execution failed: {str(e)}")
            raise TrainerError(f"Step execution failed: {str(e)}")
        
        finally:
            # Clean up to prevent memory leaks
            tensor_keys = [k for k, v in inputs.items() if isinstance(v, torch.Tensor)]
            for k in tensor_keys:
                del inputs[k]


    def fit_debug(self, k: int):
        """Fit for k iterations only to check if a model can process the data."""
        try:
            general_logger.info(f"Starting the debug training for {k} iterations")
            step = 0
            for _ in range(k):
                step += 1
                batch = next(self.batch_lists[Split.TRAIN])
                self.handle_batch(batch=batch, split=Split.TRAIN)
                # Evaluate on dev-batch
                batch = next(self.batch_lists[Split.DEV])
                self.handle_batch(batch=batch, split=Split.DEV)
        except Exception as e:
            general_logger.error(f"Debug training failed: {str(e)}")
            raise TrainerError(f"Debug training failed: {str(e)}")


    def fit(self):
        """Train the model."""
        try:
            general_logger.info("Starting training")
            step = 0

            while step < self.MAX_NUMBER_OF_STEPS:
                if not any(self.task_alive_flags.values()):
                    general_logger.info("No tasks remaining alive, stopping training")
                    break

                step += 1
                general_logger.debug(f"Starting step {step}")

                batch = next(self.batch_lists[Split.TRAIN])
                train_payload = self.handle_batch(batch=batch, split=Split.TRAIN)

                if step % 3 == 0:
                    batch = next(self.batch_lists[Split.DEV])
                    dev_payload = self.handle_batch(batch=batch, split=Split.DEV)
                    train_payload.update(dev_payload)

                self._update_progress()
                self.tracker.log(
                    splits=[Split.TRAIN, Split.DEV],
                    additional_payload=train_payload
                )

                # Periodic memory optimization
                if step % 100 == 0:
                    self._optimize_memory()

            general_logger.info("Training completed")
            self.eval(split=Split.EVAL)

        except Exception as e:
            general_logger.error(f"Training failed: {str(e)}")
            raise TrainerError(f"Training failed: {str(e)}")


    def eval(self, split):
        """Evaluate the model."""
        try:
            general_logger.info(f"Starting evaluation on {split}")
            assert split in [Split.EVAL, Split.TEST]

            for st_id in self.batch_lists[split].iter_dataloaders.keys():
                self.eval_st(split=split, st_id=st_id)

            self.tracker.log(splits=[split])
            general_logger.info(f"Evaluation on {split} completed")

        except Exception as e:
            general_logger.error(f"Evaluation failed: {str(e)}")
            raise TrainerError(f"Evaluation failed: {str(e)}")


    def eval_st(self, split, st_id):
        """Evaluate on a specific subtask."""
        try:
            general_logger.debug(f"Evaluating subtask {st_id} on {split}")
            batch_list = self.batch_lists[split]
            batch_list._reset()
            idl = batch_list.iter_dataloaders[st_id]

            for batch in idl:
                _ = self.handle_batch(batch=[batch], split=split)

        except Exception as e:
            general_logger.error(f"Subtask evaluation failed: {str(e)}")
            raise TrainerError(f"Subtask evaluation failed: {str(e)}")


    def save_model(self):
        """Save the trained model."""
        try:
            path = Path("model_files")
            path.mkdir(exist_ok=True)
            model_path = path / f"{self.model_name}.pth"

            torch.save(self.model.state_dict(), model_path)
            general_logger.info(f"Model saved to {model_path}")

        except Exception as e:
            general_logger.error(f"Failed to save model: {str(e)}")
            raise TrainerError(f"Model saving failed: {str(e)}")


    def _update_progress(self):
        """Update progress bar."""
        try:
            desc = str(self.tracker)
            self.progress_bar.set_description(desc=desc)
            self.progress_bar.refresh()

        except Exception as e:
            general_logger.warning(f"Failed to update progress bar: {str(e)}")

# Running the experiment
For actually running the experiment the configurations from the "cotrain_random_tasks.py" were taken and adapted to the changes (MFFLOW logging etc.). 
Instead of the .fit() method of the trainer class, I use the .fit_debug() method, to check the general ability of the model to process the dta.
The experiment was run on the local machine.

In [290]:
print(f"Current working directory: {os.getcwd()}")

Current working directory: /Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project


In [291]:
# changing working directory to the root of the project:/Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project
os.chdir("/Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project")

In [292]:
print(f"Current working directory: {os.getcwd()}")

Current working directory: /Users/heddafiedler/Documents/MASTER_DATA_SCIENCE/Semester_3/DL/DL_Project


In [294]:
"""Script for executing the experiment 1. Run co-training of all families."""
import os
import wandb
from media_bias_detection.utils.enums import Split, AggregationMethod, LossScaling
from media_bias_detection.utils.common import set_random_seed
from media_bias_detection.config.config import (
    head_specific_lr,
    head_specific_max_epoch,
    head_specific_patience)

EXPERIMENT_NAME = "experiment_baseline_check"
MODEL_NAME = "baseline_check"
selected_tasks = [cw_hard_03,
me_too_ma_108,
good_news_everyone_42]

tasks = selected_tasks

for t in tasks:
    for st in t.subtasks_list:
        st.process()


# training config
config = {
   "sub_batch_size": 32,
   "eval_batch_size": 128,
   "initial_lr": 4e-5,
   "dropout_prob": 0.1,
   "hidden_dimension": 768,
   "input_dimension": 768,
   "aggregation_method": AggregationMethod.MEAN,
   "early_stopping_mode": EarlyStoppingMode.HEADS,
   "loss_scaling": LossScaling.STATIC,
   "num_warmup_steps": 10,
   "pretrained_path": None,
   "resurrection": True,
   "model_name": "YOUR_MODEL_NAME",
   "head_specific_lr_dict": head_specific_lr,
   "head_specific_patience_dict": head_specific_patience,
   "head_specific_max_epoch_dict": head_specific_max_epoch,
   "logger": Logger(EXPERIMENT_NAME),
 }

set_random_seed() # default is 321
#wandb.init(project=EXPERIMENT_NAME,name=MODEL_NAME)
trainer = Trainer(task_list=tasks, **config)
trainer.fit_debug(k=1)
trainer.eval(split=Split.TEST)
trainer.save_model()
#wandb.finish()

[2024-12-09 17:53:16,451: INFO: 1312152763: Processing SubTask 300001]
[2024-12-09 17:53:16,982: INFO: 1312152763: SubTask 300001 processed successfully. Splits: Train=5474, Dev=684, Test=685]
[2024-12-09 17:53:16,983: INFO: 1312152763: Processing SubTask 10801]
Loading data for MultiLabelClassificationSubTask 10801
X type: <class 'pandas.core.series.Series'>
Y type: <class 'pandas.core.frame.DataFrame'>
X shape: 7388
Y shape: (7388, 2)


  0%|          | 0/13 [25:17<?, ?it/s]

X shape: torch.Size([7388, 128])
Y shape: torch.Size([7388, 2])
Attention masks shape: torch.Size([7388, 128])
[2024-12-09 17:53:18,060: INFO: 1312152763: SubTask 10801 processed successfully. Splits: Train=5910, Dev=738, Test=740]
[2024-12-09 17:53:18,061: INFO: 1312152763: Processing SubTask 42001]





[2024-12-09 17:53:18,446: INFO: 1312152763: SubTask 42001 processed successfully. Splits: Train=3542, Dev=442, Test=444]
[2024-12-09 17:53:18,447: INFO: 1312152763: Processing SubTask 42002]
[2024-12-09 17:53:18,812: INFO: 1312152763: SubTask 42002 processed successfully. Splits: Train=3542, Dev=442, Test=444]
[2024-12-09 17:53:18,817: INFO: 2527385809: Initializing trainer...]
[2024-12-09 17:53:18,817: INFO: 4203745325: Initializing backbone language model]
Creating ClassificationHead for subtask 300001
num_classes: 2
Initializing ClassificationHead
num_classes: 2
num_labels: 1
Initializing ClassificationHead with 1 labels
Creating MultiLabelClassificationHead for subtask 10801
num_classes: 2, num_labels: 2
Initializing ClassificationHead
num_classes: 2
num_labels: 2
Initializing ClassificationHead with 2 labels
[2024-12-09 17:53:19,144: INFO: 2558031327: Initialized TokenClassificationHead with 3 classes]
[2024-12-09 17:53:19,146: INFO: 2558031327: Initialized TokenClassificationHead

  0%|          | 0/4 [00:00<?, ?it/s]

[2024-12-09 17:53:19,178: INFO: 2527385809: Trainer initialized successfully on cpu]


  0%|          | 0/13 [18:43<?, ?it/s]

[2024-12-09 17:53:19,282: INFO: 2527385809: Starting the debug training for 1 iterations]





Batch size: 32
X shape: torch.Size([32, 128, 768])
y shape: torch.Size([32, 1])
Logits shape: torch.Size([32, 2])
Batch size: 32
X shape: torch.Size([32, 128, 768])
y shape: torch.Size([32, 2])
Logits shape: torch.Size([32, 4])
Batch size: 128
X shape: torch.Size([128, 128, 768])
y shape: torch.Size([128, 1])
Logits shape: torch.Size([128, 2])
Batch size: 128
X shape: torch.Size([128, 128, 768])
y shape: torch.Size([128, 2])
Logits shape: torch.Size([128, 4])
[2024-12-09 17:53:42,483: INFO: 2527385809: Starting evaluation on Split.TEST]
Batch size: 32
X shape: torch.Size([32, 128, 768])
y shape: torch.Size([32, 1])
Logits shape: torch.Size([32, 2])
Batch size: 32
X shape: torch.Size([32, 128, 768])
y shape: torch.Size([32, 1])
Logits shape: torch.Size([32, 2])
Batch size: 32
X shape: torch.Size([32, 128, 768])
y shape: torch.Size([32, 1])
Logits shape: torch.Size([32, 2])
Batch size: 32
X shape: torch.Size([32, 128, 768])
y shape: torch.Size([32, 1])
Logits shape: torch.Size([32, 2])
B

In [295]:
#model architecture check
def print_model_summary(model):
    """Create a detailed custom model summary."""
    backbone_params = sum(p.numel() for p in model.language_model.backbone.parameters())
    
    print("=== MTL Model Summary ===")
    print(f"\nBackbone: DistilBERT")
    print(f"Backbone Parameters: {backbone_params:,}")
    
    print("\nTask Heads:")
    for task_id, head in model.heads.items():
        head_params = sum(p.numel() for p in head.parameters())
        print(f"\nTask {task_id}:")
        print(f"  Type: {head.__class__.__name__}")
        print(f"  Parameters: {head_params:,}")
        
        # Head-specific details
        if isinstance(head, ClassificationHead):
            print(f"  Classes: {head.num_classes}")
            print(f"  Labels: {head.num_labels}")
            print(f"  Metrics: {list(head.metrics.keys())}")
            
        elif isinstance(head, TokenClassificationHead):
            print(f"  Classes: {head.num_classes}")
            print(f"  Metrics: {list(head.metrics.keys())}")
            
        elif isinstance(head, RegressionHead):
            print(f"  Output Dimension: 1")
            print(f"  Metrics: {list(head.metrics.keys())}")
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal Parameters: {total_params:,}")
    
    # Task distribution summary
    head_types = {}
    for head in model.heads.values():
        head_type = head.__class__.__name__
        head_types[head_type] = head_types.get(head_type, 0) + 1
    
    print("\nTask Distribution:")
    for head_type, count in head_types.items():
        print(f"  {head_type}: {count}")

# Use it
print_model_summary(trainer.model)

=== MTL Model Summary ===

Backbone: DistilBERT
Backbone Parameters: 66,362,880

Task Heads:

Task 300001:
  Type: ClassificationHead
  Parameters: 592,130
  Classes: 2
  Labels: 1
  Metrics: ['f1', 'acc']

Task 10801:
  Type: ClassificationHead
  Parameters: 593,668
  Classes: 2
  Labels: 2
  Metrics: ['f1', 'acc']

Task 42001:
  Type: TokenClassificationHead
  Parameters: 2,307
  Classes: 3
  Metrics: ['f1', 'acc']

Task 42002:
  Type: TokenClassificationHead
  Parameters: 2,307
  Classes: 3
  Metrics: ['f1', 'acc']

Total Parameters: 67,553,292

Task Distribution:
  ClassificationHead: 2
  TokenClassificationHead: 2
