# Data Pre-processing for CoV Classification

split for classifier head training using 5-fold CV

In [1]:
import pandas as pd
from sklearn.model_selection import StratifiedKFold

# for random selection of sequences to shuffle
seed = 42

# for k-fold cv
k = 5

## dataset download

In [2]:
%%bash

# download the CoV classification dataset if it doesn't already exist
if [ ! -e "../data/E_hd-0_cov-1.csv" ]; then
    curl -o 'CoV_classification.tar.gz' -L 'https://zenodo.org/records/14019655/files/CoV_classification.tar.gz?download=1'
    tar xzvf 'CoV_classification.tar.gz' -C ../data
    rm 'CoV_classification.tar.gz'
fi

## 5-fold CV splits for classifier training

In [3]:
# filter out long sequences
df_E = pd.read_csv("../data/E_hd-0_cov-1.csv")
correct_length = df_E.apply(lambda x: len(x["h_sequence"]) + len(x["l_sequence"]) <= 315, axis=1)
df_E.drop(index=correct_length[correct_length == False].index, inplace=True)
print(len(df_E))

24969


In [4]:
# returns lists of indices that are shuffled, stratified k-fold cv
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
X = df_E.drop("label", axis=1)
y = df_E.loc[:, "label"].astype("int64")

for i, (train_index, test_index) in enumerate(skf.split(X, y)):
    print(f"Fold {i}:")

    # select data by each CV fold
    train = df_E.loc[train_index].sample(frac=1, random_state=seed)
    test = df_E.loc[test_index].sample(frac=1, random_state=seed)

    print(train["label"].value_counts())
    print(test["label"].value_counts(), "\n")

    # reset index
    train = train.reset_index(drop=True)
    test = test.reset_index(drop=True)
    
    # save as csvs
    train.to_csv(f'./train-test_splits/E_hd-0_cov-1_train{i}.csv', index=False)
    test.to_csv(f'./train-test_splits/E_hd-0_cov-1_test{i}.csv', index=False)

Fold 0:
0    9988
1    9987
Name: label, dtype: int64
1    2497
0    2497
Name: label, dtype: int64 

Fold 1:
0    9988
1    9987
Name: label, dtype: int64
1    2497
0    2497
Name: label, dtype: int64 

Fold 2:
0    9988
1    9987
Name: label, dtype: int64
1    2497
0    2497
Name: label, dtype: int64 

Fold 3:
0    9988
1    9987
Name: label, dtype: int64
1    2497
0    2497
Name: label, dtype: int64 

Fold 4:
1    9988
0    9988
Name: label, dtype: int64
0    2497
1    2496
Name: label, dtype: int64 

