## Semantic Role Labelling: data cleaning and analysis

This notebook cleans the raw results from `oecd_semantic_role_labeling.ipynb` and prepares the data for analysis using graph and network analysis tools such as NetworkX, Gephi and iGraph.

In [1]:
import re
import json
import string

### 1. Load the RAW SRL predictions

In [1]:
import pickle
import re
with open('../data-files/srl_predictions_big.pkl', 'rb') as f:
    srl_results = pickle.load(f)

In [2]:
def get_srl_tag_words(sentence):
    tokens = re.findall(r'\[(.*?)\]', sentence)
    verb = None
    arg0 = None 
    arg1 = None
    for token in tokens:
        if  token.startswith('V:'):
            verb = token.replace('V:','').strip()
        if  token.startswith('ARG0:'):
            arg0 = token.replace('ARG0:','').strip()
        if  token.startswith('ARG1:'):
            arg1 = token.replace('ARG1:','').strip()

    return verb, arg0, arg1

triples = set()
for i in range(0, len(srl_results)):
    for j in range(0, len(srl_results[i]["verbs"])):
        verb, arg0, arg1 = get_srl_tag_words(srl_results[i]["verbs"][j]['description'])
        if (verb is not None) and (arg0 is not None) and (arg1 is not None):
            triples.add((verb.strip(), arg0.strip(), arg1.strip()))

tripleslst = list(triples)

print(len(set(tripleslst)))

95520


### 2. Load the named entities dataset from the NER analysis

In [3]:
import pandas as pd

df = pd.read_csv('../data-files/master-ner-results-singletokens.csv')
orgs_df = df[df['entity_type'] == 'ORG'] # only organisaations
orgs_df.head()

Unnamed: 0,entity,entity_type,sentence,span,docid,model,entity_as_single_token
0,OECD,ORG,"The OECD, UCLG –Africa, the World Water Counci...",4:8,31,flair - FLERT and XML embeddings,OECD
1,UCLG – Africa,ORG,"The OECD, UCLG –Africa, the World Water Counci...",10:22,31,flair - FLERT and XML embeddings,UCLG_–_Africa
2,the World Water Council,ORG,"The OECD, UCLG –Africa, the World Water Counci...",24:47,31,flair - FLERT and XML embeddings,World_Water_Council
4,the OECD,ORG,This report is published as part of the OECD P...,36:44,31,flair - FLERT and XML embeddings,OECD
5,Programme on Water Security for Sustainable De...,ORG,This report is published as part of the OECD P...,45:110,31,flair - FLERT and XML embeddings,Programme_on_Water_Security_for_Sustainable_De...


### 3. Import stopwords and human-specified custom false positive entities

In [4]:
from nltk.corpus import stopwords

file1 = open('../data-files/replace_stopwords_orgs.txt', 'r')
other_stopwords = file1.readlines()
other_stopwords2 = []
for item in other_stopwords:
    other_stopwords2.append(item.replace('\n',''))

irrelevant_tokens = ['technology', 'project', 'region', 'agricultural', "'s", 'infrastructure', 'entity', 'state', 'on', 'world', 'working', 'management', 'water_management', 'council', 'task', 'team', 'water', 'climate_change', 'policy', 'covid-19', 'city', 'the', 'et', 'al.', 'x', 'pdf', 'yes', 'abbrev','also','fe',
                    'page', 'pp', 'p', 'er', 'doi', 'can', 'b', 'c', 'd', 'e',
                    'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'o', 'q', 'r', 's', 'herein', 'furthermore',
                    't', 'u', 'v', 'w', 'y', 'z', 'www', 'com', 'org', 'de', 'dx', 'th', 'ii', 'le']

stop_words = set(stopwords.words('english')).union(set(irrelevant_tokens))
stop_words = stop_words.union(set(other_stopwords2))


### 4. Clean SRL results: Pass 1 

In [5]:
from mhs import hitting_sets


def clean_srl_results(tripleslst):
    global stop_words
    global orgs_df
    
    highest_quality_results = set()
    overall_results = set()
    
    idx = 1
    total = len(list(set(tripleslst)))
    for item in list(set(tripleslst)):
        if (idx % 2 == 0):
            print(idx,"/",total, end=' ')
        idx+=1
        verb = item[0]
        arg0 = item[1]
        arg1 = item[2]
        arg0_tokens = arg0.split()
        arg1_tokens = arg1.split()
        if (len(arg0_tokens) == 1 and len(arg1_tokens) == 1):
            if (arg0_tokens[0].strip() not in stop_words) and (arg1_tokens[0].strip() not in stop_words):
                if (arg0_tokens[0].strip() in orgs_df['entity_as_single_token'].tolist()) and (arg1_tokens[0].strip() in orgs_df['entity_as_single_token'].tolist()):
                    highest_quality_results.add((verb, arg0_tokens[0].strip(), arg1_tokens[0].strip()))
                    overall_results.add((verb, arg0_tokens[0].strip(), arg1_tokens[0].strip()))
        else:
            if (len(arg0_tokens) > 1) and (len(arg1_tokens) > 1):
                relevant_tokens_only_0 = [w for w in arg0_tokens if not w in stop_words]
                relevant_tokens_only_0 = [w for w in relevant_tokens_only_0 if w in orgs_df['entity_as_single_token'].tolist()]
                valid_ent_0 = False
                for item in relevant_tokens_only_0:
                    if item in orgs_df['entity_as_single_token'].tolist():
                        valid_ent_0 = True
                        break
                        
                new_arg0 = ''
                if valid_ent_0:
                    new_arg0 = ','.join(list(set(relevant_tokens_only_0)))
                
                relevant_tokens_only_1 = [w for w in arg1_tokens if not w in stop_words]
                relevant_tokens_only_1 = [w for w in relevant_tokens_only_1 if w in orgs_df['entity_as_single_token'].tolist()]
                valid_ent_1 = False
                for item in relevant_tokens_only_1:
                    if item in orgs_df['entity_as_single_token'].tolist():
                        valid_ent_1 = True
                        break
                        
                new_arg1 = ''
                if valid_ent_1:
                    new_arg1 = ','.join(list(set(relevant_tokens_only_1)))
                    
                if (len(new_arg0) > 0 and len(new_arg1) > 0):
                    if (len(relevant_tokens_only_0) > 1) or (len(relevant_tokens_only_1) > 1):
                        #hitting sets
                        if (len(relevant_tokens_only_0) > 6 or len(relevant_tokens_only_1) > 6):
                            print('*', len(relevant_tokens_only_0), len(relevant_tokens_only_1))
                        hs = hitting_sets(set(relevant_tokens_only_0[:25]), set(relevant_tokens_only_1[:25]))
                        # print()
                        # print()
                        # print("original sets:", '\n', set(relevant_tokens_only_0), ',\n', set(relevant_tokens_only_1))
                        # print()
                        # print("hitting sets:")
                        new_hs = []
                        for hs_item in hs:
                            if len(hs_item) == 2:
                                new_hs.append(list(hs_item))
                                print((verb, list(hs_item)[0], list(hs_item)[1]))
                                overall_results.add((verb, list(hs_item)[0], list(hs_item)[1]))
                        # print(list(new_hs))
                        # print()
                    else:   # print()
                        print((verb, new_arg0, new_arg1))
                        overall_results.add((verb, new_arg0, new_arg1))
                    
            elif (len(arg0_tokens) > 1) and (len(arg1_tokens) == 1):
                relevant_tokens_only_0 = [w for w in arg0_tokens if not w in stop_words]
                relevant_tokens_only_0 = [w for w in relevant_tokens_only_0 if w in orgs_df['entity_as_single_token'].tolist()]
                valid_ent_0 = False
                for item in relevant_tokens_only_0:
                    if item in orgs_df['entity_as_single_token'].tolist():
                        valid_ent_0 = True
                        break
                        
                new_arg0 = ''
                if valid_ent_0:
                    new_arg0 = ','.join(list(set(relevant_tokens_only_0)))
                    
                if (len(new_arg0) > 0) and (arg1_tokens[0].strip() not in stop_words) and (arg1_tokens[0].strip() in orgs_df['entity_as_single_token'].tolist()):
                    if len(relevant_tokens_only_0) > 1:
                        #hitting sets
                        if (len(relevant_tokens_only_0) > 6):
                            print('*', len(relevant_tokens_only_0))
                        
                        hs = hitting_sets(set(relevant_tokens_only_0[:50]), set({arg1_tokens[0].strip()}))
                        # print()
                        # print()
                        # print("original sets:", '\n', set(relevant_tokens_only_0), ',\n', set({arg1_tokens[0].strip()}))
                        # print()
                        # print("hitting sets:")
                        new_hs = []
                        for hs_item in hs:
                            if len(hs_item) == 2:
                                new_hs.append(list(hs_item))
                                print((verb, list(hs_item)[0], list(hs_item)[1]))
                                overall_results.add((verb, list(hs_item)[0], list(hs_item)[1]))
                        # print(list(new_hs))
                        # print()
                        # print()
                    else:
                        print((verb, new_arg0, arg1_tokens[0].strip()))
                        overall_results.add((verb, new_arg0, arg1_tokens[0].strip()))
                    
            elif (len(arg1_tokens) > 1) and (len(arg0_tokens) == 1):
                relevant_tokens_only_1 = [w for w in arg1_tokens if not w in stop_words]
                relevant_tokens_only_1 = [w for w in relevant_tokens_only_1 if w in orgs_df['entity_as_single_token'].tolist()]
                valid_ent_1 = False
                for item in relevant_tokens_only_1:
                    if item in orgs_df['entity_as_single_token'].tolist():
                        valid_ent_1 = True
                        break
                        
                new_arg1 = ''
                if valid_ent_1:
                    
                    new_arg1 = ','.join(list(set(relevant_tokens_only_1)))
                    
                if (len(new_arg1) > 0) and (arg0_tokens[0].strip() not in stop_words) and (arg0_tokens[0].strip() in orgs_df['entity_as_single_token'].tolist()):
                    if len(relevant_tokens_only_1) > 1:
                        #hitting sets
                        if (len(relevant_tokens_only_1) > 6):
                            print('*', len(relevant_tokens_only_1))
                        hs = hitting_sets(set({arg0_tokens[0].strip()}), set(relevant_tokens_only_1[:50]))
                        # print()
                        # print()
                        # print("original sets:", '\n', set({arg0_tokens[0].strip()}), ',\n', set(relevant_tokens_only_1))
                        # print()
                        # print("hitting sets:")
                        new_hs = []
                        for hs_item in hs:
                            if len(hs_item) == 2:
                                new_hs.append(list(hs_item))
                                print((verb, list(hs_item)[0], list(hs_item)[1]))
                                overall_results.add((verb, list(hs_item)[0], list(hs_item)[1]))

                        # print(list(new_hs))
                        # print()
                        # print()
                    else:
                        print((verb, arg0_tokens[0].strip(), new_arg1))
                        overall_results.add((verb, arg0_tokens[0].strip(), new_arg1))
        print()
    return list(highest_quality_results), list(overall_results)
                
        

In [6]:
new_hq_res, overall_res = clean_srl_results(tripleslst)


2 / 95520 

4 / 95520 

6 / 95520 

8 / 95520 

10 / 95520 

12 / 95520 

14 / 95520 

16 / 95520 

18 / 95520 

20 / 95520 

22 / 95520 

24 / 95520 

26 / 95520 

28 / 95520 

30 / 95520 

32 / 95520 

34 / 95520 

36 / 95520 

38 / 95520 

40 / 95520 

42 / 95520 

44 / 95520 

46 / 95520 

48 / 95520 

50 / 95520 

52 / 95520 

54 / 95520 

56 / 95520 

58 / 95520 

60 / 95520 

62 / 95520 

64 / 95520 

66 / 95520 

68 / 95520 

70 / 95520 

72 / 95520 

74 / 95520 

76 / 95520 

78 / 95520 

80 / 95520 

82 / 95520 

84 / 95520 

86 / 95520 

88 / 95520 

90 / 95520 

92 / 95520 

94 / 95520 

96 / 95520 

98 / 95520 

100 / 95520 

102 / 95520 

104 / 95520 

106 / 95520 

108 / 95520 

110 / 95520 

112 / 95520 

114 / 95520 

116 / 95520 

118 / 95520 

120 / 95520 

122 / 95520 

124 / 95520 

126 / 95520 

128 / 95520 

130 / 95520 

132 / 95520 

134 / 95520 

136 / 95520 

138 / 95520 

140 / 95520 

142 / 95520 

144 / 95520 

146 / 95520 

148 / 95520 

150 / 95520 

15


KeyboardInterrupt



In [None]:
print(len(set(new_hq_res)))
print(len(set(overall_res)))

### 5. Save Pass 1 cleaning results to file

In [None]:
import pickle
with open('../data-files/srl_predictions_cleaned_singletoken_entities_only.pkl', 'wb') as f:
    pickle.dump(new_hq_res, f)
    
import pickle
with open('../data-files/srl_predictions_cleaned_pass1.pkl', 'wb') as f:
    pickle.dump(overall_res, f)

### 6. Clean SRL results: Pass 2 

In [None]:
file1 = open('../data-files/replace_stopwords_orgs.txt', 'r')
other_stopwords = file1.readlines()
other_stopwords2 = []
for item in other_stopwords:
    other_stopwords2.append(item.replace('\n',''))
    
other_stopwords2 = list(set(other_stopwords2))

cleaned_pass2_hq = []
cleaned_pass2_overall = []

replacements = {
    "solution_water.org" : "water.org",
    "us" : "united_states",
    "eu" : "european_union",
    "un" : "united_nations",
    "ec" : "european_commission",
    "ea" : "environment_agency",
    "environmental_agency" : "environment_agency",
    "eap_task_force" : "oecd_eap_task_force",
    "wb" : "world_bank"
}

def replace_all(text, dic):
    for i, j in dic.items():
        if text.strip() == i:
            text = text.replace(text, j)
    return text

def clean_srl_list_pass2(lst, stopwords):
    cleaned_lst = []
    for item in lst:
        verb = item[0]
        arg0 = item[1]
        arg1 = item[2]

        if (arg0 not in stopwords) and (arg1 not in stopwords):
            if arg0.startswith('the_'):
                arg0 = arg0.replace(arg0[:4], '')

            if arg1.startswith('the_'):
                arg1 = arg1.replace(arg1[:4], '')

            if '_' in arg0:
                arg0tokens = arg0.split('_')
                # print('arg0:', arg0tokens)
                arg0relevant_tokens = []
                for token in arg0tokens:
                    new_token = replace_all(token, replacements)
                    arg0relevant_tokens.append(new_token)
                # print('arg0rel:', arg0relevant_tokens)
                arg0 = '_'.join(arg0relevant_tokens)
            else:
                arg0 = replace_all(arg0, replacements)
                
            if '_' in arg1:
                arg1tokens = arg1.split('_')
                # print('arg1:', arg1tokens)
                arg1relevant_tokens = []
                for token2 in arg1tokens:
                    new_token2 = replace_all(token2, replacements)
                    arg1relevant_tokens.append(new_token2)
                # print('arg1rel:', arg1relevant_tokens)
                arg1 = '_'.join(arg1relevant_tokens)
            else:
                arg1 = replace_all(arg1, replacements)
        
        if (arg0 not in stopwords) and (arg1 not in stopwords):
            cleaned_lst.append((verb, arg0, arg1))
        
    return list(set(cleaned_lst))

cleaned_pass2_hq = clean_srl_list_pass2(new_hq_res, other_stopwords2)
cleaned_pass2_overall = clean_srl_list_pass2(overall_res, other_stopwords2)

### 7. Convert cleaned SRL data to network analysis format

In [None]:
def convert_data_to_networkx_format(data):
    edge_data = []
    edge_data_with_labels = {}
    
    for item in data:
        edge_data.append([item[1], item[2]])
        
    for edge in edge_data:
        edge_data_with_labels[(''+edge[0]+'', ''+edge[1]+'')] = []
        
    for item in data:
        edge_data_with_labels[(''+item[1]+'', ''+item[2]+'')].append(item[0])
    
    for key in edge_data_with_labels:
        edge_data_with_labels[key] = ', '.join(list(set(edge_data_with_labels[key])))
        
    return edge_data, edge_data_with_labels
    
single_token_graph_edges, single_token_graph_edges_with_labels = convert_data_to_networkx_format(cleaned_pass2_hq)
full_graph_edges, full_graph_edges_with_labels = convert_data_to_networkx_format(cleaned_pass2_overall)


### 8. Save cleaned SRL results to file 

To import into Gephi.

In [None]:
import csv

header = ['verb', 'source', 'target']

with open('../data-files/single_token_actors_data.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)
    writer.writerow(header)

    for item in cleaned_pass2_hq:
        rowdata = [item[0], item[1], item[2]]        
        writer.writerow(rowdata)
        
with open('../data-files/single_token_actors_data_edgesonly.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)
    writer.writerow(header[-2:])

    for item in cleaned_pass2_hq:
        rowdata = [item[1], item[2]]        
        writer.writerow(rowdata)
    
with open('../data-files/full_actors_data.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)
    writer.writerow(header)

    for item in cleaned_pass2_overall:
        rowdata = [item[0], item[1], item[2]]        
        writer.writerow(rowdata)
        
        
with open('../data-files/full_actors_data_edgesonly.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)
    writer.writerow(header[-2:])

    for item in cleaned_pass2_overall:
        rowdata = [item[1], item[2]]        
        writer.writerow(rowdata)

### 9. Try to plot using NetworkX 

This should work for smaller graphs (less than 1000 nodes). But for larger graphs perhaps try Gephi, iGraph etc.

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

In [None]:
edges = single_token_graph_edges
G = nx.Graph()
G.add_edges_from(edges)
pos = nx.spring_layout(G, k=2, scale=1, iterations=50)

plt.figure()
nx.draw(
    G, pos, edge_color='black', width=1, linewidths=1,
    node_size=1000, node_color='lightblue', alpha=1,
    labels={node: node for node in G.nodes()}
)
# nx.draw_networkx_edge_labels(
#     G, pos,
#     # edge_labels={('A', 'B'): 'AB', ('B', 'C'): 'BC', ('B', 'D'): 'BD'},
#     # edge_labels=single_token_graph_edges_with_labels,
#     font_color='red'
# )
plt.axis('off')
plt.show()