# Train Adapter + attention features

In [None]:
# Imports

from sklearn.metrics import f1_score, accuracy_score
from sklearn.feature_selection import f_regression
import pandas as pd
import numpy as np
import wandb
import torch
import gc
import os

from datasets import Dataset, DatasetDict

from transformers import (
    AutoAdapterModel,
    AutoTokenizer,
    PfeifferConfig,
    TrainingArguments,
    AdapterTrainer,
    AutoConfig,
    TrainerCallback,
    EarlyStoppingCallback
)
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn


# Constants

DATA_PATH = "../data/annotated/"
MODELS_PATH = "../models/fewshot/"

## Train task adapter

In [None]:
feat_head = "feat_att_head"

In [None]:
CONFIG = {
    "task_name": "twittercovidq2",
    "model_name": "roberta-large",
    "max_length": 128,
    "batch_size": 1,
    "epochs": 30,
    "seeds" : [0],
    "learning_rate": 1e-4,
    "gradient_accumulation_steps": 1,
    "fewshot_train": [10, 25, 50]
}

In [None]:
TASK_PATH = f'{DATA_PATH}{CONFIG["task_name"]}.csv'

### Load dataset

In [None]:
task_df = pd.read_csv(TASK_PATH).dropna()
task_df.shape

### Extract pvals and create feature vector

In [None]:
id2label = {}
pos_labels = ["contains-bias", "clickbait", "false", "fake", "has_propaganda", "yes", "contains_false"]

labels = set(task_df["labels"].to_list())
for label in labels:
    if str(label).lower() in pos_labels:
        id2label.update({1: label})
    else:
        id2label.update({0: label})

label2id  = {id2label[k] : k for k in id2label}

id2label

In [None]:
import json
feature_arrays = []
for col in task_df.iloc[:,2:].columns:
    col = task_df[col].apply(lambda x: json.loads(x))
    col_array = np.vstack(col.values)
    _, pvals = f_regression(col_array, task_df.iloc[:,1].apply(lambda x: label2id[x]).values)
    selected_pval = pvals < 0.05
    selected_features = []
    for vector in col:
        selected_features.append([feature for feature, pval in zip(vector, selected_pval) if pval])
    feature_arrays.append(np.array(selected_features))

In [None]:
features = np.hstack(feature_arrays)

In [None]:
task_df["features"] = features.tolist()
task_df.sample(frac=1, random_state=0)
config = len(task_df["features"][0])

### Tokenize dataset

In [None]:
truncation = True
padding = "max_length"
batched = True

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])

In [None]:
def encode_batch(batch):
  """Encodes a batch of input data using the model tokenizer."""
  return tokenizer(batch["text"], max_length=CONFIG["max_length"], truncation=truncation, padding=padding)

In [None]:
task_dataset = Dataset.from_pandas(task_df)
# Encode the input data
task_dataset = task_dataset.map(encode_batch, batched=batched)
# Transform to pytorch tensors and only output the required columns
task_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "features"])
task_dataset = task_dataset.class_encode_column("labels")

### Define model architecture

In [None]:
class MeanPooling(nn.Module):
    def __init__(self, model_config):
        super(MeanPooling, self).__init__()  
        self.model_config = model_config
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings

In [None]:
class FeatAttHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, model_config, config, name):
        super().__init__()
        self.name = name
        self.pooler_out = torch.nn.Sequential(
            nn.Linear(model_config.hidden_size, config["hidden_size"]),
            nn.LayerNorm(config["hidden_size"], eps=1e-12),
        )
        self.features_out = torch.nn.Sequential(
            nn.Linear(config["features_size"], config["hidden_size"]),
            nn.LayerNorm(config["hidden_size"], eps=1e-12),
        )
        self.out_concat = torch.nn.Sequential(
            nn.Linear(config["hidden_size"]*2, 2),
            nn.Softmax(dim=1)
        )
        self.predictor = torch.nn.Sequential(
            nn.Dropout(config["dropout"]),
            nn.Linear(config["concat_size"], config["concat_size"]),
            nn.Tanh(),
            nn.Dropout(config["dropout"]),
            nn.Linear(config["concat_size"], 2)
        )

    def forward(self, cls_output, features, **kwargs):
        out_pooler = self.pooler_out(cls_output)
        out_features = self.features_out(features)
        out = self.out_concat(torch.cat((out_pooler, out_features), dim=-1))

        roberta_cls1 = cls_output * out[:,0].unsqueeze(1)
        feature_vector1 = features * out[:,1].unsqueeze(1)

        return self.predictor(torch.cat((roberta_cls1, feature_vector1), dim=1))
    
    def get_output_embeddings(self):
        return None  # override for heads with output embeddings

    def get_label_names(self):
        return ["labels"]

In [None]:
type2head = {"feat_att_head": FeatAttHead}

In [None]:
class CustomClassificationHead(type2head[feat_head]):
    def __init__(
        self,
        model,
        head_name,
        num_labels=2,
        dropout=0.1,
        features_size=config,
        feat_head=feat_head,
        hidden_size=64,
        id2label=None,
        use_pooler=False,
        concat_size=None
    ):
        self.model_config = model.config
        self.config = {
            "num_labels": num_labels,
            "dropout": dropout,
            "label2id": {label: id_ for id_, label in id2label.items()} if id2label is not None else None,
            "use_pooler": use_pooler,
            "features_size": features_size,
            "feat_head": feat_head,
            "hidden_size": hidden_size,
            "concat_size": concat_size if concat_size is not None else self.model_config.hidden_size+features_size,
        }
        super().__init__(self.model_config, self.config, head_name)
        
        self.apply(model._init_weights)
        self.train(model.training)  # make sure training mode is consistent

    def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
        if cls_output is None:
            if self.config["use_pooler"]:
                pooler = MeanPooling(self.model_config).to("cuda")
                cls_output = pooler(outputs.last_hidden_state, attention_mask)
        features = kwargs.pop("features", None)
        logits = super().forward(cls_output, features)
        loss = None
        labels = kwargs.pop("labels", None)
        if labels is not None:
            if self.config["num_labels"] == 1:
                #  We are doing regression
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.config["num_labels"]), labels.view(-1))

        if return_dict:
            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        else:
            outputs = (logits,) + outputs[1:]
            if labels is not None:
                outputs = (loss,) + outputs
            return outputs

### Train model

In [None]:
class AdapterDropTrainerCallback(TrainerCallback):
  def on_step_begin(self, args, state, control, **kwargs):
    skip_layers = list(range(np.random.randint(0, 11)))
    kwargs['model'].set_active_adapters(kwargs['model'].active_adapters[0], skip_layers=skip_layers)

  def on_evaluate(self, args, state, control, **kwargs):
    # Deactivate skipping layers during evaluation (otherwise it would use the
    # previous randomly chosen skip_layers and thus yield results not comparable
    # across different epochs)
    kwargs['model'].set_active_adapters(kwargs['model'].active_adapters[0], skip_layers=None)

In [None]:
def acc_and_f1(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    f1 = float(f1_score(y_true, y_pred, average='macro'))
    return {
        "accuracy": acc,
        "f1": f1,
    }

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return acc_and_f1(labels, predictions)

In [None]:
strategy = "epoch"
output_dir = f'{MODELS_PATH}{CONFIG["model_name"]}{os.sep}att-adapt{os.sep}{CONFIG["task_name"]}'
overwrite_output_dir = True
remove_unused_columns = False
save_total_limit = 1
report_to = "wandb"
load_best_model_at_end = True
metric_for_best_model = "eval_f1"
early_stopping_patience = 10

In [None]:
def get_model():
    model_config = AutoConfig.from_pretrained(
        CONFIG["model_name"],
        id2label=id2label,
    )
    task_model = AutoAdapterModel.from_pretrained(
        CONFIG["model_name"],
        config=model_config,
    )
    adapter_config = PfeifferConfig()
    task_model.add_adapter(CONFIG["task_name"], config=adapter_config)
    task_model.train_adapter(CONFIG["task_name"])
    task_model.set_active_adapters(CONFIG["task_name"])
    task_model.register_custom_head("custom_classification", CustomClassificationHead)
    task_model.add_custom_head(
        "custom_classification",
        head_name=CONFIG["task_name"],
        num_labels=len(id2label),
        use_pooler=True,
        id2label=id2label,
    )
    return task_model

In [None]:
for fs in CONFIG["fewshot_train"]:
    fewshot_train_ratio = np.ceil(fs/len(task_df)*100)
    for seed in CONFIG["seeds"]:
        wandb.init(
            project=CONFIG["task_name"], 
            config=CONFIG,
            job_type=f'{CONFIG["model_name"]}_{fs}',
            group=feat_head,
            tags=[
                feat_head,
                CONFIG['model_name'],
                f"mx: {CONFIG['max_length']}",
                f"bs: {CONFIG['batch_size']}",
                f"ep: {CONFIG['epochs']}",
                f"lr: {CONFIG['learning_rate']}",
            ],
            name=f'seed_{seed}',
            anonymous='must'
        )

        train_test = task_dataset.train_test_split(test_size=(100-fewshot_train_ratio)/100, generator=np.random.RandomState(0))
        test_valid = train_test['test'].train_test_split(test_size=0.2, generator=np.random.RandomState(0))
        
        dataset = DatasetDict(
            {
                'train': train_test['train'],
                'valid': test_valid['test'],
                'test': test_valid['train']
            }
        )

        training_args = TrainingArguments(
            learning_rate=CONFIG["learning_rate"],
            num_train_epochs=CONFIG["epochs"],
            per_device_train_batch_size=CONFIG["batch_size"],
            per_device_eval_batch_size=CONFIG["batch_size"],
            gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
            logging_strategy=strategy,
            evaluation_strategy=strategy,
            save_strategy=strategy,
            output_dir=output_dir,
            overwrite_output_dir=overwrite_output_dir,
            # The next line is important to ensure the dataset labels are properly passed to the model
            remove_unused_columns=remove_unused_columns,
            save_total_limit=save_total_limit,
            report_to=report_to,
            load_best_model_at_end=load_best_model_at_end,
            metric_for_best_model=metric_for_best_model,
            seed=seed
        )

        trainer = AdapterTrainer(
            model_init=get_model,
            args=training_args,
            train_dataset=dataset["train"],
            eval_dataset=dataset["valid"],
            compute_metrics=compute_metrics,
            callbacks = [
                EarlyStoppingCallback(early_stopping_patience=early_stopping_patience),
                AdapterDropTrainerCallback()
            ]
        )

        trainer.train()
        trainer.evaluate(dataset["test"], metric_key_prefix="test")

        wandb.finish()

        gc.collect()
        torch.cuda.empty_cache()