In [90]:
# Auto reload modules when they change
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [91]:
import numpy as np

In [92]:
from HateClassifier import HateClassifier
from TrainingConfig import TrainingConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

In [93]:
config = TrainingConfig()
hc = HateClassifier(config)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [94]:
model = AutoModelForSequenceClassification.from_pretrained(
    config.model_name,
    num_labels=config.num_labels,
)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [95]:
sample_text = "I hate you so much"
human_rationale = torch.Tensor([[0, 0, 1, 1, 0, 0, 0]])  # Example rationale mask
tokenized = tokenizer(sample_text, return_tensors="pt")
outputs = model(**tokenized, output_attentions=True)

In [96]:
tokenized

{'input_ids': tensor([[ 101, 1045, 5223, 2017, 2061, 2172,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [97]:
token_entropies = hc.compute_token_entropy(outputs.attentions, tokenized['attention_mask'])

In [98]:
token_entropies_positive = -token_entropies[0]

In [99]:
# Average across tokens
token_entropies_positive = token_entropies_positive.mean(0)

In [100]:
torch.log(tokenized['attention_mask'].sum())

tensor(1.9459)

In [101]:
final_entropy = token_entropies_positive/torch.log(tokenized['attention_mask'].sum()) # --> "percentage" of entropy compared to its maximum possible value
final_entropy

tensor([0.9847, 0.9911, 0.9876, 0.9916, 0.9900, 0.9889, 0.9715],
       grad_fn=<DivBackward0>)

In [102]:
# For rationale(d) token we want its entropy to be as low as possible but not more than the lower bound
# Whereas for non-rationale tokens we want their entropy to be as high as possible

attention_mask = tokenized['attention_mask'][0]
valid_mask = attention_mask.bool()
valid_indices = torch.where(valid_mask)[0]

In [103]:
human = human_rationale[0, valid_indices]

In [104]:
human

tensor([0., 0., 1., 1., 0., 0., 0.])

In [105]:
rationale_mask = (human > 0)
non_rationale_mask = (human == 0)

print("Rationale token entropies:", final_entropy[rationale_mask])
print("Non-rationale token entropies:", final_entropy[non_rationale_mask])

Rationale token entropies: tensor([0.9876, 0.9916], grad_fn=<IndexBackward0>)
Non-rationale token entropies: tensor([0.9847, 0.9911, 0.9900, 0.9889, 0.9715], grad_fn=<IndexBackward0>)


In [117]:
lower_bound = 0.99
upper_bound = 0.8

rationale_entropies = final_entropy[rationale_mask]
non_rationale_entropies = final_entropy[non_rationale_mask]

In [118]:
lower_violation = torch.relu(lower_bound - rationale_entropies)

In [121]:
non_rationale_loss = -non_rationale_entropies.mean()

In [None]:
non_rationale_loss # This will be negative as we want to maximize entropy for non-rationale tokens

tensor(-0.9852, grad_fn=<NegBackward0>)