In [4]:
import numpy as np

from typing import List

from scipy import stats
from scipy.spatial.distance import hamming

import editdistance # Levenshtein distance

from sklearn.feature_extraction.text import CountVectorizer

## Utils

In [7]:
def transform_corpus(corpus, vocabulary):
    """
    inputs
    ------------------------------------
    corpus: list/array of setences
    vocabulary: list

    outputs
    ------------------------------------
    matrix of token counts and its feature names
    """
    vectorizer = CountVectorizer()
    vectorizer.fit(vocabulary)
    matrix = vectorizer.transform(corpus)

    return vectorizer.get_feature_names_out(), matrix.toarray()

# 1.0 Data

## Data for tests

In [8]:
"""
https://github.com/tomekkorbak/measuring-non-trivial-compositionality/blob/master/protocols.py
"""

input_A = ['blue circle', 'silver box', 'green circle']
compositional_representation_A = ['a!x_', 'jx__', 'c!x_']

input_B = ['blue box','red circle', 'green box']
compositional_representation_B = ['ax__', 'b!x_','cx__']

vocabulary = input_A + input_B

_,messages_A = transform_corpus(input_A, vocabulary)
print("Count vector of A: \n {} \n".format(messages_A))

_,messages_B = transform_corpus(input_B, vocabulary)
print("Count vector of B: \n {}".format(messages_B))

Count vector of A: 
 [[1 0 1 0 0 0]
 [0 1 0 0 0 1]
 [0 0 1 1 0 0]] 

Count vector of B: 
 [[1 1 0 0 0 0]
 [0 0 1 0 1 0]
 [0 1 0 1 0 0]]


In [None]:
# get the txt and generate the representation

# 2.0 TopSim

Topographic Similarity

In [None]:
from typing import Callable, Tuple, Union, Dict, List

from scipy.stats import spearmanr

"""
Adapted: https://github.com/tomekkorbak/measuring-non-trivial-compositionality
"""

class TopographicSimilarity():

    def __init__(self, message_metric: Callable, representation_metric: Callable):
        self.message_metric = message_metric 
        self.representation_metric = representation_metric

    def measure(self, compositional_representation_A, messages_A, compositional_representation_B, messages_B):
        # print('distance_messages')
        distance_messages = self._compute_distances(
            sequence_A = messages_A,
            sequence_B = messages_B,
            metric = self.message_metric)
        
        # print('distance_representation')
        distance_representation = self._compute_distances( 
            sequence_A = compositional_representation_A,
            sequence_B = compositional_representation_B,
            metric = self.representation_metric)

        topsim = spearmanr(distance_representation, distance_messages, nan_policy="raise").correlation
        
        return topsim

    def _compute_distances(self, sequence_A: List[str], sequence_B: List[str], metric: Callable) -> List[float]:
        distances = []
        for i in range(len(sequence_A)): 
            distances.append(metric(sequence_A[i], sequence_B[i]))
        # print(distances, '\n')
        return distances

## Test

In [None]:
topsim_class = TopographicSimilarity(message_metric = hamming,
                               representation_metric = editdistance.eval
                               )


topsim = topsim_class.measure(compositional_representation_A, messages_A, 
                   compositional_representation_B,  messages_B)
print("Topsim: {}".format(topsim))

distance_messages
[0.3333333333333333, 0.6666666666666666, 0.3333333333333333] 

distance_representation
[2, 3, 2] 

Topsim: 1.0


# 3.0 Pos

In [None]:
"""
Adapted: https://github.com/tomekkorbak/measuring-non-trivial-compositionality
"""

example of representation: https://github.com/facebookresearch/EGG/blob/main/egg/zoo/compo_vs_generalization/intervention.py

## Entropy

In [None]:
from collections import defaultdict

def compute_entropy(symbols: List[str]) -> float:
    frequency_table = defaultdict(float)
    for symbol in symbols:
        frequency_table[symbol] += 1.0
    H = 0
    for symbol in frequency_table:
        p = frequency_table[symbol]/len(symbols)
        H += -p * np.log2(p)
    return H

## Mutual information

In [None]:
def compute_mutual_information(concepts: List[str], symbols: List[str]) -> float:
    concept_entropy = compute_entropy(concepts)  # H[p(concepts)]
    symbol_entropy = compute_entropy(symbols)  # H[p(symbols)]
    symbols_and_concepts = [symbol + '_' + concept for symbol, concept in zip(symbols, concepts)]
    symbol_concept_joint_entropy = compute_entropy(symbols_and_concepts)  # H[p(concepts, symbols)]
    return concept_entropy + symbol_entropy - symbol_concept_joint_entropy

## Pos Class

In [None]:
class PositionalDisentanglement():

    def __init__(self, max_message_length: int, num_concept_slots: int):
        self.max_message_length = max_message_length
        self.num_concept_slots = num_concept_slots
        self.permutation_invariant = False

    def measure(self, compositional_representation, messages):
        disentanglement_scores = []
        non_constant_positions = 0

        for j in range(self.max_message_length):
            symbols_j = compositional_representation[j]
            symbol_mutual_info = []
            symbol_entropy = compute_entropy(symbols_j)
            for i in range(self.num_concept_slots):
                concepts_i = messages[i]
                mutual_info = compute_mutual_information(concepts_i, symbols_j)
                symbol_mutual_info.append(mutual_info)
            symbol_mutual_info.sort(reverse=True)

            if symbol_entropy > 0:
                disentanglement_score = (symbol_mutual_info[0] - symbol_mutual_info[1]) / symbol_entropy
                disentanglement_scores.append(disentanglement_score)
                non_constant_positions += 1
            if non_constant_positions > 0:
                return sum(disentanglement_scores)/non_constant_positions
            else:
                return np.nan

## Test

In [None]:
pos_class = PositionalDisentanglement(max_message_length = 4, num_concept_slots = 2) 
# 4: characters numbers in representation
# 2: 2 words per sentence

pos = pos_class.measure(compositional_representation_A, input_A)
print("Pos: {}".format(pos))

Pos: 0.2039755108523047
