# Create a Human Abstraction Graph Relevant to the LLM Task
To measure the alignment of LLMs, we create a human abstraction graph that represents their task. In the S-TEST dataset there are occupation (P106) and location (P131 and P19) tasks. To measure their alignment we map the words in the dataset to concepts in the WordNet graph.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import json
import random
import pickle
import numpy as np
from tqdm import tqdm
from scipy import stats
from queue import Queue
from collections import Counter
from itertools import combinations
from nltk.corpus import wordnet as wn

sys.path.append(os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../')))
from graph import Graph, Node

## Load the S-TEST Specificity Testing Dataset
The S-TEST dataset contains sentences that test the model on a subject's occupation, location, or place of brith. For instance occupation sentences are in the format "Cher is a [MASK] by profession." Each sentence has a corresponding specific label (e.g., "artist"). Here we load the S-TEST data as well as the prediction results for 5 BERT, RoBERTa, and GPT-2 models.

In [3]:
DATA_DIR = 'S-TEST/data/S-TEST/'
RESULTS_DIR = 'S-TEST/output/results/'

MODELS = [
    'bert_base', 
    'bert_large', 
    'roberta.base', 
    'roberta.large', 
    'gpt2',
]
TASKS = [
    {'name': 'occupation', 'id': 'P106', 'up_fn': 'hypernyms', 'down_fn': 'hyponyms', 'root': wn.synset('person.n.01'), 'threshold': 0.01},
    {'name': 'location', 'id': 'P131', 'up_fn': 'part_holonyms', 'down_fn': 'part_meronyms', 'root': None, 'threshold': 0.01},
    {'name': 'place of birth', 'id': 'P19', 'up_fn': 'part_holonyms', 'down_fn': 'part_meronyms', 'root': None, 'threshold': 0.01},
]

In [4]:
def load_data_instances(task_id):
    with open(os.path.join(DATA_DIR, f'{task_id}.jsonl'), 'r') as f:
        data = [json.loads(l) for l in f]
    instances = [(d['sub_label'], d['obj_label'], d['obj2_label']) for d in data]
    return instances

def load_data(task_id):
    """Load the data instances for a task_id. There can be duplicates, but we
    handle them the same way the S-TEST repo does."""
    data = {}
    with open(os.path.join(DATA_DIR, f'{task_id}.jsonl'), 'r') as f:
        for line in f:
            datum = json.loads(line)
            data[datum['sub_label']] = datum
    return data

In [5]:
TASK = TASKS[0]
MODEL = MODELS[0]

DATA = load_data(TASK['id'])
print(f"{len(DATA)} instances for {TASK['name']} prediciton task.")
print(f"Example data:")
print(DATA[list(DATA.keys())[0]])

4999 instances for occupation prediciton task.
Example data:
{'sub_uri': 'Q39074561', 'sub_label': 'Joe Carter', 'obj_uri': 'Q1371925', 'obj_label': 'announcer', 'obj_value': 2.0, 'obj2_uri': 'Q1930187', 'obj2_label': 'journalist', 'obj2_value': 3.0, 'predicate_id': 'P106'}


In [6]:
def load_model_results(model_name, task_id, results_dir = RESULTS_DIR):
    with open(os.path.join(results_dir, model_name, task_id, 'result.pkl'), 'rb') as f:
        results = pickle.load(f)['list_of_results']
    return results

In [7]:
RESULTS = load_model_results(MODEL, TASK['id'])
print(f"{len(RESULTS)} predictions for {MODEL} on {TASK['name']} prediciton task.")
print(f"Predictions for results[0] sum to {np.sum([np.exp(w['log_prob']) for w in RESULTS[0]['masked_topk']['topk']])}")
print(f"Computed probabilities for {len(RESULTS[0]['masked_topk']['topk'])} words.") 

5000 predictions for bert_base on occupation prediciton task.
Predictions for results[0] sum to 0.8970748439562531
Computed probabilities for 18173 words.


In [8]:
ALL_LABELS = [DATA[result['sample']['sub_label']]['obj_label'] for result in RESULTS]
with open(os.path.join(DATA_DIR, f"{TASK['id']}_synsets.json"), 'r') as f:
    label_synsets = json.load(f)
    label_to_synset = {label: wn.synset(synset) for label, synset in label_synsets if synset is not None}
IDX_TO_KEEP = [i for i in range(len(ALL_LABELS)) if ALL_LABELS[i] in label_to_synset]
print(f'Removed {len(RESULTS) - len(IDX_TO_KEEP)} instances with labels {[label for i, label in enumerate(ALL_LABELS) if i not in IDX_TO_KEEP]}')

LABELS = np.array(ALL_LABELS)[IDX_TO_KEEP]
SYNSETS = [label_to_synset[label] for label in LABELS]
print(f'Resulting dataset has {len(LABELS)}/{len(ALL_LABELS)} labels mapping to {len(SYNSETS)} total/{len(set(SYNSETS))} unique synsets')
print(f'First 5 labels: {LABELS[:5]}')
print(f'First 5 synsets: {SYNSETS[:5]}')

Removed 5 instances with labels ['man', 'female', 'position', 'humour', 'man']
Resulting dataset has 4995/5000 labels mapping to 4995 total/103 unique synsets
First 5 labels: ['architect' 'philosopher' 'poet' 'mechanic' 'priest']
First 5 synsets: [Synset('architect.n.01'), Synset('philosopher.n.01'), Synset('poet.n.01'), Synset('machinist.n.01'), Synset('priest.n.01')]


## Create the WordNet Human Abstraction Graph
We use WordNet to represent the human abstraction graph for each task. For instance, for occupation prediction, we take the subset of WordNet that is connected to any of the labels in the S-TEST dataset. This creates a human abstraction of occupations ranging from low-level labels (e.g., "poet") to high-level profession concepts (e.g., "communicator").

#### First, get all the WordNet concepts that are related to any task label in the dataset.

In [9]:
def get_synset_name(synset):
    return synset.name().split('.')[0]

def get_synsets_relatives(synset, traversal_fn_name, root=None, include_self=True):
    traversal_fn = getattr(synset, traversal_fn_name)
    words = set([])
    if include_self:
        words.add(get_synset_name(synset))
    if (root is not None and synset == root) or len(traversal_fn()) == 0:
        return words
    for word in traversal_fn():
        next_words = get_synsets_relatives(word, traversal_fn_name, root)
        words.update(next_words)
    return words

def get_task_synsets(task, synset_labels):
    task_synsets = set([])
    for synset in synset_labels:
        if synset in task_synsets: 
            continue
        children = get_synsets_relatives(synset, task['down_fn'], root=task['root'], include_self=False)
        parents = get_synsets_relatives(synset, task['up_fn'], root=task['root'], include_self=False)
        task_synsets.add(get_synset_name(synset))
        task_synsets.update(children)
        task_synsets.update(parents)
    return task_synsets

task_synsets = get_task_synsets(TASK, SYNSETS)
print(f"{len(task_synsets)} concepts synsets related to {len(set(SYNSETS))} synsets from the {TASK['name']} prediction task.")
print(f'Example concepts: {random.sample(task_synsets, 5)}')

1566 concepts synsets related to 103 synsets from the occupation prediction task.
Example concepts: ['bibliotist', 'letterman', 'assassin', 'kerb_crawler', 'paster']


#### Next, create a DAG representing all of the concepts and their relationships

In [10]:
def create_wordnet_dag(task, task_synsets):
    root = task['root']
    nodes = {}
    queue = Queue()
    queue.put((root, None)) # Queue contains the synset and its parent DAG node.
    while not queue.empty():
        synset, parent_node = queue.get()
        synset_name = get_synset_name(synset)
        
        # Create a node for the synset if it does not already exit.
        if synset_name not in nodes:
            synset_node = Node(synset_name)
        else: 
            synset_node = nodes[synset_name]
            
        # Connect the node to its parent and update the graph.
        if parent_node is not None: # Only the root node has parent = None.
            parent_node.connect_child(synset_node)
        nodes[synset_name] = synset_node
        
        # Continue the traversal down the graph.
        traversal_fn = getattr(synset, task['down_fn'])
        for next_synset in traversal_fn():
            if get_synset_name(next_synset) not in task_synsets:
                # print(get_synset_name(next_synset))
                continue # Skip relatives that are not related to the task.
            queue.put((next_synset, synset_node))
    wordnet_dag = Graph(nodes, get_synset_name(root))
    wordnet_dag.finalize()
    return wordnet_dag
            
wordnet_dag = create_wordnet_dag(TASK, task_synsets)
print(f"Created WordNet DAG with root node '{wordnet_dag.root_id}' and {len(wordnet_dag.nodes)} synset concepts.")

Created WordNet DAG with root node 'person' and 1527 synset concepts.


#### Finally, map concepts in the WordNet DAG to their most accurate parent

In [11]:
print(f"Number of parents per node -- {Counter([len(node.parents) for node in wordnet_dag.nodes.values()])}")

Number of parents per node -- Counter({1: 1380, 2: 133, 3: 13, 0: 1})


In [12]:
# Manually select the most relevant parent for each concept node
name_to_parent_filename = os.path.join(RESULTS_DIR, 'name_to_parent.json')
if os.path.isfile(name_to_parent_filename):
    with open(name_to_parent_filename, 'r') as f:
        node_name_to_parent_name = json.load(f)
    print(f"Loaded node to parent mapping for {len(node_name_to_parent_name)} nodes.")
else:
    node_name_to_parent_name = {}
    for node_name, node in wordnet_dag.nodes.items():
        if node_name in node_name_to_parent_name: continue
        if len(node.parents) <= 1: continue
        non_root_parents = [parent for parent in node.parents if parent.name != 'person']
        if len(non_root_parents) == 1:
            node_name_to_parent_name[node_name] = non_root_parents[0].name
        else:
            parents = [(parent.name, parent.depth) for parent in node.parents]
            parents.sort(key = lambda x: x[1], reverse=True)
            selected_parent = ''
            while selected_parent not in wordnet_dag.nodes:
                print(f"Parent options for {node_name} are {parents}. Children are {node.children}.")
                selected_parent = input()
            node_name_to_parent_name[node_name] = selected_parent
    print(f"Loaded node to parent mapping for {len(node_name_to_parent_name)} nodes.")
    with open(name_to_parent_filename, 'w') as f:
        json.dump(node_name_to_parent_name, f, indent=4)

Loaded node to parent mapping for 146 nodes.
