In [None]:
# this takes infreqent MSCOCO frames and points them to a parent frame
# stores the result in: mscoco_framenet_parent_graph.pik

In [1]:
import spacy
import json
import sys
import numpy as np
from pprint import pprint
from collections import Counter, defaultdict
import nltk
from nltk.corpus import framenet as fn
import copy
import networkx as nx

In [2]:
nlp = spacy.load('en')

In [3]:
def read_framenet(fname):
    tokens = []
    frames = []
    for line in open(fname, "r"):
        js = json.loads(line)
        tokens.append(js["tokens"])
        frames.append(js["frames"])
    return tokens, frames

def frames_to_name_span_dict(frames):
    frame_dicts = []
    for frame in frames:
        fd = {}
        for anns in frame:
            tg = anns['target']
            start = tg['spans'][0]['start']
            end = tg['spans'][0]['end']
            name = tg['name']
            fd[(start, end)] = name
        frame_dicts.append(fd)
    return frame_dicts

In [None]:
#tokens, frames = read_framenet("/data/DataSets/COCO/train_sents_small_rnd.json")
#tokens, frames = read_framenet("/localdata/u4534172/COCO/train_sents_small_rnd.json")
tokens, frames = read_framenet("/localdata/u4534172/COCO/train_sents.json")
frame_dicts = frames_to_name_span_dict(frames)

In [None]:
def count_verb_frames(tokens, frame_dicts, nlp):
    
    counter = Counter()
    wordlist = defaultdict(Counter)
    for i,doc in enumerate(nlp.pipe([" ".join(toks) for toks in tokens], 
                                    n_threads=8, batch_size=50000)):
        for j, tok in enumerate(doc):
            if tok.pos_ == 'VERB':
                fk = (j, j+1)
                fk2 = (j, j+2)
                fkn = (j-1, j+1)
                if fk in frame_dicts[i]:
                    counter[frame_dicts[i][fk]]+=1
                    wordlist[frame_dicts[i][fk]].update([tok.orth_])
                elif fk2 in frame_dicts[i]:
                    counter[frame_dicts[i][fk2]]+=1
                    wordlist[frame_dicts[i][fk2]].update([" ".join([t.orth_ for t in doc[j:j+2]])])
                elif fkn in frame_dicts[i]:
                    counter[frame_dicts[i][fkn]]+=1
                    wordlist[frame_dicts[i][fkn]].update([" ".join([t.orth_ for t in doc[j-1:j+1]])])
                else:
                    pass
    return counter, wordlist

In [None]:
frame_counts, frame_to_words = count_verb_frames(tokens, frame_dicts, nlp)

In [7]:
def get_framenet_graph():
    
    # create graph nodes
    dg = nx.MultiDiGraph()
    for f in fn.frames():
        dg.add_node(f.name, count=0, words=[])
        
    # add edges
    relations = ['Inheritance', 'Using', 'Perspective_on']
    for f in fn.frames():
        for fr in f['frameRelations']:
            if fr['type']['name'] in relations and fr['superFrameName'] != f.name:
                dg.add_edge(f.name, fr['superFrameName'], key=None, 
                            edge_type=fr['type']['name'])
    
    return dg

def add_words_and_count_to_graph(frame_graph, frame_counts, frame_to_words):
    for f,c in frame_counts.items():
       if f in frame_graph:
            frame_graph.nodes[f]['count'] = c
    for f,w in frame_to_words.items():
       if f in frame_graph:
           frame_graph.nodes[f]['words'] = w
    return frame_graph

In [None]:
frame_graph = get_framenet_graph()
frame_graph = add_words_and_count_to_graph(frame_graph, frame_counts, frame_to_words)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(16,9))
nx.draw(frame_graph, node_size=50)

In [None]:
def compress_framenet_graph(frame_graph, th=200):
    parent_graph = nx.DiGraph()
    parent_graph.add_nodes_from(frame_graph.nodes(data=True))
    
    sp = list(nx.all_pairs_shortest_path_length(frame_graph))
    for f, parents in sp:
        
        # dont compress nodes which have a high count
        if frame_graph.nodes[f]['count'] >= th:
            parent_graph.add_edge(f, f)
            continue
        
        # find the closest parent over threshold
        best_dist = 9999
        best_p = None
        for p,d in parents.items():
            if frame_graph.nodes[p]['count'] >= th:
                if d < best_dist:
                    best_dist = d
                    best_p = p
        
        # no parents above threshold
        if best_p is None:
            # get the highest parent
            best_p, best_dist = sorted(parents.items(), key=lambda x: -x[1])[0]
    
        # compress edges
        parent_graph.add_edge(f, best_p)
    return parent_graph


In [None]:
parent_graph = compress_framenet_graph(frame_graph)

In [None]:
plt.figure(figsize=(16,9))
nx.draw(parent_graph, node_size=50)

In [None]:
parent_graph.in_degree()

In [22]:
list(parent_graph.neighbors('Motion'))

['Motion']

In [None]:
counts = defaultdict(int)
for n,v in parent_graph.nodes.items():
    for sc in parent_graph.successors(n):
        counts[sc] += v['count']
filter(lambda x:x[1] > 0, sorted(list(counts.items()), key=lambda x:-x[1]))

In [30]:
def remove_empty_wordlists(graph):
    for node in graph.nodes:
        if graph.nodes[node]['words'] == []:
            graph.nodes[node]['words'] = False
        if isinstance(graph.nodes[node]['words'], set):
            graph.nodes[node]['words'] = list(graph.nodes[node]['words'])
        if isinstance(graph.nodes[node]['words'], Counter):
            graph.nodes[node]['words'] = list(graph.nodes[node]['words'].items())
    return graph

In [39]:
#parent_graph = remove_empty_wordlists(parent_graph)
#nx.write_gml(parent_graph, "/localdata/u4534172/COCO/mscoco_framenet_parent_graph.gml")
nx.write_gpickle(parent_graph, "/localdata/u4534172/COCO/mscoco_framenet_parent_graph.pik", 2)

In [40]:
parent_graph2 = nx.read_gpickle("/localdata/u4534172/COCO/mscoco_framenet_parent_graph.pik")

In [44]:
parent_graph2.adj["Closure"]

AtlasView({'Closure': {}})