# Calculating Confidence Scores on Dev and Test Sets

In [None]:
# System imports
import math
import os
import sys

from copy import deepcopy

sys.path.append("../")

# External imports
import allennlp.nn.util as util
import numpy as np
import pandas as pd
import torch

from allennlp.commands.train import train_model
from allennlp.common import Params
from allennlp.common.util import import_submodules
from allennlp.data.dataset import Batch
from allennlp.models.archival import load_archive
from allennlp.nn.util import logsumexp
from allennlp.training.trainer import Trainer
from allennlp.training.optimizers import Optimizer
from allennlp.training.util import datasets_from_params

import_submodules("streusle_tagger")

from streusle_tagger.dataset_readers import StreusleDatasetReader

params = Params.from_file("../training_config/streusle_bert_large/streusle_bert_large_cased_no_constraints.jsonnet")
archive = load_archive("../saved_models/no_constraints/model.tar.gz")
model = archive.model

In [None]:
index_to_label = model.vocab.get_index_to_token_vocabulary(model._label_namespace)
label_to_index = dict(zip(index_to_label.values(), index_to_label.keys()))

labels_df = pd.DataFrame(label_to_index, columns=["Label", "Index"])
labels_df.to_csv("confidence_results/")
reader = StreusleDatasetReader()
datasets = datasets_from_params(deepcopy(params))

In [None]:
def denominator(crf, logits):
    if len(logits.size()) > 1:
        sequence_length, num_tags = logits.size()
        alpha = crf.start_transitions + logits[0]
    else:
        sequence_length = 1
        num_tags = logits.size()[0]
        alpha = crf.start_transitions + logits
        
    forward_trellis = []
    forward_trellis.append(alpha)
    
    for i in range(1, sequence_length):
        forward_trellis.append(forward_trellis[i - 1] + logsumexp(logits[i].view(1, num_tags) + crf.transitions))

    stops =  forward_trellis[sequence_length - 1] + crf.end_transitions
    forward_trellis.append(stops)
    
    backward_trellis = []

    if sequence_length > 1:
        backward_trellis.append(logsumexp(logits[sequence_length - 1].view(1, num_tags) + crf.transitions))

    reverse_indexes = list(range(1, sequence_length - 1))
    reverse_indexes.reverse()
    for i in reverse_indexes:
        backward_trellis.append(backward_trellis[sequence_length - i - 2] + logsumexp(logits[i].view(1, num_tags) + crf.transitions))
        
    # This never gets used; it's just for more intuitive indexing in numerator calculation
    backward_trellis.append(["dummy placeholder"])
    backward_trellis.reverse()
    return forward_trellis, backward_trellis, util.logsumexp(stops)

def numerator(crf, logits, forward_trellis, backward_trellis, tag_num, word_num):
    if len(logits.size()) > 1:
        sequence_length, num_tags = logits.size()
    else:
        sequence_length = 1
        num_tags = logits.size()[0]
    
    if sequence_length == 1:
        start_transition_mask = torch.zeros_like(crf.start_transitions)
        start_transition_mask[tag_num] = 1
        alpha = util.replace_masked_values(crf.start_transitions + logits, start_transition_mask, -1e32)
        return logsumexp(alpha + crf.end_transitions)
    
    elif word_num == 0:
        start_transition_mask = torch.zeros_like(crf.start_transitions)
        start_transition_mask[tag_num] = 1
        alpha = util.replace_masked_values(crf.start_transitions + logits[0], start_transition_mask, -1e32)
        beta = backward_trellis[1]
        return logsumexp(alpha + beta + crf.end_transitions)
    
    else:
        alpha = forward_trellis[word_num - 1]
        
        emit_mask = torch.zeros_like(logits[word_num])
        emit_mask[tag_num] = 1
        emit_scores = util.replace_masked_values(logits[word_num], emit_mask, -1e32)
        
        transition_scores = crf.transitions
        transition_mask = torch.zeros_like(transition_scores)
        transition_mask[:, tag_num] = 1
        transition_scores = util.replace_masked_values(transition_scores, transition_mask, -1e32)
        
        inner = alpha.view(num_tags, 1) + emit_scores.view(1, num_tags) + transition_scores
        alpha = logsumexp(inner, 0)
        if word_num == sequence_length - 1:
            return logsumexp(forward_trellis[word_num - 1] + logsumexp(emit_scores.view(1, num_tags) + transition_scores) + crf.end_transitions)
        else:
            beta = backward_trellis[word_num + 1]
            return logsumexp(forward_trellis[word_num - 1] + logsumexp(emit_scores.view(1, num_tags) + transition_scores) + beta + crf.end_transitions)

In [None]:
def sentence_confidence(crf, sequence_logits):
    """Calculates matrix of confidence scores with num_words rows and num_tags columns."""
    confidence_matrix = []
    num_tags = crf.num_tags
    if len(sequence_logits.size()) == 1:
        num_words = 1
    else:
        num_words = sequence_logits.size()[0]
    
    forward_trellis, backward_trellis, denom = denominator(model.crf, sequence_logits)
    for i in range(num_words):
        new_row = []
        for j in range(num_tags):
            numer = numerator(crf, sequence_logits, forward_trellis, backward_trellis, j, i)
            confidence_score = math.exp(numer - denom)
            new_row.append(confidence_score)
        confidence_matrix.append(new_row)
        
    return confidence_matrix

In [None]:
def dataset_confidence(dataset, dataset_name):
    """Creates one CSV file per sentence, containing metadata and confidence scores for all tag-token pairs."""
    
    save_path = f"../confidence_results/{dataset_name}"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    for i, instance in enumerate(dataset):
        instance_batch = Batch([instance])
        instance_batch.index_instances(model.vocab)

        # Confidence scores
        print(f"Calculating confidence scores for instance {i}...")
        tokens = instance_batch.as_tensor_dict()["tokens"]
        embedded_tokens = model.text_field_embedder(tokens)
        logits = model.tag_projection_layer(embedded_tokens).squeeze()
        confidence_matrix = sentence_confidence(model.crf, logits)

        # Metadata
        tokens_list = np.array([[str(t) for t in instance.get("tokens").tokens]]).transpose()
        ground_truth_tags = instance.get("tags").labels
        ground_truth_tags_indexes = np.array([[label_to_index[tag] for tag in ground_truth_tags]]).transpose()
        ground_truth_tags = np.array([ground_truth_tags]).transpose()
        predicted_tags_indexes = (model.forward(**instance_batch.as_tensor_dict())["tags"])[0]
        predicted_tags = np.array([[index_to_label[i] for i in predicted_tags_indexes]]).transpose()
        predicted_tags_indexes = np.array([predicted_tags_indexes]).transpose()
        metadata = np.concatenate((tokens_list, ground_truth_tags, ground_truth_tags_indexes, predicted_tags, predicted_tags_indexes), axis=1)

        # Combine metadata and confidence scores
        data = np.concatenate((metadata, confidence_matrix), axis=1)
        
        # Write to file
        columns = ["Tokens", "Ground Truth", "Ground Truth Indexes", "Predicted Tags", "Predicted Tag Indexes"] + [i for i in range(model.crf.num_tags)]
        df = pd.DataFrame(data, columns=columns)
        df.to_csv(f"{save_path}/{i}.csv")

In [None]:
dataset_confidence(datasets["validation"], "validation")