# Train/Test/Validation Splitting

This notebook was used to split the data for training, validation, and testing.

In [None]:
import sys; sys.path.append("..")

import pandas as pd
import plotly.express as px

import Bio.PDB.PDBParser
from Bio.PDB.Polypeptide import protein_letters_3to1

import gvpgnn.paths as paths
import gvpgnn.data_models as dm

In [None]:
# Open the training data sequences and structure
df = pd.read_csv('../data/cath_w_seqs_share.csv', index_col=0)

## What does the distribution of superfamilies look like?

In [None]:
# Analyze the number of shared homologous superfamilies.
n_sf_unique = len(df.superfamily.unique())
print(f"In training set of size {len(df)}, there are {n_sf_unique} unique homologous superfamilies")

df[df.superfamily == 1]

In [39]:
df_sf_counts = df.groupby(by="superfamily").size()

df_sf_counts.index = df_sf_counts.index.astype("str") # convert to categorical index
df_sf_counts.name = "members"

# Show as a column chart:
fig = px.bar(df_sf_counts, title="Members of each homologous superfamily in the dataset")
fig.show()

# Show as a pie chart:
# fig = px.pie(df_sf_counts.to_frame(), values="members", title="Members of each homologous superfamily in the dataset")
# fig.show()

## Dataset Splitting Strategy

- We want to avoid splitting homologous superfamilies across the training and test splits, since this would allow the model to train on examples that are closely related to those in the test set.
- The simplest solution I can think of is to choose superfamilies to put in the test set, and not let the model see any examples from those families in the training set.
- The question is whether we want more diversity in the training set or test set? There are a few huge superfamilies, like `10` which is the largest (about `28.7%` of examples belong to it). So whichever split has those huge families will get a bit less variety.
- My initial decision is to fill up the test set with small superfamilies to maximize the diversity there. We want to see how the model generalizes to many different unseen superfamilies, and I'm guessing that's how this work task will be evaluated.

In [38]:
fraction_test = 0.2 # of the entire set
fraction_val = 0.2 # of the REMAINING training set

n_test = round(len(df) * fraction_test)
n_training = len(df) - n_test
n_val = round(n_training * fraction_val)
n_train = n_training - n_val

print("\n-- NOMINAL SPLITS:")
print("Dataset size:", len(df))
print("test:", n_test)
print("train:", n_train)
print("val:", n_val)

sf_count_low_to_high: list[tuple[int, int]] = \
  sorted(df.groupby(by="superfamily").size().to_dict().items(), key=lambda x: x[1])

split_superfamilies = dict(train=[], val=[], test=[])
split_counts = dict(train=0, val=0, test=0)

idx = 0

# Build the test set first so that it receives many small superfamilies.
while split_counts["test"] < n_test:
  sf, count = sf_count_low_to_high[idx]
  split_superfamilies["test"].append(sf)
  split_counts["test"] += count
  idx += 1

# Then build the validation set, which we also want to have a diverse set of families.
while split_counts["val"] < n_val:
  sf, count = sf_count_low_to_high[idx]
  split_superfamilies["val"].append(sf)
  split_counts["val"] += count
  idx += 1

# Every remaining example goes in the training set.
split_superfamilies["train"].extend([
  sf_count_low_to_high[i][0] for i in range(idx, len(sf_count_low_to_high))
])

split_counts["train"] = sum([sf_count_low_to_high[i][1] for i in range(idx, len(sf_count_low_to_high))])

print("\n-- ACTUAL SPLITS:")
print("test:", split_counts["test"])
print("val:", split_counts["val"])
print("train:", split_counts["train"])

print("(Note: Due to the size of superfamilies, the actual splits may be slightly different in size.)")

# Make sure that every example is accounted for!
assert(sum(split_counts.values()) == len(df))

# Write the splits to files:
for split_name in ("train", "val", "test"):
  sf_list = split_superfamilies[split_name]
  df_split = df[df.superfamily.isin(sf_list)]
  df_split.to_csv(paths.data_folder(f"{split_name}_cath_w_seqs_share.csv"), index=False)


-- NOMINAL SPLITS:
Dataset size: 6273
test: 1255
train: 4014
val: 1004

-- ACTUAL SPLITS:
test: 1261
val: 1016
train: 3996
(Note: Due to the size of superfamilies, the actual splits may be slightly different in size.)


In [None]:
# Ensure that no examples are shared across splits!
cath_ids = dict(train=set(), val=set(), test=set())

for split_name in cath_ids:
  df_split = pd.read_csv(paths.data_folder(f"{split_name}_cath_w_seqs_share.csv"))
  cath_ids[split_name] = set(df_split.cath_id.unique())

# Just to be really sure...
assert(cath_ids["train"].isdisjoint(cath_ids["test"]))
assert(cath_ids["train"].isdisjoint(cath_ids["val"]))
assert(cath_ids["val"].isdisjoint(cath_ids["test"]))

## Distribution of Labels

Are the labels fairly balanced across the possible categories? Yes! So probably no need to do special weighting in the loss function.

In [None]:
from gvpgnn.data_models import architecture_labels

In [None]:
df['label_tuple'] = list(zip(df['class'], df['architecture']))
df['label_integer'] = df.label_tuple.map(lambda tup: architecture_labels[tup])

In [None]:
df.groupby(by="label_tuple").size()

In [None]:
for split_name in ("train", "val", "test"):
  print("-----",  split_name, "-----")
  df_split = pd.read_csv(paths.data_folder(f"{split_name}_cath_w_seqs_share.csv"))

  df_split['label_tuple'] = list(zip(df_split['class'], df_split['architecture']))
  df_split['label_integer'] = df_split.label_tuple.map(lambda tup: architecture_labels[tup])

  print("Distribution of labels:")
  print(100 * df_split.groupby(by="label_tuple").size() / len(df_split))


# What is the distribution of labels in the train+val set?
# df_train_val = pd.concat([
#   pd.read_csv(paths.data_folder(f"train_cath_w_seqs_share.csv")),
#   pd.read_csv(paths.data_folder(f"val_cath_w_seqs_share.csv"))
# ])

# df_train_val['label_tuple'] = list(zip(df_train_val['class'], df_train_val['architecture']))

# print("Distribution of labels:")
# print(100 * df_train_val.groupby(by="label_tuple").size() / len(df_train_val))

In [None]:
import torch.utils.data as data

In [None]:
list(data.WeightedRandomSampler([0.9, 0.05, 0.05], 100))