## Concatenate relationships with the parent node

#### Imports and helper functions:

In [1]:
import penman
from penman import layout
from penman.graph import Graph
from penman.transform import reify_attributes
import re
from pathlib import Path
from collections import defaultdict


def pprint(l, reified=False, **args):
    if isinstance(l, dict):
        print('Key\tValue')
        for k, v in l.items():            
            print(f'{k}\t{v}', **args)
            
    elif isinstance(l, list) or isinstance(l, tuple) or isinstance(l, set):
        for el in l:
            print(el, **args)
            
    elif isinstance(l, penman.Graph):
        if reified:
            l = penman.encode(l)
            l = reify_rename_graph_from_string(l)
        print(penman.encode(l), **args)
        
    elif isinstance(l, penman.Tree):
        if reified:
            l = penman.format(l)
            l = reify_rename_graph_from_string(l)
            print(penman.encode(l), **args)
        else:
            print(penman.format(l), **args)
        
    elif isinstance(l, str):
        if reified:
            l = reify_rename_graph_from_string(l)
            print(penman.encode(l), **args)
        else:
            print(penman.format(penman.parse(l)), **args)
            
    else:
        raise ValueError('Unknown type')
    print(**args)

In [18]:
### This is an alternative way (maybe even a better one, because it uses the original code) 
### to map nodes to structures/concepts and alignments such as '0.0.0.0': |0.0.0.0, score, (9, 10)|

from AMR2Text.toolkit.tamr_aligner.amr.aligned import Alignment

amr2text_alingnment_path = Path('.')/'AMR2Text'/'processed'/'corpus_a.mrp.new_aligned.txt'
with open(amr2text_alingnment_path) as f:
    amrs = f.read().strip().split('\n\n')
    amrs = [amr.split('\n') for amr in amrs]

amr_id3 = amrs[3]
al = Alignment(amr_id3)
al.alignments.alignments
al.nodes_by_levels

{'0': |0, possible-01, (2, 3)|,
 '0.0': |0.0, do-02, (4, 5)|,
 '0.0.0': |0.0.0, improve-01, (6, 7)|,
 '0.0.0.0': |0.0.0.0, score, (9, 10)|,
 '0.0.0.0.0': |0.0.0.0.0, credit, (8, 9)|,
 '0.0.0.0.1': |0.0.0.0.1, i, (3, 4)|,
 '0.0.1': |0.0.1, amr-unknown, (0, 1)|,
 '0.0.1.0': |0.0.1.0, more, (1, 2)|,
 '_ROOT_': |_ROOT_, _ROOT_, (-1, -1)|}

#### Main class:

In [138]:
class AMRAnalysis:
    def __init__(self, amr2text_alingnment_path, keep_meta=False, concat_rel=True):
        self.amr2text_alingnment_path = amr2text_alingnment_path        
        self.keep_meta = keep_meta
        self.info_dict = {}
        if concat_rel:
            self.concat_rel()
    
    @staticmethod
    def reify_rename_graph_from_string(amr_string):
    
        g1 = reify_attributes(penman.decode(amr_string))
        t1 = layout.configure(g1)
        t1.reset_variables(fmt='MRPNode-{i}')
        g1 = layout.interpret(t1)

        return g1
    
    @staticmethod
    def alignment_labels2mrp_labels(amr_string):
        """Currently works only on reified graphs"""

        amr_graph = AMRAnalysis.reify_rename_graph_from_string(amr_string)
        epidata, triples = amr_graph.epidata, amr_graph.triples
        cur_label, popped = '0', False
        labels_dict = {cur_label:amr_graph.top}
        for triple in triples:        
            cur_node = triple[0]        
            epi = epidata[triple]
            if epi and isinstance(epi[0], penman.layout.Push):
                cur_node = epi[0].variable
                if not popped:
                    cur_label += '.0'
                labels_dict[cur_label] = cur_node
                popped = False            
            elif epi and isinstance(epi[0], penman.layout.Pop):
                pops_count = epi.count(epi[0])
                split = cur_label.split('.')
                if popped: 
                    split = split[:len(split)-pops_count] 
                else:
                    split = split[:len(split)-pops_count+1]
                split[-1] = str(int(split[-1])+1)
                cur_label = '.'.join(split)
                popped = True

        return labels_dict, amr_graph
    
    @staticmethod
    def get_alignments_dict_from_string(alignments_string, alignment_pattern, toks, labels_dict):
        """
        Somehow the alingnments string in 'new_alinged' does not contain
        all aligned nodes that are specified below ¯\_(ツ)_/¯ 
        """
        matches = re.match(alignment_pattern, alignments_string)
        if not matches:
            raise ValueError(f'Alignments string "{alignments_string}" has wrong format!\nCould not find alignments.')
        alignments = matches.group(1).split()
        alignments_dict = {}

        for alignment in alignments:
            parts = alignment.split('|')
            token_span = parts[0]
            #indices = span.split('-')
            #token_span = ' '.join(toks[int(indices[0]):int(indices[1])])
            nodes = parts[1].split('+')
            nodes = [labels_dict[node] for node in nodes]
            for node in nodes:
                alignments_dict[node] = token_span
        return alignments_dict
    
    @staticmethod
    def get_alignments_dict(nodes_block, labels_dict):
        """
        This function deals with the problem that was found while using the 
        function above
        """
        nodes_block = [spl_line for spl_line in nodes_block if len(spl_line) == 3]
        alignments_dict = {}
        for spl_line in nodes_block:
            node = spl_line[0]
            node = labels_dict[node] # '0.0.0' --> 'MRPNode2'
            token_span = spl_line[2]
            alignments_dict[node] = token_span
            
        return alignments_dict

    def extract_info(self):    
        with open(self.amr2text_alingnment_path) as f:
            amrs = f.read().strip().split('\n\n')
            amrs = [amr.split('\n') for amr in amrs]

        alignment_pattern = re.compile(r'# ::alignments\s(.+?)\s::')
        for amr_analysis in amrs:
            amr_id = amr_analysis[0].split()[-1]

            toks = amr_analysis[2].split()[2:] # first 2 tokens are: '# ::tok'
            toks = [tok.lower() for tok in toks]

            amr_string = amr_analysis[-1]
            labels_dict, amr_graph = AMRAnalysis.alignment_labels2mrp_labels(amr_string)

            alignments_string = amr_analysis[3]
            nodes_block = [line.split()[2:] for line in amr_analysis if line.startswith('# ::node')] # first 2 tokens are: '# ::node'
            try:
                # function below works well, but the alignments string doesn't contain all alignments, so a new function
                # has to be defined
                #alignments_dict = AMRAnalysis.get_alignments_dict_from_string(alignments_string, alignment_pattern, toks, labels_dict)
                alignments_dict = AMRAnalysis.get_alignments_dict(nodes_block, labels_dict)
                alignments_dict = defaultdict(lambda: None, alignments_dict)
            except KeyError as e:
                print(amr_id)
                pprint(amr_string, reified=True)
                pprint(labels_dict)
                raise e

            self.info_dict[amr_id] = {'amr_string':penman.encode(amr_graph), \
                                      'toks':toks, \
                                      'alignments_dict':alignments_dict, \
                                      'labels_dict':labels_dict, \
                                      'amr_graph':amr_graph}
            if self.keep_meta:
                meta = amr_analysis[:3] # save '# ::id', '# ::snt' fields
                meta = '\n'.join(meta)
                self.info_dict[amr_id]['meta'] =  meta
        return self
    
    @staticmethod
    def find_below(labels_dict):
        """
        Finds nodes below a certain node using a dictionary of a following form
        (located in 'info_dict[amr_id]['labels_dict']'):
        
        Key Value
        0 MRPNode-0
        0.0 MRPNode-1
        0.0.0 MRPNode-2
        0.0.0.0	MRPNode-3
        0.0.0.0.0 MRPNode-4
        0.0.0.0.1 MRPNode-5
        0.0.1 MRPNode-6
        0.0.1.0 MRPNode-7
        
        Returns a dict where the key is the node label (e.g 'MRPNode-2') and
        the value is a list with all nodes represented as strings below it.
        """
        nodes_below_dict = defaultdict(list)
        for key, value in labels_dict.items():
            for k, v in labels_dict.items():
                if k.startswith(key) and len(k) > len(key):
                    nodes_below_dict[value].append(v)
        return nodes_below_dict
    
    @staticmethod
    def full_span(subtree_token_spans):
        """
        Takes a list of token spans of a whole subtree of form:
        and checks, if there are gaps. 
        
        Returns a list of indices if a token span is full, else False.
        """
        toks_indices = set()
        for token_span in subtree_token_spans:
            spl = token_span.split('-')
            i1, i2 = int(spl[0]), int(spl[1])
            indices = set(range(i1, i2))            
            toks_indices.update(indices)            
        minimum, maximum = min(toks_indices), max(toks_indices)
        toks_indices = sorted(list(toks_indices))
        if toks_indices == list(range(minimum, maximum+1)):
            return toks_indices
        return False
    
    def concat_rel(self, rel=':mod'): 
        if not self.info_dict:
            self.extract_info()
        self.graphs_concat_rel = {}
        
        # ONLY FOR DEBUGGING CERTAIN IDS!!!
        # DELETE FOR NORMAL USE!!!
        #self.info_dict = {k:v for k, v in self.info_dict.items() if k == '3'}
        
        for amr_id in self.info_dict:
            triples_filtered = []
            g = self.info_dict[amr_id]['amr_graph']
            toks = self.info_dict[amr_id]['toks']
            alignments_dict = self.info_dict[amr_id]['alignments_dict']
            nodes_below_dict = AMRAnalysis.find_below(self.info_dict[amr_id]['labels_dict'])
            instances_dict = defaultdict(lambda: None, {node:concept for node, _, concept in g.instances()})
            reentrancies = defaultdict(lambda: None, g.reentrancies())
            
            changed_instances = {}
            nodes_to_delete = []
            epidata = {}
            
            for triple in g.triples:
                if triple[0] not in nodes_to_delete and triple[2] not in nodes_to_delete:
                    if triple[1] == rel:
                        invoked = triple[0]
                        nodes_below_invoked = nodes_below_dict[invoked]
                        nodes_below_invoked_with_invoked = nodes_below_invoked + [invoked]
                        instances_below_invoked = [instances_dict[node] for node in nodes_below_invoked]
                        
                        span = [alignments_dict[node] for node in nodes_below_invoked_with_invoked if alignments_dict[node]]
                        subtree_token_span = AMRAnalysis.full_span(span)
                        reentrancies_below_invoked = any([reentrancies[node] for node in nodes_below_invoked])
                        
                        if subtree_token_span and not reentrancies_below_invoked:
                            merged = [toks[i] for i in subtree_token_span]
                            changed_instances[invoked] = '_'.join(merged)
                            nodes_to_delete += nodes_below_invoked
                            continue
                            
                    epidata[triple] = g.epidata[triple]
                    triples_filtered.append(triple)
            
            for i in range(len(triples_filtered)):
                n, r, c = triples_filtered[i]
                old_tuple = (n, r, c)
                if n in changed_instances and r == ':instance':
                    new_tuple = (n, r, changed_instances[n])
                    triples_filtered[i] = new_tuple
                    epidata = {(k if k != old_tuple else new_tuple):(v if k != old_tuple else v+[penman.layout.Pop()]) 
                               for k, v in epidata.items()}
            
            new_g = Graph(triples=triples_filtered, epidata=epidata)            
            self.graphs_concat_rel[amr_id] = (g, new_g)
            
        return self

In [140]:
amr_path_a = Path('.')/'AMR2Text'/'processed'/'corpus_a.mrp.new_aligned.txt'
amr_path_b = Path('.')/'AMR2Text'/'processed'/'corpus_b.mrp.new_aligned.txt'

amr_analysis_a = AMRAnalysis(amr_path_a, keep_meta=True, concat_rel=True)
amr_analysis_b = AMRAnalysis(amr_path_b, keep_meta=True, concat_rel=True)

graphs_concat_rel_a = amr_analysis_a.graphs_concat_rel 
graphs_concat_rel_b = amr_analysis_b.graphs_concat_rel 

In [105]:
info_dicts = [amr_analysis_a.info_dict, amr_analysis_b.info_dict]

print(info_dicts[0]['4']['toks'])
pprint(info_dicts[0]['4']['amr_string'])
pprint(graphs_concat_rel_a['4'][1])
#pprint(info_dicts[0]['4']['alignments_dict'])
#pprint(info_dicts[0]['4']['labels_dict'])
#pprint(AMRAnalysis.find_below(info_dicts[0]['22']['labels_dict']))
#pprint(info_dicts[0]['4']['amr_graph'].epidata)

['chinese', 'lunar', 'rover', 'lands', 'on', 'moon']
(MRPNode-0 / land-01
           :ARG1 (MRPNode-1 / rover
                            :mod (MRPNode-2 / country)
                            :mod (MRPNode-3 / moon)
                            :mod (MRPNode-4 / name
                                            :op1 (MRPNode-5 / china)))
           :location (MRPNode-6 / moon))

(MRPNode-0 / land-01
           :ARG1 (MRPNode-1 / chinese_lunar_rover)
           :location (MRPNode-6 / moon))



In [143]:
path_a = Path('.')/'amr_suite'/'py3-Smatch-and-S2match'/'amr_data'/'corpus_a_concat.amr'
path_b = Path('.')/'amr_suite'/'py3-Smatch-and-S2match'/'amr_data'/'corpus_b_concat.amr'

def save_concatenation_results(path, amr_analysis):
    with open(path, 'w') as f:
        for amr_id, (_, g_concat) in amr_analysis.graphs_concat_rel.items():
            meta_block = amr_analysis.info_dict[amr_id]['meta']
            print(meta_block, file=f)
            pprint(g_concat, file=f)
            
save_concatenation_results(path_a, amr_analysis_a)
save_concatenation_results(path_b, amr_analysis_b)

#### Basic version of the concatenation function:

In [94]:
def concat_rel(g, rel=':mod'):    
    forbidden_nodes_with_instances = {}
    triples_filtered = []
    for triple in g.triples:
        if triple[1] == rel:
            invoked = triple[0]
            forbidden_node = triple[2]
            instance = [concept for node, _, concept in g.instances() if node == forbidden_node]
            if instance:
                forbidden_nodes_with_instances[forbidden_node] = (instance[0], invoked)
            else:
                forbidden_nodes_with_instances[forbidden_node] = ('', invoked)
        else:
            triples_filtered.append(triple)       
    for forbidden_node in forbidden_nodes_with_instances:
        instance, invoked = forbidden_nodes_with_instances[forbidden_node]
        for i in range(len(triples_filtered)):
            n, r, c = triples_filtered[i]
            if n == invoked and r == ':instance' and c != 'amr-unknown':
                triples_filtered[i] = (n, r, f'{instance}_{c}')
    triples_filtered = [t for t in triples_filtered if t[0] not in forbidden_nodes_with_instances]
    epidata = {(n, r, c):g.epidata[(n, r, c.split('_')[-1])] for n, r, c in triples_filtered}
    new_g = Graph(triples=triples_filtered, epidata=epidata)    
    return new_g

new_g = concat_rel(g, ':mod')
print(penman.encode(new_g), '\n')
print(penman.encode(g))

(MRPNode-0 / possible-01
           :ARG1 (MRPNode-1 / wrong-02
                            :ARG1 (MRPNode-2 / amr-unknown)
                            :ARG2 (MRPNode-3 / air_i_conditioner))) 

(MRPNode-0 / possible-01
           :ARG1 (MRPNode-1 / wrong-02
                            :ARG1 (MRPNode-2 / amr-unknown)
                            :ARG2 (MRPNode-3 / conditioner
                                             :mod (MRPNode-4 / i)
                                             :mod (MRPNode-5 / air))))


## Similarity Measures

#### SBert models vs GloVe 6B.100d:

In [4]:
import numpy as np
from scipy.spatial.distance import cosine
from sentence_transformers import SentenceTransformer, util

In [5]:
def cos(a,b):
    #cosine similarity
    dist = cosine(a,b)
    sim = 1 - min(1,dist)
    return sim

def load_glove(fp):
    dic={}
    if not fp:
        return dic
    with open(fp,"r") as f:
        for line in f:
            ls = line.split()
            word = ls[0]
            vec = np.array([float(x) for x in ls[1:]])
            dic[word] = vec
    return dic

def vecs_of_sents(m, sents):
    s_vs = np.asarray([np.sum([m[word] for word in sent.split()], axis=0)/len(sent.split()) for sent in sents])
    return s_vs

def print_scores(s1, s2, cosine_scores):
    max_s1 = max([len(s) for s in s1])
    max_s2 = max([len(s) for s in s2])
    
    for i in range(cosine_scores.shape[0]):
        for j in range(cosine_scores.shape[1]):
            print(f'{s1[i]:{max_s1}}\t{s2[j]:{max_s2}}\tScore: {cosine_scores[i, j]:.4f}')
            
def sbert_sim(model, s1, s2):
    embeddings1 = model.encode(s1, convert_to_tensor=True)
    embeddings2 = model.encode(s2, convert_to_tensor=True)
    
    cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
    
    print_scores(s1, s2, cosine_scores)
            
def glove_sim(model, s1, s2):
    embeddings1 = vecs_of_sents(model, s1)
    embeddings2 = vecs_of_sents(model, s2)
    
    cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
    
    print_scores(s1, s2, cosine_scores)


sbert1 = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')
sbert2 = SentenceTransformer('paraphrase-distilroberta-base-v1')
glove = load_glove('amr_suite/vectors/glove.6B.100d.txt')

In [22]:
#s1 = ['french fries']
#s2 = ['chip', 'chips']
s1 = ['How do I pump up water pressure in my shower?']
s2 = ['How can I boost the water pressure in my shower?']

s1_glove = ['how do i pump up water pressure in my shower ?']
s2_glove = ['how can i boost the water pressure in my shower ?']

print('"paraphrase-distilroberta-base-v1":')
sbert_sim(sbert2, s1, s2)
print('\n')
print('"distilbert-base-nli-stsb-mean-tokens":')
sbert_sim(sbert1, s1, s2)
print('\n')
print('"GloVe average":')
glove_sim(glove, s1_glove, s2_glove)

"paraphrase-distilroberta-base-v1":
How do I pump up water pressure in my shower?	How can I boost the water pressure in my shower?	Score: 0.9333


"distilbert-base-nli-stsb-mean-tokens":
How do I pump up water pressure in my shower?	How can I boost the water pressure in my shower?	Score: 0.9025


"GloVe average":
how do i pump up water pressure in my shower ?	how can i boost the water pressure in my shower ?	Score: 0.9868
