In [1]:
import pandas as pd
import networkx as nx
from tqdm import tqdm

In [2]:
def split_triples(triples, fracs):
    assert len(fracs) == 3
    assert sum(fracs) == 1
    
    train, remainder = split_triples_into_two(triples, fracs[0])
    valid, test = split_triples_into_two(remainder, fracs[1]/(fracs[1] + fracs[2]))
    
    return train, valid, test

def split_triples_into_two(triples, frac):
    group1 = triples.sample(frac=frac, random_state=42)
    group2 = triples.drop(group1.index)
    assert group1.shape[0] + group2.shape[0] == triples.shape[0]
    return group1, group2

# Note that we are using Graph instead of MultiDiGraph here since networkx does not support
# computing connected components on a directed graph 
def construct_networkx_object(df_triples, df_entities):
    G = nx.Graph()

    # add nodes
    print('adding nodes')
    e_tuples = [(row['name'], {'id': row['id']}) for i, row in df_entities.iterrows()]
    G.add_nodes_from(e_tuples)
    
    # add edges
    print('adding edges')
    for i, row in tqdm(df_triples.iterrows(), total=df_triples.shape[0]):
        G.add_edge(row.e1, row.e2, relation=row.rel)
        
    assert nx.number_of_nodes(G) == df_entities.shape[0]
        
    return G

def restrict_triples_to_entities(triples, entities_to_keep):
    assert all(triples.columns == ['e1', 'rel', 'e2'])
    
    idx_to_keep = []
    for idx, row in triples.iterrows():
        if row.e1 in entities_to_keep and row.e2 in entities_to_keep:
            idx_to_keep.append(idx)
    
    print(f'Keeping {len(idx_to_keep)} rows out of {triples.shape[0]}')
    
    return triples.loc[idx_to_keep,]

def get_all_entities_from_triples(triples):
    return set(triples.e1).union(set(triples.e2))

### Read in the full set of triples and split randomly (80/10/10)

In [3]:
data_dir = '../data/fb13/'
graph = pd.read_csv(f'{data_dir}/graph.txt', sep='\t', names=['e1', 'rel', 'e2'])
graph.rel.value_counts()

gender            66663
nationality       54451
profession        54182
place_of_death    40579
place_of_birth    37970
location          28783
institution       18358
cause_of_death    12857
religion           9635
parents            6268
children           6041
ethnicity          5622
spouse             4464
Name: rel, dtype: int64

In [4]:
train, valid, test = split_triples(graph, fracs=[0.8, 0.1, 0.1])
#train.rel.value_counts()
#valid.rel.value_counts()
#test.rel.value_counts()

### Identify largest connected component in training data and restrict all splits to entities from this component

In [5]:
# Read in entity info and construct networkX object from training data
entities = pd.read_csv(f'{data_dir}/entity2id.txt', sep='\t', names=['name', 'id'])
G = construct_networkx_object(train, entities)

adding nodes


  1%|          | 1811/276698 [00:00<00:30, 9064.06it/s]

adding edges


100%|██████████| 276698/276698 [00:30<00:00, 8993.24it/s]


In [6]:
# Identify largest connected component of this
largest_component = max(nx.connected_components(G), key=len)
G_connected = G.subgraph(largest_component).copy()

# Confirm that this is a single connected component
[len(c) for c in sorted(nx.connected_components(G_connected), key=len, reverse=True)]

[74845]

In [7]:
# Extract the nodes from this component and confirm that there are fewer than in the original graph
entities_to_keep = set(G_connected.nodes())
all_entities = set(G.nodes())
assert len(entities_to_keep) < len(all_entities)

In [14]:
# Filter train, valid, test, and entities based on this
train_filtered = restrict_triples_to_entities(train, entities_to_keep)
valid_filtered = restrict_triples_to_entities(valid, entities_to_keep)
test_filtered = restrict_triples_to_entities(test, entities_to_keep)
entities_filtered = entities[entities.name.isin(entities_to_keep)]

Keeping 276690 rows out of 276698
Keeping 34292 rows out of 34588
Keeping 34287 rows out of 34587


In [15]:
# Confirm that valid and test entities are a subset of train entities
train_entities = get_all_entities_from_triples(train_filtered)
valid_entities = get_all_entities_from_triples(valid_filtered)
test_entities = get_all_entities_from_triples(test_filtered)
assert valid_entities.issubset(train_entities)
assert test_entities.issubset(train_entities)

In [30]:
# Reset entity id mapping
entities_filtered.id = range(74845)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[name] = value


In [31]:
# Write to file
train_filtered.to_csv(f'{data_dir}resplit/train.txt', sep='\t', index=False, header=None)
valid_filtered.to_csv(f'{data_dir}resplit/valid.txt', sep='\t', index=False, header=None)
test_filtered.to_csv(f'{data_dir}resplit/test.txt', sep='\t', index=False, header=None)
entities_filtered.to_csv(f'{data_dir}resplit/entity2id.txt', sep='\t', index=False, header=None)