In [1]:
%matplotlib inline

import os
import json
import numpy as np
import collections

import networkx as nx
import matplotlib.pyplot as plt

### Setup File Paths

In [17]:
# root_dir = '/data2/csz/MSRVTT'
# root_dir = '/data2/csz/TGIF'
root_dir = '/data2/csz/VATEX'

anno_dir = os.path.join(root_dir, 'annotation', 'RET')
os.makedirs(anno_dir, exist_ok=True)

# input files
sent2srl_file = os.path.join(anno_dir, 'sent2srl.json')

# output files
sent2rg_file = os.path.join(anno_dir, 'sent2rolegraph.json')
sent2rga_file = os.path.join(anno_dir, 'sent2rolegraph.augment.json')

### Convert Sentence to Role Graph

In [18]:
sent2srl = json.load(open(sent2srl_file))

In [4]:
def create_role_graph_data(srl_data):
  words = srl_data['words']
  verb_items = srl_data['verbs']
    
  graph_nodes = {}
  graph_edges = []
    
  root_name = 'ROOT'
  graph_nodes[root_name] = {'words': words, 'spans': list(range(0, len(words))), 'role': 'ROOT'}
    
  # parse all verb_items
  phrase_items = []
  for i, verb_item in enumerate(verb_items):
    tags = verb_item['tags']
    tag2idxs = {}
    tagname_counter = {} # multiple args of the same role
    for t, tag in enumerate(tags):
      if tag == 'O':
        continue
      if t > 0 and tag[0] != 'B':
        # deal with some parsing mistakes, e.g. (B-ARG0, O-ARG1)
        # change it into (B-ARG0, B-ARG1)
        if tag[2:] != tags[t-1][2:]:
          tag = 'B' + tag[1:]
      tagname = tag[2:]
      if tag[0] == 'B':
        if tagname not in tagname_counter:
          tagname_counter[tagname] = 1
        else:
          tagname_counter[tagname] += 1
      new_tagname = '%s:%d'%(tagname, tagname_counter[tagname])
      tag2idxs.setdefault(new_tagname, [])
      tag2idxs[new_tagname].append(t)
    if len(tagname_counter) > 1 and 'V' in tagname_counter and tagname_counter['V'] == 1:
      phrase_items.append(tag2idxs)

  node_idx = 1
  spanrole2nodename = {}
  for i, phrase_item in enumerate(phrase_items):
    # add verb node to graph
    tagname = 'V:1'
    role = 'V'
    spans = phrase_item[tagname]
    spanrole = '-'.join([str(x) for x in spans] + [role])
    if spanrole in spanrole2nodename:
      continue
    node_name = str(node_idx)
    tag_words = [words[idx] for idx in spans]
    graph_nodes[node_name] = {
      'role': role, 'spans': spans, 'words': tag_words,
    }
    spanrole2nodename[spanrole] = node_name
    verb_node_name = node_name
    node_idx += 1
    
    # add arg nodes and edges of the verb node
    for tagname, spans in phrase_item.items():
      role = tagname.split(':')[0]
      if role != 'V':
        spanrole = '-'.join([str(x) for x in spans] + [role])
        if spanrole in spanrole2nodename:
          node_name = str(spanrole2nodename[spanrole])
        else:
          # add new node or duplicate a node with a different role
          node_name = str(node_idx)
          tag_words = [words[idx] for idx in spans]
          graph_nodes[node_name] = {
            'role': role, 'spans': spans, 'words': tag_words,
          }
          spanrole2nodename[spanrole] = node_name
          node_idx += 1
        # add edge
        graph_edges.append((verb_node_name, node_name, role))
            
  return graph_nodes, graph_edges

In [19]:
sent2graph = {}
for sent, srl in sent2srl.items():
    try:
        graph_nodes, graph_edges = create_role_graph_data(srl)
        sent2graph[sent] = (graph_nodes, graph_edges)
    except:
        print(sent)

In [20]:
json.dump(sent2graph, open(sent2rg_file, 'w'))

In [21]:
n = 0
for sent, graph in sent2graph.items():
  if len(graph[0]) == 1:
    n += 1
#     print(sent)
print('#sents without non-root nodes:', n)

#sents without non-root nodes: 3504


### Augment Graph if no SRL is detected (no verb)

In [9]:
import spacy

# ! python -m spacy download en_core_web_sm
nlp = spacy.load("en_core_web_sm")


In [10]:
for sent, graph in sent2graph.items():
  nodes, edges = graph
  node_idx = len(nodes)
                        
  # add noun and verb word node if no noun and no noun phrases
  if len(nodes) == 1:
    doc = nlp(sent)
    assert len(doc) == len(nodes['ROOT']['words']), sent
    
    # add noun nodes
    for w in doc.noun_chunks:
      node_name = str(node_idx)
      nodes[node_name] = {
        'role': 'NOUN', 'spans': np.arange(w.start, w.end).tolist()
      }
      nodes[node_name]['words'] = [doc[j].text for j in nodes[node_name]['spans']]
      node_idx += 1
    if len(nodes) == 1:
      for w in doc:
        node_name = str(node_idx)
        if w.tag_.startswith('NN'):
          nodes[node_name] = {
            'role': 'NOUN', 'spans': [w.i], 'words': [w.text],
          }
          node_idx += 1
    
    # add verb nodes
    for w in doc:
      node_name = str(node_idx)
      if w.tag_.startswith('VB'):
        nodes[node_name] = {
          'role': 'V', 'spans': [w.i], 'words': [w.text],
        }
        node_idx += 1
    
  sent2graph[sent] = (nodes, edges)
  
print(len(sent2graph))

124654


In [11]:
json.dump(sent2graph, open(sent2rga_file, 'w'))    

### Statistics

In [12]:
role_types = collections.Counter()
for sent, graph in sent2graph.items():
  nodes = graph[0]
  for k, v in nodes.items():
    role_types[v['role']] += 1
print(len(role_types))

39


In [13]:
role_types.most_common()

[('V', 209170),
 ('ARG1', 137677),
 ('ARG0', 131884),
 ('ROOT', 124654),
 ('ARG2', 38411),
 ('ARGM-DIR', 30282),
 ('ARGM-LOC', 23348),
 ('ARGM-TMP', 18048),
 ('ARGM-MNR', 16621),
 ('ARGM-ADV', 4797),
 ('NOUN', 3551),
 ('ARG4', 1856),
 ('ARGM-PRP', 1817),
 ('ARG3', 1725),
 ('R-ARG0', 1598),
 ('ARGM-PRD', 1455),
 ('R-ARG1', 1037),
 ('ARGM-GOL', 771),
 ('ARGM-COM', 689),
 ('C-ARG1', 389),
 ('ARGM-NEG', 259),
 ('ARGM-CAU', 237),
 ('ARGM-EXT', 204),
 ('ARGM-MOD', 112),
 ('ARGM-DIS', 87),
 ('R-ARGM-LOC', 47),
 ('ARGM-LVB', 45),
 ('ARGM-ADJ', 43),
 ('C-ARG0', 43),
 ('ARGM-REC', 36),
 ('ARGM-PNC', 29),
 ('R-ARG2', 25),
 ('C-ARGM-ADV', 3),
 ('ARG5', 2),
 ('C-ARG2', 2),
 ('C-ARG4', 2),
 ('R-ARGM-MOD', 2),
 ('R-ARGM-TMP', 2),
 ('R-ARGM-MNR', 1)]

In [14]:
# noun per sent
nouns_per_sent = []
for sent, graph in sent2graph.items():
  n_nouns = 0
  for node_id, node in graph[0].items():
    if node['role'] != 'ROOT' and node['role'] != 'V':
      n_nouns += 1
  if n_nouns == 0:
    print(sent)
  nouns_per_sent.append(n_nouns)
nouns_per_sent = np.array(nouns_per_sent)
print(np.sum(nouns_per_sent == 0), np.min(nouns_per_sent), np.mean(nouns_per_sent), np.max(nouns_per_sent),
     np.percentile(nouns_per_sent, 90), np.percentile(nouns_per_sent, 95))

0 1 3.346358720939561 14 5.0 6.0


In [15]:
# verb per sent
verbs_per_sent = []
for sent, graph in sent2graph.items():
  n_verbs = 0
  for node_id, node in graph[0].items():
    if node['role'] == 'V':
      n_verbs += 1
  verbs_per_sent.append(n_verbs)
verbs_per_sent = np.array(verbs_per_sent)
print(np.sum(verbs_per_sent == 0), np.min(verbs_per_sent), np.mean(verbs_per_sent), np.max(verbs_per_sent),
     np.percentile(verbs_per_sent, 90), np.percentile(verbs_per_sent, 96))

624 0 1.6780047170568133 6 3.0 3.0


In [19]:
sent2graph['a egg has been broken and dropped into the cup and a water is boiling in the sauce pan']

({'ROOT': {'words': ['a',
    'egg',
    'has',
    'been',
    'broken',
    'and',
    'dropped',
    'into',
    'the',
    'cup',
    'and',
    'a',
    'water',
    'is',
    'boiling',
    'in',
    'the',
    'sauce',
    'pan'],
   'spans': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
   'role': 'ROOT'},
  '1': {'role': 'V', 'spans': [4], 'words': ['broken']},
  '2': {'role': 'ARG1', 'spans': [0, 1], 'words': ['a', 'egg']},
  '3': {'role': 'V', 'spans': [6], 'words': ['dropped']},
  '4': {'role': 'ARG4', 'spans': [7, 8, 9], 'words': ['into', 'the', 'cup']},
  '5': {'role': 'V', 'spans': [14], 'words': ['boiling']},
  '6': {'role': 'ARG1', 'spans': [11, 12], 'words': ['a', 'water']},
  '7': {'role': 'ARGM-LOC',
   'spans': [15, 16, 17, 18],
   'words': ['in', 'the', 'sauce', 'pan']}},
 [('1', '2', 'ARG1'),
  ('3', '2', 'ARG1'),
  ('3', '4', 'ARG4'),
  ('5', '6', 'ARG1'),
  ('5', '7', 'ARGM-LOC')])