This notebook requires a Google Colab **T4 runtime**

In [None]:
# https://github.com/googlecolab/colabtools/issues/3409
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
!git clone https://github.com/casszhao/ReAGent.git
!pip install -r ReAGent/requirments.txt
!python ReAGent/setup_nltk.py

## 1. Generation

### 1.1 Global configuation

In [None]:
import torch

device = "cuda"
torch.manual_seed(42)
torch.use_deterministic_algorithms(True, warn_only=True)

### 1.2 Load model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("gpt2-medium").to(device)

### 1.3 Configure prediction

In [None]:
max_inference_length = 5
input_string = "Super Mario Land is a game that developed by"

### 1.4 Run prediction

In [None]:
input_ids = tokenizer(input_string, return_tensors='pt')['input_ids'][0].to(model.device)
generated_ids = model.generate(input_ids=torch.unsqueeze(input_ids, 0), max_length=(input_ids.shape[0] + max_inference_length), do_sample=False)[0]
generated_texts = [ tokenizer.decode(token) for token in generated_ids ]
print(f'generated full sequence --> {generated_texts}')

## 2. Rationalization

### 2.1 Construct rationalizer

In [None]:
from ReAGent.src.rationalization.rationalizer.aggregate_rationalizer import AggregateRationalizer
from ReAGent.src.rationalization.rationalizer.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator
from ReAGent.src.rationalization.rationalizer.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator
from ReAGent.src.rationalization.rationalizer.token_replacement.token_replacer.uniform import UniformTokenReplacer
from ReAGent.src.rationalization.rationalizer.token_replacement.token_sampler.postag import POSTagTokenSampler

rational_size = 5
rational_size_ratio = None

token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=device)

stopping_condition_evaluator = TopKStoppingConditionEvaluator(
    model=model, 
    token_sampler=token_sampler, 
    top_k=3, 
    top_n=rational_size, 
    top_n_ratio=rational_size_ratio, 
    tokenizer=tokenizer
)

importance_score_evaluator = DeltaProbImportanceScoreEvaluator(
    model=model, 
    tokenizer=tokenizer, 
    token_replacer=UniformTokenReplacer(
        token_sampler=token_sampler, 
        ratio=0.3
    ),
    stopping_condition_evaluator=stopping_condition_evaluator,
    max_steps=3000
)

rationalizer = AggregateRationalizer(
    importance_score_evaluator=importance_score_evaluator,
    batch_size=8,
    overlap_threshold=2,
    overlap_strict_pos=True,
    top_n=rational_size, 
    top_n_ratio=rational_size_ratio
)

### 2.2 Run rationalization

In [None]:
# rationalize each generated token

importance_scores = []
importance_score_map = torch.zeros([generated_ids.shape[0] - input_ids.shape[0], generated_ids.shape[0] - 1], device=device)

for target_pos in torch.arange(input_ids.shape[0], generated_ids.shape[0]):
    
    # extract target
    target_id = generated_ids[target_pos]

    # rationalization
    pos_rational = rationalizer.rationalize(torch.unsqueeze(generated_ids[:target_pos], 0), torch.unsqueeze(target_id, 0))[0]

    ids_rational = generated_ids[pos_rational]
    text_rational = [ tokenizer.decode([id_rational]) for id_rational in ids_rational ]

    importance_score_map[target_pos - input_ids.shape[0], :target_pos] = rationalizer.mean_important_score

    print(f'{target_pos + 1} / {generated_ids.shape[0]}')
    print(f'Target word     --> {tokenizer.decode(target_id)}', )
    print(f"Rational pos    --> {pos_rational}")
    print(f"Rational text   --> {text_rational}")

    print()

### 2.3 Visualize rationalization results

In [None]:
import seaborn
seaborn.set(rc={ 'figure.figsize': (30, 10) })
s = seaborn.heatmap(
    importance_score_map.cpu(), 
    xticklabels=generated_texts[:-1], 
    yticklabels=generated_texts[input_ids.shape[0]:], 
    annot=True, 
    square=True)
s.set_xlabel('Importance distribution')
s.set_ylabel('Target')
s

## 3. Evaluation

### 3.1 Configure evaluation

In [None]:
metric_stride = 1

### 3.2 Compute metrics

In [None]:
from ReAGent.src.evaluator.soft_norm_sufficiency import SoftNormalizedSufficiencyEvaluator
from ReAGent.src.evaluation.evaluator.soft_norm_comprehensiveness import SoftNormalizedComprehensivenessEvaluator
soft_norm_suff_evaluator = SoftNormalizedSufficiencyEvaluator(model)
soft_norm_comp_evaluator = SoftNormalizedComprehensivenessEvaluator(model)

source_soft_ns_all = []
source_soft_nc_all = []
random_soft_ns_all = []
random_soft_nc_all = []
target_token_all = []

table_details = [ ["target_pos", "target_token", "source_soft_ns", "source_soft_nc", "rand_soft_ns", "rand_soft_nc"] ]

for target_pos in torch.arange(input_ids.shape[0], generated_ids.shape[0], metric_stride):

    target_token = tokenizer.decode(generated_ids[target_pos])
    target_token_all.append(target_token)

    input_ids_step = torch.unsqueeze(generated_ids[:target_pos], 0)
    target_id_step = torch.unsqueeze(generated_ids[target_pos], 0)
    importance_score_step = torch.unsqueeze(importance_score_map[target_pos - input_ids.shape[0], :target_pos], 0)
    random_importance_score_step = torch.softmax(torch.rand(importance_score_step.shape, device=device), dim=-1)

    # compute Soft-NS and Soft-NC on source importance score

    source_soft_ns_step = soft_norm_suff_evaluator.evaluate(input_ids_step, target_id_step, importance_score_step)
    source_soft_ns_all.append(source_soft_ns_step)

    source_soft_nc_step = soft_norm_comp_evaluator.evaluate(input_ids_step, target_id_step, importance_score_step)
    source_soft_nc_all.append(source_soft_nc_step)

    # compute Soft-NS and Soft-NC on random importance score

    random_soft_ns_step = soft_norm_suff_evaluator.evaluate(input_ids_step, target_id_step, random_importance_score_step)
    random_soft_ns_all.append(random_soft_ns_step)

    random_soft_nc_step = soft_norm_comp_evaluator.evaluate(input_ids_step, target_id_step, random_importance_score_step)
    random_soft_nc_all.append(random_soft_nc_step)

    table_details.append([
        target_pos.item() + 1, target_token, 
        f"{source_soft_ns_step.item():.3f}", f"{source_soft_nc_step.item():.3f}", 
        f"{random_soft_ns_step.item():.3f}", f"{random_soft_nc_step.item():.3f}", 
        # metric_soft_ns_step.item(), metric_soft_nc_step.item()
        ])
    print(f"target_pos: {target_pos + 1}, target_token: {target_token}, Source Soft-NS: {source_soft_ns_step}, Source Soft-NC: {source_soft_nc_step}, Random Soft-NS: {random_soft_ns_step}, Random Soft-NC: {random_soft_nc_step}")

# compute metrics on Soft-NS and Soft-NC

metric_soft_ns = torch.log(torch.sum(torch.tensor(source_soft_ns_all, device=device)) / torch.sum(torch.tensor(random_soft_ns_all, device=device)))
metric_soft_nc = torch.log(torch.sum(torch.tensor(source_soft_nc_all, device=device)) / torch.sum(torch.tensor(random_soft_nc_all, device=device)))

print(f"metric_soft_ns: {metric_soft_ns}, metric_soft_nc: {metric_soft_nc}")

### 3.3 Show metrics in tables

In [None]:
import tabulate
from IPython.display import HTML, display

table_details_html = tabulate.tabulate(table_details, tablefmt='html')

display(HTML(table_details_html))

table_mean = [
        [ "target_tokens", "metric_soft_ns", "metric_soft_nc" ],
        [ "$".join(target_token_all), f"{metric_soft_ns.item():.3f}", f"{metric_soft_nc.item():.3f}" ]
    ]

table_mean_html = tabulate.tabulate(table_mean, tablefmt='html')

display(HTML(table_mean_html))

### 3.4 Save results to file

In [None]:
import csv

with open('notebook_details.csv', 'w', newline='') as csvfile:
    csvWriter = csv.writer(csvfile)
    csvWriter.writerows(table_details)

with open('notebook_mean.csv', 'w', newline='') as csvfile:
    csvWriter = csv.writer(csvfile)
    csvWriter.writerows(table_mean)