# Train / validation / test split

This is a notebook that was used to split the CC dataset into train/valid/test samples for the reference annotation NLP task.

In [1]:
from sec_certs.dataset import CCDataset
from sec_certs.sample import CCCertificate
import pandas as pd
from sklearn.model_selection import train_test_split
import json

In [2]:
dset = CCDataset.from_web()
df = dset.to_pandas()
reference_rich_certs = {x.dgst for x in dset if (x.heuristics.st_references.directly_referencing and x.state.st_txt_path) or (x.heuristics.report_references.directly_referencing and x.state.report_txt_path)}
df = df.loc[df.index.isin(reference_rich_certs)]

# The following certs go straight to the test set as they represent super rare categories that we cannot split
certs_from_rare_categories = df.loc[df.category.isin({"Multi-Function Devices", "Mobility", "Data Protection"})].index.tolist()
df = df.loc[~df.index.isin(certs_from_rare_categories)]

In [4]:
# This splits 30/20/50 (train, valid, test)
x_train, x_test = train_test_split(df.index, test_size=0.5, shuffle=True, stratify=df.category)
x_train, x_valid = train_test_split(x_train, test_size=0.4, shuffle=True, stratify=df.loc[df.index.isin(x_train)].category)
x_test = list(x_test) + list(certs_from_rare_categories)

with open("../../../data/reference_annotations_split/train.json", "w") as handle:
    json.dump(x_train.tolist(), handle, indent=4)

with open("../../../data/reference_annotations_split/valid.json", "w") as handle:
    json.dump(x_valid.tolist(), handle, indent=4)

with open("../../../data/reference_annotations_split/test.json", "w") as handle:
    json.dump(x_test, handle, indent=4)