# Contrastive Learning Pairs

This notebook has the code for prebuilding and saving pairs for contrastive learning. The explanation of how the pairs are taken is in the playground notebook of the contrastive learning folder.

A single python function return positive and negative pairs for contrastive learning doing this can be found in src/contrastive_pairs.py

In [3]:
import pandas as pd
import pickle

## Training Data

### Load Data

In [2]:
train_data = pd.read_pickle('../data/04a_Train_Set.pkl')

### Positive Pairs

In [3]:
positive_training_pairs = []

for group in train_data.groupby("author_email"):
   for i, message_1 in enumerate(group[1]['message']):
       for message_2 in group[1]['message'].iloc[i+1:]:
            positive_training_pairs.append([message_1, message_2, 1])

### Negative Pairs

In [4]:
groups_calculated = []

negative_training_pairs = []

train_data_groups = train_data.groupby("author_email")

cut_amount = 369

count = 0

for group in train_data_groups:
    groups_calculated.append(group[0])
    negative_groups = [group if group[0] not in groups_calculated else None for group in train_data_groups]
    negative_groups = list(filter(lambda item: item is not None, negative_groups))
    for message_1 in group[1]['message'].sample(n=cut_amount):
        for negative_group in negative_groups:
            for message_2 in negative_group[1]['message'].sample(n=cut_amount):
                negative_training_pairs.append([message_1, message_2, -1])

### Check sum

In [5]:
print('Number of positive pairs: ' + str(len(positive_training_pairs)))
print('Number of negative pairs: ' + str(len(negative_training_pairs)))

Number of positive pairs: 51268322
Number of negative pairs: 51468858


### Save

In [6]:
with open('../data/06a_Contrastive_Train_Pairs.pkl', 'wb') as f:
    positive_training_pairs.extend(negative_training_pairs)
    pickle.dump(positive_training_pairs, f, protocol=pickle.HIGHEST_PROTOCOL)

## Validation Data

### Load Data

In [None]:
validate_data = pd.read_pickle('../data/04b_Validate_Set.pkl')

### Positive Pairs

In [None]:
positive_validate_pairs = []

for group in validate_data.groupby("author_email"):
   for i, message_1 in enumerate(group[1]['message']):
       for message_2 in group[1]['message'].iloc[i+1:]:
            positive_validate_pairs.append([message_1, message_2, 1])

### Negative Pairs

In [None]:
groups_calculated = []

negative_validate_pairs = []

validate_data_groups = validate_data.groupby("author_email")

cut_amount = 650

for group in validate_data_groups:
    groups_calculated.append(group[0])
    negative_groups = [group if group[0] not in groups_calculated else None for group in validate_data_groups]
    negative_groups = list(filter(lambda item: item is not None, negative_groups))
    for message_1 in group[1]['message'].sample(n=cut_amount):
        for negative_group in negative_groups:
            for message_2 in negative_group[1]['message'].sample(n=cut_amount):
                negative_validate_pairs.append([message_1, message_2, -1])

### Check sum

In [None]:
print('Number of positive pairs: ' + str(len(positive_validate_pairs)))
print('Number of negative pairs: ' + str(len(negative_validate_pairs)))

Number of positive pairs: 8863415
Number of negative pairs: 8872500


### Save

In [None]:
import pickle

with open('../data/06b_Contrastive_Validate_Pairs.pkl', 'wb') as f:
    positive_validate_pairs.extend(negative_validate_pairs)
    pickle.dump(positive_validate_pairs, f, protocol=pickle.HIGHEST_PROTOCOL)

## Test Data

### Load Data

In [None]:
test_data = pd.read_pickle('../data/04c_Test_Set.pkl')

### Positive Pairs

In [None]:
positive_test_pairs = []

for group in test_data.groupby("author_email"):
   for i, message_1 in enumerate(group[1]['message']):
       for message_2 in group[1]['message'].iloc[i+1:]:
            positive_test_pairs.append([message_1, message_2, 1])

### Negative Pairs

In [None]:
groups_calculated = []

negative_test_pairs = []

test_data_groups = test_data.groupby("author_email")

cut_amount = 647

count = 0

for group in test_data_groups:
    groups_calculated.append(group[0])
    negative_groups = [group if group[0] not in groups_calculated else None for group in test_data_groups]
    negative_groups = list(filter(lambda item: item is not None, negative_groups))
    for message_1 in group[1]['message'].sample(n=cut_amount):
        for negative_group in negative_groups:
            for message_2 in negative_group[1]['message'].sample(n=cut_amount):
                negative_test_pairs.append([message_1, message_2, -1])

### Check sum

In [None]:
print('Number of positive pairs: ' + str(len(positive_test_pairs)))
print('Number of negative pairs: ' + str(len(negative_test_pairs)))

Number of positive pairs: 8788005
Number of negative pairs: 8790789


### Save

In [None]:
import pickle

with open('../data/06c_Contrastive_Test_Pairs.pkl', 'wb') as f:
    positive_test_pairs.extend(negative_test_pairs)
    pickle.dump(positive_test_pairs, f, protocol=pickle.HIGHEST_PROTOCOL)