In [None]:
import os
import json
import pandas as pd
import numpy as np
np.random.seed(123)

In [None]:
def create_negative_paris(train_rltshps, unique_persons):
    '''
    Create negative pairs: for each person of positive pair create negative pair
    by picking some random person with whome they are not in the relationship.
    '''
    neg_per_person = 10
    n = len(unique_persons)
    negative_rltshps = []
    
    for pair_set in train_rltshps:
        p1, p2 = list(pair_set)
        
        # Add negative pairs
        # For the person p1
        rnd_idx = np.random.randint(n)
        negative_sample = unique_persons[rnd_idx]
        
        cnt = 0
        while cnt < neg_per_person:
            while(negative_sample == p1 or \
                  (set([p1, negative_sample]) in train_rltshps) or \
                  (set([p1, negative_sample]) in negative_rltshps)):
                rnd_idx = np.random.randint(n)
                negative_sample = unique_persons[rnd_idx]

            negative_rltshps.append(set([p1, negative_sample]))
            cnt += 1

        # For the person p2
        rnd_idx = np.random.randint(n)
        negative_sample = unique_persons[rnd_idx]
        
        cnt = 0
        while cnt < neg_per_person:
            while(negative_sample == p2 or \
                  (set([p2, negative_sample]) in train_rltshps) or \
                  (set([p2, negative_sample]) in negative_rltshps)):
                rnd_idx = np.random.randint(n)
                negative_sample = unique_persons[rnd_idx]

            negative_rltshps.append(set([p2, negative_sample]))
            cnt += 1
        
        # Print current negative samples count
        negative_cnt = len(negative_rltshps)
        if negative_cnt % 1000 == 0:
            print(f'Current negative samples count: {negative_cnt}')
            
    return negative_rltshps

In [None]:
# Read relatives' pairs
train_rltshps = pd.read_csv("data/train_relationships.csv")
train_rltshps.head()

In [None]:
# Shuffle rows in pandas DataFrame
train_rltshps = train_rltshps.sample(frac=1, random_state=123).reset_index(drop=True)

# Create list of sets from dataframe
train_rlt_list = [set(e) for e in zip(train_rltshps['p1'], train_rltshps['p2'])]
valid_rlt_list = []

# Slpit pairs train-valid in order to persons be unique per dataset
VAL_FACTOR = 0.12

def get_related_pairs(person, pairs, dest):
    for pair in pairs:
        if person in pair:
            p1, p2 = list(pair)
            dest.append(pair)
            pairs.remove(pair)
            target_p = p1 if p1 != person else p2
            get_related_pairs(target_p, pairs, dest)
            
while (len(valid_rlt_list) / len(train_rlt_list)) <= VAL_FACTOR:
    initial_p = train_rlt_list[0]
    train_rlt_list = train_rlt_list[1:]
    valid_rlt_list.append(initial_p)
    p1, p2 = list(initial_p)
    get_related_pairs(p1, train_rlt_list, valid_rlt_list)
    get_related_pairs(p2, train_rlt_list, valid_rlt_list)

print(f'Train pairs length: {len(train_rlt_list)}')
print(f'Valid pairs length: {len(valid_rlt_list)}')
print(f'Valid/Train ratio: {len(valid_rlt_list)/len(train_rlt_list)}')

unique_train_persons = set([person for pair in train_rlt_list for person in pair])
unique_valid_persons = set([person for pair in valid_rlt_list for person in pair])
print(f'Train unique persons: {len(unique_train_persons)}')
print(f'Valid unique persons: {len(unique_valid_persons)}')
print(f'Persons intersection: {unique_train_persons & unique_valid_persons}')

In [None]:
# Create negative relationships
neg_train_rltshps = create_negative_paris(train_rlt_list, list(unique_train_persons))
neg_valid_rltshps = create_negative_paris(valid_rlt_list, list(unique_valid_persons))

In [None]:
train_val_set = {
    'train_rlt_list': [list(p) for p in train_rlt_list],
    'neg_train_rltshps': [list(p) for p in neg_train_rltshps],
    'valid_rlt_list': [list(p) for p in valid_rlt_list],
    'neg_valid_rltshps': [list(p) for p in neg_valid_rltshps]
}

with open('train_val_set.json', 'w') as f:
    json.dump(train_val_set, f)