In [1]:
import json
import torch

from tqdm import tqdm
from fairseq.models.bart import BARTModel

#### Load Annotated Dataset

In [2]:
train_set = json.load(open('../data/train.json', 'r'))
test_set = json.load(open('../data/test.json', 'r'))

#### Load Weights

In [3]:
CMLM_MODEL_PATH = 'BART_models/xsum_cedar_cmlm'
MLM_MODEL_PATH = 'BART_models/bart.large'

DATA_NAME_OR_PATH = 'summarization/XSum/fairseq_files/xsum-bin'

In [4]:
bart = BARTModel.from_pretrained(CMLM_MODEL_PATH,
                                 checkpoint_file='checkpoint_best.pt',
                                 data_name_or_path=DATA_NAME_OR_PATH)

2022-04-07 01:25:13 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/xsum_cedar_cmlm
2022-04-07 01:25:13 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/summarization/XSum/fairseq_files/xsum-bin
2022-04-07 01:25:22 | INFO | fairseq.tasks.translation | [source] dictionary: 50264 types
2022-04-07 01:25:22 | INFO | fairseq.tasks.translation | [target] dictionary: 50264 types


In [5]:
prior_bart = BARTModel.from_pretrained(MLM_MODEL_PATH,
                                       checkpoint_file='model.pt',
                                       data_name_or_path=MLM_MODEL_PATH)

2022-04-07 01:25:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large
2022-04-07 01:25:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large
2022-04-07 01:25:39 | INFO | fairseq.tasks.denoising | dictionary: 50264 types


#### Build Prior & Posterior Model

In [6]:
from EntFA.model import ConditionalSequenceGenerator
from EntFA.utils import prepare_cmlm_inputs, prepare_mlm_inputs, get_probability_parallel

In [7]:
model = ConditionalSequenceGenerator(bart)
prior_model = ConditionalSequenceGenerator(prior_bart)

#### Training

In [8]:
import numpy as np

from sklearn import neighbors
from sklearn.metrics import classification_report, f1_score, accuracy_score

In [9]:
def build_classifier(train_features, train_labels, n=30):
    classifier = neighbors.KNeighborsClassifier(n_neighbors=30, algorithm='auto')
    
    x_mat = np.array(train_features)
    stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])]
    x_mat = np.vstack([x_mat[:, 0]/stds[0],  x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose()
    y_vec = np.array(train_labels)
    classifier.fit(x_mat, y_vec)
    
    return classifier

def infernece(test_features, classifier):
    """
    Args:
        test_features (List[List]): [[prior, posterior, overlap_feature], ...]
        classifier: KNN classifier
    """
    x_mat = np.array(test_features)
    stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])]
    x_mat = np.vstack([x_mat[:, 0]/stds[0],  x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose()
    Z = classifier.predict(x_mat)
    return Z

In [10]:
def get_features(data_set, prior_model, model):
    label_mapping = {
        'Non-hallucinated': 0,
        'Factual Hallucination': 0,
        'Non-factual Hallucination': 1
    }

    features, labels = [], []
    for t in tqdm(data_set):
        source, prediction, entities = t['source'], t['prediction'], t['entities']

        inputs = prepare_mlm_inputs(source, prediction, ent_parts=entities)
        priors = get_probability_parallel(prior_model, inputs[0], inputs[1], inputs[2], inputs[3], mask_filling=True)

        inputs = prepare_cmlm_inputs(source, prediction, ent_parts=entities)
        posteriors = get_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3])

        overlaps = [1. if e['ent'].lower() in source.lower() else 0. for e in entities]
        assert len(priors) == len(posteriors) == len(overlaps)

        for i, e in enumerate(entities):
            if label_mapping.get(e['label'], -1) != -1:
                features.append((priors[i], posteriors[i], overlaps[i]))
                labels.append(label_mapping[e['label']])

    return features, labels

In [11]:
train_features, train_labels = get_features(train_set, prior_model, model)
classifier = build_classifier(train_features, train_labels, n=30)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 460/460 [00:49<00:00,  9.22it/s]


#### Evaluation

In [12]:
test_features, test_labels = get_features(test_set, prior_model, model)
Z = infernece(test_features, classifier)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:25<00:00,  9.52it/s]


In [13]:
print('accuracy: {:.4}\n\n'.format(accuracy_score(test_labels, Z)))
print(classification_report(test_labels, Z, target_names=['Factual', 'Non-Factual'], digits=4))

accuracy: 0.9102


              precision    recall  f1-score   support

     Factual     0.9323    0.9629    0.9474       701
 Non-Factual     0.7658    0.6343    0.6939       134

    accuracy                         0.9102       835
   macro avg     0.8490    0.7986    0.8206       835
weighted avg     0.9056    0.9102    0.9067       835



#### Save Classifier

In [14]:
import pickle

pickle.dump(classifier, open('knn_classifier.pkl', 'wb'))