In [None]:
!git clone https://github.com/elenanespolo/Sentiment_Sarcasm_Analysis

%cd Sentiment_Sarcasm_Analysis
!git pull

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
import os
import transformers
import tqdm
from transformers import BertConfig, BertModel, BertTokenizer
from transformers.models.bert.modeling_bert import BertLayer, BertEncoder
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
import seaborn as sns
from sklearn.metrics import f1_score

pcgrad_repo = "./pcgrad_repo"
if not os.path.exists('./pcgrad_repo'):
    !git clone https://github.com/WeiChengTseng/Pytorch-PCGrad
    !mv Pytorch-PCGrad pcgrad_repo
from pcgrad_repo.pcgrad import PCGrad

import seaborn as sns
import matplotlib.pyplot as plt


# Set parameters

In [None]:
ENABLE_WANDB = False  # set to True to enable Weights and Biases logging, False to disable
training_dataset_name = 'twitter'  # 'BESSTIE' or 'yelp' or 'bicodemix' or 'twitter'

training_dataset_task = 'sentiment'  # 'sentiment', 'sarcasm' or 'sarcasm-sentiment' (if sarcasm-sentiment, both sentiment and sarcasm labels must be present in the dataset)

use_spAtten = False # if True to use spAtten, otherwise use standard attention

#NOTE: sentiment always first in 'task' field (index 0), sarcasm second (index 1) where not in dictionary
#NOTE: task always lowercase, but check is always done
#NOTE: class '0' (negative) always first in classes list (index 0), class '1' (positive) second (index 1), etc.
#NOTE: if a task is not present, its classes list is empty (lenght of 0 not None)
#NOTE: BESSTIE dataset apply filter base on value of 'task', 'variety' and 'source' fields,
# if None, no filter applied for that field, only sample with correct values for that field are kept
# if task is 'sentiment' only sample with at least sentiment labels are kept,
# if 'sarcasm', only samples with at least sarcasm labels are kept
# EXAMPLE: task = 'sentiment', variety = 'en-IN', source = 'Reddit' means:
# keep only samples with sentiment labels not nan, variety 'en-IN' and source 'Reddit'

train_dataset_CFGs = {
    'BESSTIE':{
        'dataset_name': 'BESSTIE',
        'root_folder': './dataset/besstie',
        'file_name': 'train_SS_with_nan.csv',
        'classes': {
            'sentiment': ['0', '1'],
            'sarcasm': ['0', '1'],
        },
        'task': training_dataset_task,
        'variety': 'en-IN',
        'source': 'Reddit',
    },
    'bicodemix': {
        'dataset_name': 'bicodemix',
        'root_folder': './dataset/bicodemix',
        'file_name': 'train_SS.csv',
        'classes': {
            'sentiment': ['0', '1', '2'],
            'sarcasm': ['0', '1'],
        },
        'task': training_dataset_task,
    },
    'twitter': {
        'dataset_name': 'twitter',
        'root_folder': './dataset/twitter',
        'file_name': 'twitter_sentiment_analysis.csv',
        'classes': {
            'sentiment': ['Negative', 'Positive'],
            'sarcasm': [],
        },
        'task': training_dataset_task,
    }
}

valid_dataset_CFG = {
    'dataset_name': 'BESSTIE',
    'root_folder': './dataset/besstie',
    'file_name': 'valid_SS_with_nan.csv',
    'classes': {
        'sentiment': ['0', '1'],
        'sarcasm': ['0', '1'],
    },
    'task': train_dataset_CFGs[training_dataset_name]['task'],
    'variety': None,
    'source': None,
}

CFG = {
    'lr': 2e-5,
    'start_epoch': 0,
    'epochs': 30,
    'batch_size': 8,
    'max_length': 200,
    'min_length': 40 if training_dataset_name == 'twitter' else 1,
    "train_dataset_CFG": train_dataset_CFGs[training_dataset_name],
    "valid_dataset_CFG": valid_dataset_CFG,
    'model_name': 'bert-base-uncased',
    'classification_head': 'linear', # ['linear', 'conv', 'lstm', 'multi_task_conv', 'cross_talk_conv']
    'seed': 0,
}

models_root_dir = "./models"
os.makedirs(models_root_dir, exist_ok=True)

IS_MULTITASK = 'sentiment' in CFG['train_dataset_CFG']['task'].lower() and 'sarcasm' in CFG['train_dataset_CFG']['task'].lower()

if IS_MULTITASK:
    print("Training in multitask mode.")
    if CFG['classification_head'] in ['linear', 'conv', 'lstm']:
        raise ValueError(f"Invalid classification head {CFG['classification_head']} for multi-task learning. Please choose a multi-task head (multi_task_conv or cross_talk_conv).")
else:
    print("Training in single task mode.")
    if CFG['classification_head'] in ['multi_task_conv', 'cross_talk_conv']:
        raise ValueError(f"Invalid classification head {CFG['classification_head']} for single-task learning. Please choose a single-task head (linear, conv or lstm).")

print(f"Training with min_length={CFG['min_length']} and max_length={CFG['max_length']}")

# for wandb
run_name = f"{training_dataset_name}_{CFG['classification_head']}" + ("_spatten" if use_spAtten else "")
print(f"Run name is: {run_name}")

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CFG['seed'])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Datasets

In [None]:
if not os.path.exists("./dataset/besstie/train.csv") or not os.path.exists("./dataset/besstie/valid.csv"):
    print("Downloading BESSTIE dataset...")
    # Login using e.g. `huggingface-cli login` to access this dataset
    df = pd.read_csv("hf://datasets/unswnlporg/BESSTIE/train.csv")
    df.to_csv("./dataset/besstie/train.csv", index=False)
    df = pd.read_csv("hf://datasets/unswnlporg/BESSTIE/valid.csv")
    df.to_csv("./dataset/besstie/valid.csv", index=False)
    print("BESSTIE dataset downloaded.")
    
if not os.path.exists(os.path.join(CFG['train_dataset_CFG']['root_folder'], CFG['train_dataset_CFG']['file_name'])):
    raise Exception('Training file not found! Please check the train_dataset_CFG configuration.')


In [None]:
def get_dataset(dataset_CFG, minlength, maxlength, tokenizer):
    dataset_name = dataset_CFG['dataset_name'].lower()
    if dataset_name == 'twitter':
        from dataset.twitter.dataset_twitter import TwitterDataSet
        dataset = TwitterDataSet(
            **dataset_CFG,
            tokenizer=tokenizer,
            minlength=minlength,
            maxlength=maxlength
        )
    elif dataset_name == 'besstie':
        from dataset.besstie.dataset_besstie import BesstieDataSet
        dataset = BesstieDataSet(
            **dataset_CFG,
            tokenizer=tokenizer,
            minlength=minlength,
            maxlength=maxlength,
        )
    elif dataset_name == 'bicodemix':
        from dataset.bicodemix.dataset_bicodemix import BicodemixDataSet
        dataset = BicodemixDataSet(
            **dataset_CFG,
            tokenizer=tokenizer,
            minlength=minlength,
            maxlength=maxlength,
        )
    else:
        raise Exception(f"Dataset {dataset_name} not recognized.")
    return dataset


# Model

## Encoder

### Back-bone and tokenizer

In [None]:
def get_tokenizer_and_encoder(model_name:str):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    model = transformers.AutoModel.from_pretrained(model_name)
    return tokenizer, model

### Head

In [None]:
class MultiKernelConvs(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        kernel_sizes=(2, 3, 5),
        dropout=0.1
    ):
        super().__init__()

        self.convs = torch.nn.ModuleList([
            torch.nn.Conv1d(
                in_channels=input_size,
                out_channels=hidden_size//len(kernel_sizes),
                kernel_size=k,
                padding=k // 2
            )
            for k in kernel_sizes
        ])

        self.activation = torch.nn.ReLU()
        self.pool = torch.nn.AdaptiveAvgPool1d(1)
        self.dropout = torch.nn.Dropout(dropout)
        self.flatten = torch.nn.Flatten()

    def forward(self, x):
        # x: (B, H, L)
        conv_outputs = []

        for conv in self.convs:
            h = self.activation(conv(x))      # (B, C, L)
            h = self.pool(h).squeeze(-1)       # (B, C)
            conv_outputs.append(h)

        x = torch.cat(conv_outputs, dim=1)    # (B, C * num_kernels)
        x = self.flatten(self.dropout(x))

        return x

class ConvClassificationHead(torch.nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_labels=2, linear=True):
        super().__init__()

        if linear:
            self.conv = torch.nn.Sequential(
                MultiKernelConvs(
                    input_size=input_size,
                    hidden_size=hidden_size,
                    kernel_sizes=(3,),
                ), # (B, hidden_size)
                torch.nn.Linear(hidden_size, num_labels)
            )
        else:
            self.conv = torch.nn.Sequential(
                MultiKernelConvs(
                    input_size=input_size,
                    hidden_size=hidden_size,
                    kernel_sizes=(3,),
                ) # (B, hidden_size)
            )

    def forward(self, x):
        return self.conv(x)

class MultiTaskConvHead(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_sentiment_labels: int,
        num_sarcasm_labels: int
    ):
        super().__init__()

        self.sentiment_head = ConvClassificationHead(
            input_size=input_size,
            hidden_size=hidden_size,
            num_labels=num_sentiment_labels
        )

        self.sarcasm_head = ConvClassificationHead(
            input_size=input_size,
            hidden_size=hidden_size,
            num_labels=num_sarcasm_labels
        )

    def forward(self, sequence_output):
        """
        sequence_output: last_hidden_state from BERT
        shape: (batch, seq_len, hidden_size)
        """
        sentiment_logits = self.sentiment_head(sequence_output)
        sarcasm_logits = self.sarcasm_head(sequence_output)

        return {
            "sentiment": sentiment_logits,
            "sarcasm": sarcasm_logits
        }

class CrossTalkHead(torch.nn.Module):
    def __init__(
        self,
        input_size,
        conv_hidden_size,
        num_sentiment_labels,
        num_sarcasm_labels,
    ):
        super().__init__()

        self.encoder = ConvClassificationHead(
            input_size=input_size,
            hidden_size=conv_hidden_size,
            linear = False
        )

        # task-specific embeddings
        self.sentiment_embed = torch.nn.Linear(
            conv_hidden_size, conv_hidden_size
        )
        self.sarcasm_embed = torch.nn.Linear(
            conv_hidden_size, conv_hidden_size
        )

        # cross-talk layers
        self.sentiment_fuse = torch.nn.Linear(
            2 * conv_hidden_size, conv_hidden_size
        )
        self.sarcasm_fuse = torch.nn.Linear(
            2 * conv_hidden_size, conv_hidden_size
        )

        self.sentiment_out = torch.nn.Linear(
            conv_hidden_size, num_sentiment_labels
        )
        self.sarcasm_out = torch.nn.Linear(
            conv_hidden_size, num_sarcasm_labels
        )

    def forward(self, sequence_output):
        shared = self.encoder(sequence_output)

        # first linear layer
        sent_feat = self.sentiment_embed(shared)
        sarc_feat = self.sarcasm_embed(shared)

        # cross-talk
        sent_feat_cross = self.sentiment_fuse(
            torch.cat([sarc_feat, sent_feat], dim=-1)
        )
        sarc_feat_cross = self.sarcasm_fuse(
            torch.cat([sarc_feat, sent_feat], dim=-1)
        )

        return {
            "sentiment": self.sentiment_out(sent_feat_cross),
            "sarcasm": self.sarcasm_out(sarc_feat_cross)
        }

In [None]:
def get_classification_head(method: str, input_size:int, hidden_size: int, num_labels: int):
    num_sent_labels, num_sarc_labels = num_labels

    # Single task case
    if not IS_MULTITASK:
        num_task_labels = max(num_sent_labels, num_sarc_labels)

    if method == "linear":
        return torch.nn.Linear(input_size, num_task_labels)
    elif method == "conv":
        return ConvClassificationHead(input_size, hidden_size, num_task_labels)
    elif method == "lstm":
        return torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
    elif method == "multi_conv":
        return torch.nn.Sequential(
            MultiKernelConvs(
                input_size=input_size,
                hidden_size=hidden_size,
                kernel_sizes=(2, 3, 5),
                dropout=0.1
            ),
            torch.nn.Linear(hidden_size, num_task_labels)
        )
    elif method == 'multi_task_conv':
        return MultiTaskConvHead(input_size, hidden_size, num_sent_labels, num_sarc_labels)
    elif method == 'cross_talk_conv':
        return CrossTalkHead(input_size, hidden_size, num_sent_labels, num_sarc_labels)
    else:
        raise ValueError(f"Unknown classification head method: {method}")

### Classifier

In [None]:
class MyClassifier(torch.nn.Module):
    def __init__(self, base_model_name, classification_head_name, num_labels, multitask:bool):
        super().__init__()

        self.multitask = multitask

        num_sent_labels, num_sarc_labels = num_labels

        # Single task case
        if not self.multitask:
            num_task_labels = max(num_sent_labels, num_sarc_labels)

        self.tokenizer, self.base_model = get_tokenizer_and_encoder(base_model_name)
        self.hidden_size = self.base_model.config.hidden_size
        self.dropout = torch.nn.Dropout(self.base_model.config.hidden_dropout_prob)

        self.classification_head_name = classification_head_name

        self.classification_head = get_classification_head(
            classification_head_name, self.hidden_size, self.hidden_size, num_labels
        )

        if classification_head_name == "lstm":
            self.output_layer = torch.nn.Linear(self.hidden_size*2, num_task_labels)

    def get_tokenizer(self) -> transformers.PreTrainedTokenizer:
        return self.tokenizer

    def forward(self, inputs, task=None):
        outputs = self.base_model(**inputs)
        sequence = self.dropout(outputs.last_hidden_state)

        if self.classification_head_name == "linear":
            cls_rep = sequence[:, 0, :]
            logits = self.classification_head(cls_rep)

        elif self.classification_head_name == "conv":
            # x: (batch, seq_len, hidden_size)
            x = sequence.transpose(1, 2)  # -> (batch, hidden_size, seq_len)
            logits = self.classification_head(x)

        elif self.classification_head_name == "lstm":
            lstm_out, _ = self.classification_head(sequence)
            cls_rep = lstm_out[:, 0, :]
            logits = self.output_layer(cls_rep)

        elif self.classification_head_name == 'multi_task_conv':
            x = sequence.transpose(1, 2)
            logits = self.classification_head(x)

        elif self.classification_head_name == 'cross_talk_conv':
            x = sequence.transpose(1, 2)
            logits = self.classification_head(x)

        if self.multitask and task is not None:
            return logits[task]
        else:
            return logits


### SpAtten

#### Utility functions

In [None]:
def topk_masking(scores, keep_ratio, active_mask):
    """
    Create a hard mask by keeping the top-k tokens based on scores.
    Args:
        scores (torch.Tensor): Scores for each token (batch_size, seq_len).
        keep_ratio (float): Ratio of tokens to keep (between 0 and 1).
        mask (torch.Tensor): Original attention mask (batch_size, seq_len) with 1s for valid tokens and 0s for padding.
    Returns:
        torch.Tensor: Hard mask (batch_size, seq_len) with 1s for kept tokens and 0s for pruned tokens.
    """

    new_mask = torch.zeros_like(scores)

    for i, (sample, active_token) in enumerate(zip(scores, active_mask)):
        seq_len = active_token.sum().item()  # number of active tokens in the sample

        k = max(1, int(seq_len * keep_ratio))

        topk_indices = torch.topk(sample, k, dim=-1).indices

        mask = torch.zeros_like(sample)
        mask.scatter_(0, topk_indices, 1)

        new_mask[i, active_token.bool()] = mask[active_token.bool()]

    return new_mask

def magnitude_head_scores(attention_output, num_heads, head_mask):
    """
    attention_output: (batch, seq_len, hidden_size)
    returns: (batch, num_heads)
    """
    batch_size, seq_len, hidden_size = attention_output.size()
    head_dim = hidden_size // num_heads

    # E \in (batch, heads, L0, D)
    E = attention_output.view(
        batch_size, seq_len, num_heads, head_dim
    ).permute(0, 2, 1, 3)

    # s_h = sum_{l,d} |E|
    head_scores = E.abs().sum(dim=(2, 3))  # (batch, heads)

    head_scores = head_scores * head_mask # Apply head mask to zero out pruned heads

    return head_scores/seq_len # division by seq_len to normalize scores by sequence length


#### A layer of bert

In [None]:
class CascadingMaskBertLayer(BertLayer):
    def __init__(self, config: BertConfig, prune_token_percent, prune_head_percent, visualize=False):
        super().__init__(config)
        self.prune_token_percent = prune_token_percent
        self.prune_head_percent = prune_head_percent
        self.visualize = visualize

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        token_mask=None,
        head_mask=None,
        output_attentions=True,
        output_hidden_states=False,
    ):
        batch_size = hidden_states.size(0)
        num_heads = self.attention.self.num_attention_heads

        if token_mask is None:
            token_mask = (attention_mask[:, 0, 0, :] > -1).clone().int() # (batch_size, seq_len)

        if head_mask is None:
            head_mask = torch.ones(
                (batch_size, num_heads),
                device=hidden_states.device
            ) # (batch_size, num_heads)

        head_mask_expanded = head_mask[:, :, None, None]

        # ---- Apply previous cascade ----
        hidden_states = hidden_states * token_mask.unsqueeze(-1)

        # ---- Extend attention mask ----
        cascade_attn_mask = (1.0 - token_mask) * -1e4

        # ---- Self-attention ----
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask + cascade_attn_mask.unsqueeze(1).unsqueeze(2),
            head_mask=head_mask_expanded,
            output_attentions=output_attentions,
        )
        attention_output, attention_scores = self_attention_outputs

        if self.visualize:
            for sample in range(attention_scores.size(0)):
                if head_mask_expanded[sample,0,0,0] == 0:  # Check if head 0 is active for this sample
                    print(f"Sample {sample} head 0 is pruned, skipping attention score visualization.")
                    continue
                plt.figure()
                plt.title(f"Attention scores for sample {sample} (head 0):")
                sns.heatmap(attention_scores[sample,0,:,:].cpu().detach().numpy(), cmap='viridis')
                plt.show()


            plt.figure()
            plt.title(f"Head mask for each sample")
            sns.heatmap(head_mask.cpu().detach().numpy(), cmap='viridis')
            plt.ylabel("Sample index")
            plt.xlabel("Head index")
            plt.show()


        # ---- Compute new token decisions ----
        token_scores = attention_scores.sum(dim=(1,2))  # (batch_size, seq_len)

        new_token_mask = topk_masking(
            token_scores,
            keep_ratio=1-self.prune_token_percent,  # eliminate bottom pt% tokens
            active_mask=attention_mask[:,0,0,:] > -1
        )
        # Protect CLS
        new_token_mask[:, 0] = 1.0

        # ---- CASCADE ----
        token_mask = token_mask * new_token_mask



        # ---- Compute new head decisions ----
        heads_scores = magnitude_head_scores(attention_output, num_heads=num_heads, head_mask=head_mask)

        new_head_mask = topk_masking(
            heads_scores,  # (batch_size, num_heads)
            keep_ratio=1-self.prune_head_percent,
            active_mask=torch.ones_like(heads_scores)
        )

        # ---- CASCADE ----
        head_mask = new_head_mask


        batch_size, seq_len, hidden_size = attention_output.size()
        head_dim = hidden_size // num_heads

        E = attention_output.view(
            batch_size, seq_len, num_heads, head_dim
        ).permute(0, 2, 1, 3)

        E = E * head_mask[:, :, None, None] # Apply head mask to attention output

        attention_output = E.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, hidden_size)

        # ---- Feed-forward ----
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)

        return layer_output, token_mask, head_mask

#### Bert

In [None]:
class CascadingBertEncoder(BertEncoder):
    def __init__(self, config, visualize, visualize_prune_decisions=False):
        super().__init__(config)

        # Replace standard layers with cascading mask layers
        self.layer = torch.nn.ModuleList([
            CascadingMaskBertLayer(config, config.pt[i], config.ph[i], visualize=visualize[i])
            for i in range(config.num_hidden_layers)
        ])
        self.visualize_prune_decisions = visualize_prune_decisions

        self.text_history = [] # Store input texts for visualization

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        **kwargs,
    ):
        self.text_history = []
        batch_size = hidden_states.size(0)
        token_mask = None
        head_mask = None


        # Iterate through layers with cascading masks
        for i, layer_module in enumerate(self.layer):
            hidden_states, token_mask, head_mask = layer_module(
                hidden_states,
                attention_mask=attention_mask,
                token_mask=token_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if self.visualize_prune_decisions:
                print(f"Layer {i} active tokens:", token_mask.sum(dim=1))
                print(f"Layer {i} active heads:", head_mask.sum(dim=1))

            self.text_history.append(
                {
                    "layer": i,
                    "active_tokens": token_mask.cpu().detach().numpy(),
                    "active_heads": head_mask.cpu().detach().numpy()
                })

        # NOTE: Required to be compatible with BertModel output (we don't use all these fields, but they are expected by the model's forward method)
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    None,
                    None,
                    None,
                    None,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
            cross_attentions=None,
        )


#### Top level model
The only difference from previously defined classifier 'MyClassifier' is the overwrtite of 'self.bert.encoder' with a custom class to add the functionalities required.

In [None]:
class MyClassifierSpAtten(torch.nn.Module):
    def __init__(self, config:BertConfig, base_model_name="bert-base-uncased", num_labels=2, visualize_prune_decisions=False):
        super().__init__()
        self.bert = BertModel.from_pretrained(base_model_name, config=config)
        self.tokenizer = BertTokenizer.from_pretrained(base_model_name)

        # Replace standard encoder with cascading mask encoder
        self.bert.encoder = CascadingBertEncoder(
            self.bert.config, 
            visualize=[visualize_prune_decisions for _ in range(config.num_hidden_layers)], 
            visualize_prune_decisions=visualize_prune_decisions
        )

        num = num_labels if isinstance(num_labels, int) else num_labels[0]+num_labels[1]
        self.dense = torch.nn.Linear(config.hidden_size, num)

    def forward(self, inputs, task:str=None):
        bert_outputs = self.bert(**inputs)

        cls_token = bert_outputs.last_hidden_state[:, 0]  # Use [CLS] token representation
        return self.dense(cls_token)

In [None]:
bert_cfg = BertConfig.from_pretrained(CFG['model_name'])

bert_cfg.output_attentions = True
bert_cfg.output_hidden_states = False
bert_cfg.return_dict = True

pt_schedule = [0.1 for _ in range(bert_cfg.num_hidden_layers)]
for i in range(len(pt_schedule)):
    if i % 2 == 0:
        pt_schedule[i] = 0.0
pt_schedule[0] = 0.0  # No pruning in the first layer

ph_schedule = [0.03 for _ in range(bert_cfg.num_hidden_layers)]
for i in range(len(ph_schedule)):
    if i % 2 == 0:
        ph_schedule[i] = 0.0
ph_schedule[0] = 0.0  # No pruning in the first layer

def calculate_p_schedule(p_schedule):
    for i in range(len(p_schedule)):
        if i == 0:
            p_schedule[i] = 0.0
        else:
            p_schedule[i] = 1-(1-p_schedule[i-1])*(1-p_schedule[i])

    return p_schedule


bert_cfg.pt = calculate_p_schedule(pt_schedule)
print("Cascading token pruning schedule:", bert_cfg.pt)

bert_cfg.ph = calculate_p_schedule(ph_schedule)
print("Cascading head pruning schedule:", bert_cfg.ph)


# Not finetuned mode
model = MyClassifierSpAtten(
    config=bert_cfg,
    base_model_name=CFG['model_name'],
    num_labels=len(
        CFG['train_dataset_CFG']['classes'][CFG['train_dataset_CFG']['task']]
    ),
    visualize_prune_decisions=True
)

# Finetuned mode (uncomment to load a finetuned model, make sure to specify the correct path to the model checkpoint)
# model = get_model(CFG, "bert-base-uncased_linear_twitter_sentiment_spatten", 30, IS_MULTITASK, device)

tokenizer = BertTokenizer.from_pretrained(CFG['model_name'])

# Example texts to test the model and visualize the pruning decisions
texts = [
"The attention mechanism is becoming increasingly popular in Natural Language Processing (NLP) applications, showing superior performance than convolutional and recurrent architectures.",
"However, attention becomes the computation bot- tleneck because of its quadratic computational complexity to input length, complicated data movement and low arithmetic intensity.",
"Moreover, existing NN accelerators mainly focus on op- timizing convolutional or recurrent models, and cannot efficiently support attention.",
"In this paper, we present SpAtten, an efficient algorithm-architecture co-design that leverages token sparsity, head sparsity, and quantization opportunities to reduce the attention computation and memory access.",
"Inspired by the high redundancy of human languages, we propose the novel cascade token pruning to prune away unimportant tokens in the sentence.",
"We also propose cascade head pruning to remove unessential heads.",
"Cascade pruning is fundamentally different from weight pruning since there is no trainable weight in the attention mechanism, and the pruned tokens and heads are selected on the fly.",
"To efficiently support them on hardware, we design a novel top-k engine to rank token and head importance scores with high throughput.",
"Furthermore, we propose progressive quantization that first fetches MSBs only and performs the computation; if the confidence is low, it fetches LSBs and recomputes the attention outputs, trading computation for memory reduction",
]

inputs = tokenizer(
    texts,
    add_special_tokens=True,
    return_token_type_ids=False,
    padding=True,
    max_length=200,
    truncation=True,
    return_attention_mask=True,
    return_tensors='pt',
).to(device)

model(inputs)

# print(model.bert.encoder.text_history)
for layer_info in model.bert.encoder.text_history:
    layer = layer_info["layer"]
    token_mask = layer_info["active_tokens"]

    print(f"\nInput used at layer {layer}:")
    for j, text in enumerate(texts):
        active_token_ids = inputs["input_ids"][j][token_mask[j] == 1]
        tokens = tokenizer.convert_ids_to_tokens(active_token_ids)

        print(tokens)

# Train

## Multitask

In [None]:
def train_SS(model, train_loader, optimizer, criterion, device):
    model.train()

    train_sarc_loss = 0.0
    train_sent_loss = 0.0
    train_sarc_acc = 0.0
    train_sent_acc = 0.0
    c1, c2 = criterion

    pbar = tqdm.tqdm(train_loader)
    for batch in pbar:
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }

        local_labels = batch['label'].to(device)
        outputs = model(inputs)

        sent_loss = c1(outputs['sentiment'], local_labels[:,0])
        sarc_loss = c2(outputs['sarcasm'], local_labels[:,1])

        # loss = sarc_loss + sent_loss
        # loss.backward()
        optimizer.pc_backward([sarc_loss, sent_loss])
        optimizer.step()
        optimizer.zero_grad()

        train_sarc_loss += sarc_loss.item()
        train_sent_loss += sent_loss.item()

        _, preds_sarc = torch.max(outputs['sarcasm'], dim=1)
        _, preds_sent = torch.max(outputs['sentiment'], dim=1)
        train_sarc_acc += torch.sum(preds_sarc == local_labels[:,1]).item()
        train_sent_acc += torch.sum(preds_sent == local_labels[:,0]).item()

    return train_sarc_loss / len(train_loader), train_sent_loss / len(train_loader), train_sarc_acc / (len(train_loader.dataset)), train_sent_acc / (len(train_loader.dataset))


In [None]:
def train_SS_spAtten(model, train_loader, optimizer, criterion, device):
    model.train()

    train_sarc_loss = 0.0
    train_sent_loss = 0.0
    train_sarc_acc = 0.0
    train_sent_acc = 0.0
    c1, c2 = criterion

    pbar = tqdm.tqdm(train_loader)
    for batch in pbar:
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }

        local_labels = batch['label'].to(device)
        outputs = model(inputs)

        sent_loss = c1(outputs[:,:3], local_labels[:,0])
        sarc_loss = c2(outputs[:,3:], local_labels[:,1])

        # loss = sarc_loss + sent_loss
        # loss.backward()
        optimizer.pc_backward([sarc_loss, sent_loss])
        optimizer.step()
        optimizer.zero_grad()

        train_sarc_loss += sarc_loss.item()
        train_sent_loss += sent_loss.item()

        _, preds_sarc = torch.max(outputs[:,:3], dim=1)
        _, preds_sent = torch.max(outputs[:,3:], dim=1)
        train_sarc_acc += torch.sum(preds_sarc == local_labels[:,1]).item()
        train_sent_acc += torch.sum(preds_sent == local_labels[:,0]).item()

    return train_sarc_loss / len(train_loader), train_sent_loss / len(train_loader), train_sarc_acc / (len(train_loader.dataset)), train_sent_acc / (len(train_loader.dataset))


## Singletask

In [None]:
def train(model, task, train_loader, optimizer, criterion, device):
    model.train()

    train_loss = 0.0
    train_acc = 0.0

    pbar = tqdm.tqdm(train_loader)
    for batch in pbar:
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }

        local_labels = batch['label'].flatten().to(device)
        outputs = model(inputs, task=task)

        loss = criterion(outputs, local_labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()

        _, preds = torch.max(outputs, dim=1)
        train_acc += torch.sum(preds == local_labels).item()

    return train_loss / len(train_loader), train_acc / (len(train_loader.dataset))

# Validation

## Multitask

In [None]:
def validate_SS(model, val_loader, criterion, device):
    model.eval()
    val_sarc_acc = 0.0
    val_sent_acc = 0.0
    val_sarc_loss = 0.0
    val_sent_loss = 0.
    all_sent_preds = []
    all_sent_labels = []
    all_sarc_preds = []
    all_sarc_labels = []

    with torch.no_grad():
        for batch in val_loader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            local_labels = batch['label'].to(device)
            outputs = model(inputs)

            sentiment_criterion, sarcasm_criterion = criterion
            sarc_loss = sarcasm_criterion(outputs['sarcasm'], local_labels[:,1])
            sent_loss = sentiment_criterion(outputs['sentiment'], local_labels[:,0])

            _, preds_sarc = torch.max(outputs['sarcasm'], dim=1)
            _, preds_sent = torch.max(outputs['sentiment'], dim=1)

            val_sent_acc += torch.sum(preds_sent == local_labels[:,0]).item()
            val_sarc_acc += torch.sum(preds_sarc == local_labels[:,1]).item()

            val_sarc_loss += sarc_loss.item()
            val_sent_loss += sent_loss.item()

            all_sent_preds.extend(preds_sent.cpu().numpy())
            all_sent_labels.extend(local_labels[:,0].cpu().numpy())
            all_sarc_preds.extend(preds_sarc.cpu().numpy())
            all_sarc_labels.extend(local_labels[:,1].cpu().numpy())


    f1_sent = f1_score(all_sent_labels, all_sent_preds, average='macro')
    f1_sarc = f1_score(all_sarc_labels, all_sarc_preds, average='macro')
    return val_sarc_loss / len(val_loader), val_sarc_acc / len(val_loader.dataset), val_sent_loss / len(val_loader), val_sent_acc / len(val_loader.dataset), f1_sarc, f1_sent

In [None]:
def validate_SS_spAtten(model, val_loader, criterion, device):
    model.eval()
    val_sarc_acc = 0.0
    val_sent_acc = 0.0
    val_sarc_loss = 0.0
    val_sent_loss = 0.
    all_sent_preds = []
    all_sent_labels = []
    all_sarc_preds = []
    all_sarc_labels = []

    with torch.no_grad():
        for batch in val_loader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            local_labels = batch['label'].to(device)
            outputs = model(inputs)

            sentiment_criterion, sarcasm_criterion = criterion
            sarc_loss = sarcasm_criterion(outputs[:,3:], local_labels[:,1])
            sent_loss = sentiment_criterion(outputs[:,:3], local_labels[:,0])

            _, preds_sarc = torch.max(outputs[:,3:], dim=1)
            _, preds_sent = torch.max(outputs[:,:3], dim=1)

            val_sent_acc += torch.sum(preds_sent == local_labels[:,0]).item()
            val_sarc_acc += torch.sum(preds_sarc == local_labels[:,1]).item()

            val_sarc_loss += sarc_loss.item()
            val_sent_loss += sent_loss.item()

            all_sent_preds.extend(preds_sent.cpu().numpy())
            all_sent_labels.extend(local_labels[:,0].cpu().numpy())
            all_sarc_preds.extend(preds_sarc.cpu().numpy())
            all_sarc_labels.extend(local_labels[:,1].cpu().numpy())


    f1_sent = f1_score(all_sent_labels, all_sent_preds, average='macro')
    f1_sarc = f1_score(all_sarc_labels, all_sarc_preds, average='macro')
    return val_sarc_loss / len(val_loader), val_sarc_acc / len(val_loader.dataset), val_sent_loss / len(val_loader), val_sent_acc / len(val_loader.dataset), f1_sarc, f1_sent

## Singletask

In [None]:
def validate(model, task, val_loader, criterion, device):
    model.eval()
    val_acc = 0.0
    val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            local_labels = batch['label'].flatten().to(device)
            outputs = model(inputs, task=task)

            loss = criterion(outputs, local_labels)
            _, preds = torch.max(outputs, dim=1)
            val_acc += torch.sum(preds == local_labels).item()
            val_loss += loss.item()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(local_labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds, average='macro')
    return val_loss / len(val_loader), val_acc / len(val_loader.dataset), f1

# Wandb

In [None]:
if CFG['train_dataset_CFG'].get('variety', None) == None:
    run_id = f"{CFG['model_name']}_{CFG['classification_head']}_{CFG['train_dataset_CFG']['dataset_name']}_{CFG['train_dataset_CFG']['task']}" + ("_spatten" if use_spAtten else "")
else:
    run_id = f"{CFG['model_name']}_{CFG['classification_head']}_{CFG['train_dataset_CFG']['dataset_name']}_{CFG['train_dataset_CFG']['variety']}_{CFG['train_dataset_CFG']['task']}" + ("_spatten" if use_spAtten else "")
if ENABLE_WANDB:
    import wandb
    run = wandb.init(
        entity="elena-nespolo02-politecnico-di-torino",
        project="Figurative Analysis",
        name=run_name,
        id=run_id,
        resume="allow",
        config=CFG,
        tags=[CFG['train_dataset_CFG']['dataset_name'], CFG['train_dataset_CFG']['task'], CFG['model_name']]
    )

    wandb.define_metric("epoch/step")
    wandb.define_metric("epoch/*", step_metric="epoch/step")

    wandb.define_metric("train/step")
    wandb.define_metric("train/*", step_metric="train/step")

    wandb.define_metric("validate/step")
    wandb.define_metric("validate/*", step_metric="validate/step")


# ML pipeline

In [None]:
model_name = CFG['model_name']
tokenizer, model = get_tokenizer_and_encoder(model_name)

tokenizer = transformers.BertTokenizer.from_pretrained(model_name)

# setup classifier model
if use_spAtten:
    model = MyClassifierSpAtten(
        config=bert_cfg,
        base_model_name=CFG['model_name'],
        num_labels=[len(v) if k in CFG['train_dataset_CFG']['task'].lower() else 0 for k,v in CFG['train_dataset_CFG']['classes'].items()],
        visualize_prune_decisions=False
    ).to(device)
else:
    model = MyClassifier(
        base_model_name=model_name,
        classification_head_name=CFG['classification_head'],
        num_labels=[len(v) if k in CFG['train_dataset_CFG']['task'].lower() else 0 for k,v in CFG['train_dataset_CFG']['classes'].items()],
        multitask=IS_MULTITASK
    ).to(device)

print(model)
train_ds = get_dataset(
    CFG['train_dataset_CFG'],
    minlength=CFG['min_length'],
    maxlength=CFG['max_length'],
    tokenizer=tokenizer,
)
val_ds = get_dataset(
    CFG['valid_dataset_CFG'],
    minlength=CFG['min_length'],
    maxlength=CFG['max_length'],
    tokenizer=tokenizer,
)

print("Training dataset size:", len(train_ds))
print("Validation dataset size:", len(val_ds))

labels_count = {}
for t, cs in train_ds.get_label_count().items():
    if t in CFG['train_dataset_CFG']['task'].lower():
        labels_count[t] = [c for c in cs.values()]
print(labels_count)

if IS_MULTITASK:
    optimizer = PCGrad(torch.optim.Adam(model.parameters(), lr=CFG['lr']))
    sarc_weights = labels_count['sarcasm']
    sent_weights = labels_count['sentiment']
    sum_sarc = sum(sarc_weights)
    sum_sent = sum(sent_weights)
    sarcasm_criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor([sw/sum_sarc for sw in sarc_weights], dtype=torch.float).to(device)
    )
    sentiment_criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor([sw/sum_sent for sw in sent_weights], dtype=torch.float).to(device)
    )
    criterion = [sentiment_criterion, sarcasm_criterion] # a list of per-task losses
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'])
    weights = labels_count[list(labels_count.keys())[0]]
    criterion = torch.nn.CrossEntropyLoss(
        weight=torch.tensor([w/sum(weights) for w in weights], dtype=torch.float).to(device)
    )

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=CFG['batch_size'],
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=CFG['batch_size'],
    shuffle=False
)

# Loading form a starting point
if CFG['start_epoch'] > 0 and ENABLE_WANDB:
    artifact = run.use_artifact(f'elena-nespolo02-politecnico-di-torino/Figurative Analysis/{run_id}:epoch_{CFG['start_epoch']}', type='model')
    artifact_dir = artifact.download()

    artifact_path = os.path.join(artifact_dir, run_id+f"_epoch_{CFG['start_epoch']}.pth")

    checkpoint = torch.load(artifact_path, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    if not IS_MULTITASK:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        optimizer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

## Main Loop

In [None]:
for epoch in range(CFG['start_epoch']+1,CFG['epochs']+1):
    print(f"Epoch {epoch}/{CFG['epochs']}")

    if IS_MULTITASK:
        if use_spAtten:
            epoch_sarc_loss, epoch_sent_loss, epoch_sarc_acc, epoch_sent_acc = train_SS_spAtten(model, train_loader, optimizer, criterion, device)

            val_sarc_loss, val_sarc_acc, val_sent_loss, val_sent_acc, f1_sarc, f1_sent = validate_SS_spAtten(model, val_loader, criterion, device)
        else:
            epoch_sarc_loss, epoch_sent_loss, epoch_sarc_acc, epoch_sent_acc = train_SS(model, train_loader, optimizer, criterion, device)

            val_sarc_loss, val_sarc_acc, val_sent_loss, val_sent_acc, f1_sarc, f1_sent = validate_SS(model, val_loader, criterion, device)

        if ENABLE_WANDB:
            run.log({
                    "epoch/step": epoch,
                    "epoch/train_sarc_loss": epoch_sarc_loss,
                    "epoch/train_sent_loss": epoch_sent_loss,
                    "epoch/train_sarc_acc": epoch_sarc_acc,
                    "epoch/train_sent_acc": epoch_sent_acc,
                    "epoch/val_sarc_loss": val_sarc_loss,
                    "epoch/val_sent_loss": val_sent_loss,
                    "epoch/val_sarc_acc": val_sarc_acc,
                    "epoch/val_sent_acc": val_sent_acc
                },
                commit=True,
            )
        print(f"Training Sarcasm Loss: {epoch_sarc_loss:.4f}")
        print(f"Training Sentiment Loss: {epoch_sent_loss:.4f}")
        print(f"Training Sarcasm Acc: {epoch_sarc_acc:.4f}")
        print(f"Training Sentiment Acc: {epoch_sent_acc:.4f}")
        print(f"Validation Sarcasm Loss: {val_sarc_loss:.4f}")
        print(f"Validation Sentiment Loss: {val_sent_loss:.4f}")
        print(f"Validation Sarcasm Acc: {val_sarc_acc:.4f}")
        print(f"Validation Sentiment Acc: {val_sent_acc:.4f}")

    else:
        epoch_loss, epoch_acc = train(model, CFG['train_dataset_CFG']['task'], train_loader, optimizer, criterion, device)

        val_loss, val_acc, val_f1 = validate(model, CFG['valid_dataset_CFG']['task'], val_loader, criterion, device)

        if ENABLE_WANDB:
            run.log({
                    "epoch/step": epoch,
                    "epoch/train_loss": epoch_loss,
                    "epoch/train_acc": epoch_acc,
                    "epoch/val_loss": val_loss,
                    "epoch/val_acc": val_acc,
                    "epoch/val_f1": val_f1
                },
                commit=True,
            )

        print(f"Training Loss: {epoch_loss:.4f}")
        print(f"Training Acc: {epoch_acc:.4f}")
        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Acc: {val_acc:.4f}")

    if (epoch % 10) == 0 or (epoch == CFG['epochs']):
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.optimizer.state_dict() if IS_MULTITASK else optimizer.state_dict(),
            "epoch/step": epoch
        }

        file_name = f"{run_id}_epoch_{epoch}.pth"

        # Saving the progress
        file_path = os.path.join(models_root_dir, file_name)
        torch.save(checkpoint, file_path)

        print(f"Model saved to {file_path}")

        if ENABLE_WANDB:
            artifact = wandb.Artifact(name=run_id, type="model")
            artifact.add_file(file_path)

            run.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])

if ENABLE_WANDB:
    run.finish()

In [None]:
# If previous cell crashes
if ENABLE_WANDB:
    run.finish()

# Testing

## Load a model

In [None]:
def get_model(CFG, run_id, epoch, is_multitask, device):
    if use_spAtten:
        model = MyClassifierSpAtten(
            config=bert_cfg,
            base_model_name=CFG['model_name'],
            num_labels=[len(v) if k in CFG['train_dataset_CFG']['task'].lower() else 0 for k,v in CFG['train_dataset_CFG']['classes'].items()],
            visualize_prune_decisions=False
        ).to(device)
    else:
        model = MyClassifier(
            base_model_name=CFG['model_name'],
            classification_head_name=CFG['classification_head'],
            num_labels=[len(v) if k in CFG['train_dataset_CFG']['task'].lower() else 0 for k,v in CFG['train_dataset_CFG']['classes'].items()],
            multitask=IS_MULTITASK
        ).to(device)

    if ENABLE_WANDB:
        artifact = run.use_artifact(f'elena-nespolo02-politecnico-di-torino/Figurative Analysis/{run_id}:epoch_{epoch}', type='model')
        artifact_dir = artifact.download()

        artifact_path = os.path.join(artifact_dir, run_id+f"_epoch_{epoch}.pth")

        checkpoint = torch.load(artifact_path, map_location=device)

    else:
        file_name = f"{run_id}_epoch_{epoch}.pth"
        model_path = os.path.join(models_root_dir, file_name)

        checkpoint = torch.load(model_path, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    return model

## For each variety

In [None]:
test_model = get_model(CFG, run_id, CFG['epochs'], IS_MULTITASK, device)
train_variety = CFG['train_dataset_CFG'].get('variety', 'Unknown')
train_source = CFG['train_dataset_CFG'].get('source', 'Google') 
train_task = CFG['train_dataset_CFG']['task']

target_varieties = ['en-AU', 'en-IN', 'en-UK']
row_results = {} 

print(f"\n" + "="*50)
print(f" EVALUATION REPORT")
print(f" Model Trained on: {train_variety}")
print(f" Fixed Source:     {train_source}")
print(f" Task:             {train_task}")
print("="*50 + "\n")

for target_variety in target_varieties:
    dataset_CFG = CFG['valid_dataset_CFG'].copy()
    dataset_CFG['variety'] = target_variety
    dataset_CFG['source'] = train_source 
    dataset_CFG['task'] = train_task

    from dataset.besstie import dataset_besstie

    try:
        test_ds = dataset_besstie.BesstieDataSet(
            **dataset_CFG,
            tokenizer=tokenizer,
            min_length=1,
            max_length=200,
        )
    except Exception as e:
        row_results[target_variety] = 0.0
        continue

    if len(test_ds) == 0:
        row_results[target_variety] = 0.0
        continue

    test_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=CFG['batch_size'],
        shuffle=False
    )

    if IS_MULTITASK:
        criterion = [torch.nn.CrossEntropyLoss().to(device), torch.nn.CrossEntropyLoss().to(device)]
        _, test_sarc_f1, _, test_sent_f1 = validate_SS(test_model, test_loader, criterion, device)
        row_results[target_variety] = (test_sarc_f1, test_sent_f1)
        print(f"Test on {target_variety}: Sarcasm F1={test_sarc_f1:.4f} | Sentiment F1={test_sent_f1:.4f}")
    else:
        criterion = torch.nn.CrossEntropyLoss().to(device)
        _, test_f1 = validate(test_model, train_task, test_loader, criterion, device)
        row_results[target_variety] = test_f1
        print(f"Test on {target_variety}: F1={test_f1:.4f}")

print("\n" + "="*50)
print(" >>> SAVE THE RESULTS AND COPY THEM INTO THE PLOT CELL <<<")
print("="*50)

if not IS_MULTITASK:
    formatted_dict = f"'{train_variety}': {{'en-AU': {row_results.get('en-AU', 0):.4f}, 'en-IN': {row_results.get('en-IN', 0):.4f}, 'en-UK': {row_results.get('en-UK', 0):.4f}}},"
    print(formatted_dict)
else:
    formatted_dict = f"'{train_variety}': {{'en-AU': {row_results.get('en-AU', (0,0))}, 'en-IN': {row_results.get('en-IN', (0,0))}, 'en-UK': {row_results.get('en-UK', (0,0))}}},"
    print(formatted_dict)
print("="*50 + "\n")

In [None]:
import seaborn as sns

results_collection = {
    # Trained on en-AU
    'en-AU': {'en-AU': 0., 'en-IN': 0., 'en-UK': 0.},

    # Trained on en-IN
    'en-IN': {'en-AU': 0., 'en-IN': 0., 'en-UK': 0.},

    # Trained on en-UK
    'en-UK': {'en-AU': 0., 'en-IN': 0., 'en-UK': 0.},
}

TASK_NAME = "Sentiment"
SOURCE_NAME = "GOOGLE"

# If Multitask -> Task to plot: 0=Sarcasm, 1=Sentiment
MULTITASK_INDEX = 1 

if len(results_collection) == 0:
    print("No data found in results_collection!")
else:
    data = []
    row_labels = ['en-AU', 'en-IN', 'en-UK']
    col_labels = ['en-AU', 'en-IN', 'en-UK']
    

    # Check if is a tuple (Multitask) or a float (Singletask)
    first_val = list(list(results_collection.values())[0].values())[0]
    is_multitask_data = isinstance(first_val, (list, tuple))

    for row in row_labels:
        row_data = []
        if row in results_collection:
            for col in col_labels:
                val = results_collection[row][col]
                if is_multitask_data:
                    row_data.append(val[MULTITASK_INDEX])
                else:
                    row_data.append(val)
        else:
            row_data = [0.0, 0.0, 0.0] # if the row is missing fill with zeros
        data.append(row_data)

    df_heatmap = pd.DataFrame(data, index=row_labels, columns=col_labels)

    plt.figure(figsize=(6, 5))
    sns.set(font_scale=1.1)
    ax = sns.heatmap(df_heatmap, annot=True, cmap="RdBu_r", fmt=".3f",
                     linewidths=.5, cbar=False, vmin=0.4, vmax=1.0)
    title_str = f"BERT - {SOURCE_NAME} {TASK_NAME}"
    plt.title(title_str, fontsize=14, weight='bold', pad=15)
    plt.ylabel("Trained On", fontsize=12, weight='bold')
    plt.xlabel("Tested On", fontsize=12, weight='bold')
    plt.yticks(rotation=0)
    plt.show()

## For different classification heads

In [None]:
results = {}

for class_head in ['linear', 'conv', 'lstm']:
    CFG['classification_head'] = class_head
    run_id = f"bert-base-uncased_{class_head}_{CFG['train_dataset_CFG']['dataset_name']}_{CFG['train_dataset_CFG']['task']}" + ("_spatten" if use_spAtten else "")
    run_name = f"{training_dataset_name}_{class_head}"

    test_model = get_model(CFG, run_id, 30, IS_MULTITASK, device)

    labels_count = {}
    for t, cs in train_ds.get_label_count().items():
        if t in CFG['train_dataset_CFG']['task'].lower():
            labels_count[t] = [c for c in cs.values()]

    for variety in ['en-AU', 'en-IN', 'en-UK']:
        print(f"Testing on variety: {variety}")
        for source, task in [('Google', 'sentiment'), ('Reddit', 'sentiment'), ('Reddit', 'sarcasm')]:
            if task not in CFG['train_dataset_CFG']['task'].lower():
                continue

            print(f"\tTesting on {source} for task: {task}")

            dataset_CFG = CFG['valid_dataset_CFG'].copy()
            dataset_CFG['variety'] = variety
            dataset_CFG['source'] = source
            dataset_CFG['task'] = task

            from dataset.besstie import dataset_besstie

            test_ds = dataset_besstie.BesstieDataSet(
                **dataset_CFG,
                tokenizer=tokenizer,
                min_length=1,
                max_length=200,
            )

            test_loader = torch.utils.data.DataLoader(
                test_ds,
                batch_size=CFG['batch_size'],
                shuffle=False
            )
            print(f"\t\tTest dataset size: {len(test_ds)}")

            # If the model was trained on multitask, we still test one task at a time
            if IS_MULTITASK:
                weights = labels_count[task]
                sum_weights = sum(weights)
                criterion = torch.nn.CrossEntropyLoss(
                    weight=torch.tensor([sw/sum_weights for sw in weights], dtype=torch.float).to(device)
                )

            else:
                weights = labels_count[list(labels_count.keys())[0]]
                criterion = torch.nn.CrossEntropyLoss(
                    weight=torch.tensor([w/sum(weights) for w in weights], dtype=torch.float).to(device)
                )

            test_loss, test_acc, test_f1 = validate(test_model, task, test_loader, criterion, device)

            results[(class_head, variety, source, task)] = {
                'loss': test_loss,
                'acc': test_acc,
                'f1': test_f1
            }

In [None]:
print("Final Results")

for source in ['Google', 'Reddit']:
    plt.figure()
    data = np.array([v['f1'] for k,v in results.items() if k[2]==source]).reshape(3,3)

    sns.set(font_scale=1.1)
    sns.heatmap(
            data,
            annot=True,
            yticklabels=['linear', 'conv', 'lstm'],
            xticklabels=['en-AU', 'en-IN', 'en-UK'],
            cmap='Blues',
            linewidths=.5,
            cbar=False,
            fmt=".3f"
        )
    plt.ylabel("Classification Head", fontsize=12, weight='bold')
    plt.xlabel(source, fontsize=12, weight='bold')
    plt.title(f'BERT - {CFG["train_dataset_CFG"]["dataset_name"]} {CFG["train_dataset_CFG"]["task"]}', fontsize=14, weight='bold', pad=15)
    plt.show()

# Domain shift

In [None]:
from collections import Counter
import torch

def extract_token_vocab(dataset, topk=None, ignore_special_tokens=True):
    vocab_counter = Counter()

    if ignore_special_tokens:
        special_ids = set(dataset.tokenizer.all_special_ids)
    else:
        special_ids = set()

    for i in range(len(dataset)):
        item = dataset[i]
        input_ids = item["input_ids"]

        input_ids = [
            tid.item() 
            for tid in input_ids 
            if tid.item() not in special_ids
        ]

        vocab_counter.update(input_ids)

    if topk is not None:
        vocab_counter = Counter(dict(vocab_counter.most_common(topk)))
    return set(vocab_counter.keys()), vocab_counter

def vocabulary_overlap(dataset_a, dataset_b):
    vocab_a, _ = extract_token_vocab(dataset_a, topk=1000)
    vocab_b, _ = extract_token_vocab(dataset_b, topk=1000)

    intersection = vocab_a & vocab_b
    union = vocab_a | vocab_b

    return {
        "vocab_size_a": len(vocab_a),
        "vocab_size_b": len(vocab_b),
        "intersection_size": len(intersection),
        "jaccard_similarity": len(intersection) / len(union),
        "coverage_a_in_b": len(intersection) / len(vocab_a),
        "coverage_b_in_a": len(intersection) / len(vocab_b),
    }

import numpy as np

def sentence_length_distribution(dataset, ignore_special_tokens=True):
    lengths = []

    for i in range(len(dataset)):
        input_ids = dataset[i]["input_ids"]

        if ignore_special_tokens:
            special_ids = set(dataset.tokenizer.all_special_ids)
            length = sum(tid.item() not in special_ids for tid in input_ids)
        else:
            length = len(input_ids)

        lengths.append(length)

    return np.array(lengths)

def summarize_lengths(lengths):
    return {
        "mean": float(np.mean(lengths)),
        "std": float(np.std(lengths)),
        "median": float(np.median(lengths)),
        "min": int(np.min(lengths)),
        "max": int(np.max(lengths)),
        "p95": float(np.percentile(lengths, 95)),
    }

def overlap_twitter():
    results = {}
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    source_ds = get_dataset(
        train_dataset_CFGs['twitter'],
        minlength=40,
        maxlength=200,
        tokenizer=tokenizer,
    )
    for source, task in [('Reddit', 'sentiment'), ('Google', 'sentiment'), ('Reddit', 'sarcasm')]:
        if task not in train_dataset_CFGs['twitter']['task'].lower():
            continue
        
        print(f"Using {source} for task {task}")

        for variety_b in ['en-AU', 'en-IN', 'en-UK']:
            print(f"\tTarget variety: {variety_b}")

            dataset_CFG_b = train_dataset_CFGs['BESSTIE'].copy()
            dataset_CFG_b['variety'] = variety_b
            dataset_CFG_b['source'] = source
            dataset_CFG_b['task'] = task

            from dataset.besstie import dataset_besstie

            target_ds = dataset_besstie.BesstieDataSet(
                **dataset_CFG_b,
                tokenizer=tokenizer,
                min_length=1,
                max_length=200,
            )

            overlap_stats = vocabulary_overlap(source_ds, target_ds)
            print("\t", overlap_stats)

            results[(source, task, 'twitter', variety_b)] = overlap_stats

    return results
def overlap_bicodemix():
    results = {}
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    source_ds = get_dataset(
        train_dataset_CFGs['bicodemix'],
        minlength=40,
        maxlength=200,
        tokenizer=tokenizer,
    )
    for source, task in [('Reddit', 'sentiment'), ('Google', 'sentiment'), ('Reddit', 'sarcasm')]:
        if task not in train_dataset_CFGs['bicodemix']['task'].lower():
            continue
        
        print(f"Using {source} for task {task}")

        for variety_b in ['en-AU', 'en-IN', 'en-UK']:
            print(f"\tTarget variety: {variety_b}")

            dataset_CFG_b = train_dataset_CFGs['BESSTIE'].copy()
            dataset_CFG_b['variety'] = variety_b
            dataset_CFG_b['source'] = source
            dataset_CFG_b['task'] = task

            from dataset.besstie import dataset_besstie

            target_ds = dataset_besstie.BesstieDataSet(
                **dataset_CFG_b,
                tokenizer=tokenizer,
                min_length=1,
                max_length=200,
            )

            overlap_stats = vocabulary_overlap(source_ds, target_ds)
            print("\t", overlap_stats)

            results[(source, task, 'bicodemix', variety_b)] = overlap_stats

    return results

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

dataset_a = "twitter"
dataset_b = "BESSTIE"

task = "sarcasm"

import numpy as np
import seaborn as sns

def compute_coverage_matrix(
    varieties,
    source,
    task,
    train_dataset_CFGs,
    tokenizer
):
    matrix = np.zeros((len(varieties), len(varieties)))

    for i, variety_a in enumerate(varieties):
        cfg_a = train_dataset_CFGs['BESSTIE'].copy()
        cfg_a.update({
            "variety": variety_a,
            "source": source,
            "task": task,
        })

        source_ds = get_dataset(
            cfg_a,
            minlength=1,
            maxlength=200,
            tokenizer=tokenizer,
        )

        for j, variety_b in enumerate(varieties):
            cfg_b = train_dataset_CFGs['BESSTIE'].copy()
            cfg_b.update({
                "variety": variety_b,
                "source": source,
                "task": task,
            })

            dest_ds = get_dataset(
                cfg_b,
                minlength=1,
                maxlength=200,
                tokenizer=tokenizer,
            )

            overlap = vocabulary_overlap(source_ds, dest_ds)
            matrix[i, j] = overlap["coverage_a_in_b"]

    return matrix, varieties, varieties

import matplotlib.pyplot as plt

def plot_coverage_matrix(matrix, row_labels, col_labels, title):
    fig = plt.figure(figsize=(6, 5))

    sns.set(font_scale=1.1)
    ax = sns.heatmap(matrix, annot=True, cmap="RdBu_r", fmt=".3f",
                     linewidths=.5, cbar=False, vmin=0.4, vmax=1.0)
    ax.set_xticklabels(col_labels)
    ax.set_yticklabels(row_labels)
    ax.set_xlabel("Variety B", fontsize=12, weight='bold')
    ax.set_ylabel("Variety A", fontsize=12, weight='bold')
    ax.set_title(title, fontsize=14, weight='bold', pad=15)
    fig.tight_layout()
    plt.show()

varieties = ['en-AU', 'en-IN', 'en-UK']

for source, task in [
    ('Reddit', 'sentiment'),
    ('Google', 'sentiment'),
    ('Reddit', 'sarcasm')
]:
    matrix, rows, cols = compute_coverage_matrix(
        varieties=varieties,
        source=source,
        task=task,
        train_dataset_CFGs=train_dataset_CFGs,
        tokenizer=tokenizer,
    )

    plot_coverage_matrix(
        matrix,
        rows,
        cols,
        title=f"Coverage of A in B - {source.upper()} {task}"
    )


In [None]:
df = pd.DataFrame.from_dict(
    {
        (source, task, variety_b): stats['coverage_a_in_b']
        for (source, task, _, variety_b), stats in results.items()
    },
    orient='index',
    columns=['coverage']
).reset_index()
df[['source', 'task', 'variety_b']] = pd.DataFrame(df['index'].tolist(), index=df.index)
df = df.drop(columns=['index'])

fig = plt.figure(figsize=(6, 5))

heatmap_df = df.pivot(
    index="source",
    columns="variety_b",
    values="coverage"
)

sns.set(font_scale=1.1)
ax = sns.heatmap(heatmap_df, annot=True, cmap="RdBu_r", fmt=".3f",
                    linewidths=.5, cbar=False, vmin=0.4, vmax=1.0)
ax.set_xticklabels(df['variety_b'].unique())
ax.set_yticklabels(df['source'].unique())
ax.set_xlabel("Variety B", fontsize=12, weight='bold')
ax.set_ylabel("Source", fontsize=12, weight='bold')
ax.set_title("Coverage of Bicodemix sarcasm in Variety B", fontsize=14, weight='bold', pad=15)
fig.tight_layout()
plt.show()