# Filter to selected samples (e.g. peak timepoints) and make cross validation divisions, based on active CrossValidationSplitStrategy

**Run this notebook separately for each CrossValidationSplitStrategy!**

We've already removed specimens with very few sequences or without all isotypes. And we've already sampled one sequence per clone per isotype per specimen.

If you're editing `config.py`'s active CrossValidationSplitStrategy, make sure to run `python scripts/make_dirs.py`.

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, train_test_split
from malid import config, helpers, logger

In [2]:
# Import the list of specimens remaining after QC filtering.
# Not all specimens survived to this step - some are thrown out for not having enough sequences or not having all isotypes.
# However, these aren't yet filtered to is_selected_for_cv_strategy specimens that are particular to the selected cross validation strategy.

# This list is generated by record_which_specimens_remain_after_qc_filtering.ipynb
specimens = pd.read_csv(
    config.paths.dataset_specific_metadata
    / "specimens_that_survived_qc_filters_in_sample_sequences_notebook.tsv",
    sep="\t",
)
specimens

Unnamed: 0,participant_label,specimen_label,disease,total_sequence_count
0,BFI-0000234,M124-S014,Healthy/Background,67139
1,BFI-0000254,M111-S037,HIV,73771
2,BFI-0000255,M111-S033,HIV,69192
3,BFI-0000256,M111-S038,HIV,99595
4,BFI-0000258,M111-S055,HIV,105617
...,...,...,...,...
2324,towlerton-2022-hiv_1026,towlerton-2022-hiv_015V08002633_CFAR,HIV,25999
2325,towlerton-2022-hiv_1027,towlerton-2022-hiv_015V09002862_CFAR,HIV,44435
2326,towlerton-2022-hiv_1028,towlerton-2022-hiv_015V11002805_CFAR,HIV,23025
2327,towlerton-2022-hiv_1029,towlerton-2022-hiv_015V11001839_CFAR,HIV,47858


## Apply `is_selected_for_cv_strategy` filter using currently active `config.cross_validation_split_strategy`

In [3]:
config.cross_validation_split_strategy

<CrossValidationSplitStrategy.in_house_peak_disease_timepoints: CrossValidationSplitStrategyValue(data_sources_keep=[<DataSource.in_house: 1>], stratify_by='disease', diseases_to_keep_all_subtypes=['Healthy/Background', 'HIV', 'Lupus', 'T1D'], subtypes_keep=['Covid19 - Sero-positive (ICU)', 'Covid19 - Sero-positive (Admit)', 'Covid19 - Acute 2', 'Covid19 - Admit', 'Covid19 - ICU', 'Influenza vaccine 2021 - day 7'], filter_specimens_func_by_study_name={'Covid19-buffycoat': <function acute_disease_choose_most_peak_timepoint at 0x7f75dd6a2f70>, 'Covid19-Stanford': <function acute_disease_choose_most_peak_timepoint at 0x7f75dd6a2f70>}, gene_loci_supported=<GeneLocus.BCR|TCR: 3>, exclude_study_names=['IBD pre-pandemic Yoni'], include_study_names=None, filter_out_specimens_funcs_global=[], study_names_for_held_out_set=None)>

In [4]:
stratify_by_column_name = config.cross_validation_split_strategy.value.stratify_by
stratify_by_column_name

'disease'

In [5]:
# filter to is_selected_for_cv_strategy specimens, in addition to the is_valid / survived_filters filter already here
specimen_metadata = helpers.get_all_specimen_info(
    # CV fold is not available yet, so must set this to False
    add_cv_fold_information=False
).sort_values("disease")
specimen_metadata = specimen_metadata[specimen_metadata["in_training_set"]]

# sanity check the definitions
assert specimen_metadata["is_selected_for_cv_strategy"].all()
assert specimen_metadata["survived_filters"].all()

assert stratify_by_column_name in specimen_metadata.columns

specimen_metadata

Unnamed: 0,participant_label,specimen_label,disease,specimen_time_point,participant_description,data_source,study_name,available_gene_loci,disease_subtype,age,...,symptoms_Lupus_sm_nRNP +/-,symptoms_cmv,symptoms_healthy_in_resequencing_experiment,specimen_time_point_days,survived_filters,is_selected_for_cv_strategy,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup
1818,BFI-0009122,M418-S001,Covid19,10 days,COVID-19 sample from Sam Yang,DataSource.in_house,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - ICU,63.0,...,,,,10.0,True,True,True,False,Covid19,Covid19
1821,BFI-0009128,M418-S009,Covid19,9 days,COVID-19 sample from Sam Yang,DataSource.in_house,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - Admit,43.0,...,,,,9.0,True,True,True,False,Covid19,Covid19
1822,BFI-0009129,M418-S010,Covid19,7 days,COVID-19 sample from Sam Yang,DataSource.in_house,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - Admit,41.0,...,,,,7.0,True,True,True,False,Covid19,Covid19
1823,BFI-0009131,M418-S012,Covid19,8 days,COVID-19 sample from Sam Yang,DataSource.in_house,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - ICU,61.0,...,,,,8.0,True,True,True,False,Covid19,Covid19
1827,BFI-0009139,M418-S021,Covid19,9 days,COVID-19 sample from Sam Yang,DataSource.in_house,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - ICU,48.0,...,,,,9.0,True,True,True,False,Covid19,Covid19
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2309,BFI-0010772,M491-S129,T1D,,Diabetes 35453-study biobank: adult; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - adult,,...,,,,,True,True,True,False,T1D,T1D
2306,BFI-0010768,M491-S124,T1D,,Diabetes 35453-study biobank: adult; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - adult,34.0,...,,,,,True,True,True,False,T1D,T1D
2305,BFI-0010767,M491-S156,T1D,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,16.0,...,,,,,True,True,True,False,T1D,T1D
2304,BFI-0010767,M491-S123,T1D,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,16.0,...,,,,,True,True,True,False,T1D,T1D


In [6]:
# merge back to apply the filter
specimens_merged = pd.merge(
    specimens,
    specimen_metadata,
    on=["participant_label", "specimen_label", "disease"],
    how="inner",
    validate="1:1",
)
assert specimens_merged.shape[0] == specimen_metadata.shape[0] <= specimens.shape[0]
specimens_merged

Unnamed: 0,participant_label,specimen_label,disease,total_sequence_count,specimen_time_point,participant_description,data_source,study_name,available_gene_loci,disease_subtype,...,symptoms_Lupus_sm_nRNP +/-,symptoms_cmv,symptoms_healthy_in_resequencing_experiment,specimen_time_point_days,survived_filters,is_selected_for_cv_strategy,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup
0,BFI-0000234,M124-S014,Healthy/Background,67139,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,...,,,,,True,True,True,False,Healthy/Background,Healthy/Background
1,BFI-0000254,M111-S037,HIV,73771,,Location: Tanzania,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV Broad Neutralizing,...,,,,,True,True,True,False,HIV,HIV
2,BFI-0000255,M111-S033,HIV,69192,,Location: Tanzania,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV Broad Neutralizing,...,,,,,True,True,True,False,HIV,HIV
3,BFI-0000256,M111-S038,HIV,99595,,Location: Tanzania,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV Broad Neutralizing,...,,,,,True,True,True,False,HIV,HIV
4,BFI-0000258,M111-S055,HIV,105617,,Location: Tanzania,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV Non Neutralizing,...,,,,,True,True,True,False,HIV,HIV
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
611,BFI-0010791,M491-S149,T1D,34236,,Diabetes 35453-study biobank: pediatric; T1D +...,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,,,True,True,True,False,T1D,T1D
612,BFI-0010792,M491-S150,T1D,30509,,Diabetes 35453-study biobank: adult; TID,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - adult,...,,,,,True,True,True,False,T1D,T1D
613,BFI-0010794,M491-S153,T1D,23487,,Diabetes 35453-study biobank: pediatric; TID,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,,,True,True,True,False,T1D,T1D
614,BFI-0010800,M491-S160,T1D,13932,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,,,True,True,True,False,T1D,T1D


In [7]:
specimens_merged["data_source"].value_counts()

DataSource.in_house    616
Name: data_source, dtype: int64

In [8]:
# Expressed as specimen counts:

In [9]:
specimens_merged[stratify_by_column_name].value_counts()

Healthy/Background    224
HIV                    98
Lupus                  98
T1D                    96
Covid19                63
Influenza              37
Name: disease, dtype: int64

In [10]:
specimens_merged["disease"].value_counts()

Healthy/Background    224
HIV                    98
Lupus                  98
T1D                    96
Covid19                63
Influenza              37
Name: disease, dtype: int64

In [11]:
specimens_merged["disease_subtype"].value_counts()

T1D - pediatric                                       67
HIV Non Neutralizing                                  53
Healthy/Background - CMV-                             53
Healthy/Background - CMV+                             48
HIV Broad Neutralizing                                45
Healthy/Background - HIV Negative                     43
Healthy/Background (children)                         43
Influenza vaccine 2021 - day 7                        37
Covid19 - Admit                                       33
T1D - adult                                           29
Healthy/Background - SLE Negative                     27
Pediatric SLE - nephritis                             23
SLE Patient                                           21
Pediatric SLE - no nephritis                          20
Covid19 - ICU                                         15
SLE Multiple aAbs / SLE dsDNA WITHOUT Nephritis        9
SLE Multiple aAbs / SLE dsDNA WITH Nephritis           9
SLE Multiple aAbs              

In [12]:
# Expressed as participant counts:

In [13]:
(
    specimens_merged.drop_duplicates(["participant_label", "disease"])["disease"]
    .value_counts()
    .rename("number of participants")
)

Healthy/Background    220
HIV                    95
T1D                    92
Lupus                  86
Covid19                63
Influenza              37
Name: number of participants, dtype: int64

In [14]:
(
    specimens_merged.drop_duplicates(["participant_label", stratify_by_column_name])[
        stratify_by_column_name
    ]
    .value_counts()
    .rename("number of participants")
)

Healthy/Background    220
HIV                    95
T1D                    92
Lupus                  86
Covid19                63
Influenza              37
Name: number of participants, dtype: int64

In [15]:
tmp = (
    specimens_merged.drop_duplicates(["participant_label", "disease_subtype"])[
        "disease_subtype"
    ]
    .value_counts()
    .rename("number of participants")
)

tmp.to_csv(
    config.paths.base_output_dir_for_selected_cross_validation_strategy
    / "all_data_combined.participants_by_disease_subtype.tsv",
    sep="\t",
)

tmp

T1D - pediatric                                       63
Healthy/Background - CMV-                             53
HIV Non Neutralizing                                  50
Healthy/Background - CMV+                             48
HIV Broad Neutralizing                                45
Healthy/Background - HIV Negative                     43
Healthy/Background (children)                         43
Influenza vaccine 2021 - day 7                        37
Covid19 - Admit                                       33
T1D - adult                                           29
Healthy/Background - SLE Negative                     23
Pediatric SLE - nephritis                             23
SLE Patient                                           20
Pediatric SLE - no nephritis                          20
Covid19 - ICU                                         15
SLE Multiple aAbs                                      7
Covid19 - Sero-positive (ICU)                          7
Unaffected Control             

In [16]:
tmp = (
    specimens_merged.groupby(["disease", "data_source"], observed=True)[
        "total_sequence_count"
    ]
    .sum()
    .reset_index(name="number of sequences")
)

tmp.to_csv(
    config.paths.base_output_dir_for_selected_cross_validation_strategy
    / "all_data_combined.sequences_by_disease_and_cohort.tsv",
    sep="\t",
    index=None,
)

tmp.sort_values("number of sequences")

Unnamed: 0,disease,data_source,number of sequences
0,Covid19,DataSource.in_house,1773823
3,Influenza,DataSource.in_house,2088387
5,T1D,DataSource.in_house,3481423
1,HIV,DataSource.in_house,5888513
4,Lupus,DataSource.in_house,6938520
2,Healthy/Background,DataSource.in_house,19612409


In [17]:
tmp = (
    specimens_merged.groupby([stratify_by_column_name, "data_source"], observed=True)[
        "total_sequence_count"
    ]
    .sum()
    .reset_index(name="number of sequences")
)

tmp.sort_values("number of sequences")

Unnamed: 0,disease,data_source,number of sequences
0,Covid19,DataSource.in_house,1773823
3,Influenza,DataSource.in_house,2088387
5,T1D,DataSource.in_house,3481423
1,HIV,DataSource.in_house,5888513
4,Lupus,DataSource.in_house,6938520
2,Healthy/Background,DataSource.in_house,19612409


## Make CV splits

**Strategy:**

```
Full data →  [Test | Rest] x 3, producing 3 folds

In each fold:
Rest →  [ Train | Validation]
Train → [Train1 | Train2]
```

- Split on patients.
- Split into (train+validation) / test first with a 2:1 ratio (because 3 folds total). Every patient is in 1 test fold. Then split (train+validation) into train (2/3) and validation (1/3).
- Also create a "global fold" that has no test set, but has the same 2:1 train/validation ratio.
- Each train set is further subdivided into Train1 and Train2 (used for training model 3 rollup step)
- All splits are stratified by disease label (or other specified `stratify_by_column_name`)

**How to handle varying gene loci:**

* We want to include BCR+TCR samples as well as single-loci (e.g. BCR-only) samples
* All data is used for single loci models. For example, the BCR sequence model and the BCR-only metamodel will include specimens whether they're BCR+TCR or BCR-only. (Any TCR-only would be excluded)
* Only BCR+TCR samples are used in BCR+TCR metamodel. (The input BCR models will have been trained on any and all BCR data, but the second stage metamodel will be trained on only samples that have both BCR and TCR components.)

The wrong way to design the cross validation split: split patients up all together, regardless of if they are BCR-only, TCR-only, or BCR+TCR.

Example of why this is wrong: consider just one disease for now; suppose you have a BCR-only set of 3 patients and a BCR+TCR set of a different 3 patients. How do you split those 6 patients into 3 cross validation folds?

The wrong strategy might split as follows:

```
# wrong way - possible result
how many
included
patients
             fold 1              fold 2              fold 3
BCR/TCR      3 train/0 test      2 train/1 test      1 train/2 test
BCR only     1 train/2 test      2 train/1 test      3 train/0 test
```

The BCR-only vs BCR+TCR are not spread evenly - we have imbalanced folds.

**The right way to split would be: split BCR+TCR patients first, then separately split the BCR-only patients, and combine the resulting folds:**

```
# right way - result
how many
included
patients
             fold 1              fold 2              fold 3
BCR/TCR      2 train/1 test      2 train/1 test      2 train/1 test
BCR only     2 train/1 test      2 train/1 test      2 train/1 test
```

**Note: we need to respect `study_names_for_held_out_set` on the CrossValidationSplitStrategy**. If set, that reduces us to a single fold with a pre-determined test set. (We still need to make a random train/validation split that keeps each patient's data segregated to one or the other)

In [13]:
cols = [
    "participant_label",
    "disease",
    "data_source",
    "past_exposure",
    "disease.separate_past_exposures",
    "available_gene_loci",
    "study_name",
]
if stratify_by_column_name not in cols:
    cols.append(stratify_by_column_name)
unique_participants_all = (
    specimens_merged[cols]
    .drop_duplicates(subset=["participant_label"])
    .reset_index(drop=True)
)
unique_participants_all

Unnamed: 0,participant_label,disease,data_source,past_exposure,disease.separate_past_exposures,available_gene_loci,study_name
0,BFI-0000234,Healthy/Background,DataSource.in_house,False,Healthy/Background,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV
1,BFI-0000254,HIV,DataSource.in_house,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV
2,BFI-0000255,HIV,DataSource.in_house,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV
3,BFI-0000256,HIV,DataSource.in_house,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV
4,BFI-0000258,HIV,DataSource.in_house,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",HIV
...,...,...,...,...,...,...,...
588,BFI-0010791,T1D,DataSource.in_house,False,T1D,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Diabetes biobank
589,BFI-0010792,T1D,DataSource.in_house,False,T1D,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Diabetes biobank
590,BFI-0010794,T1D,DataSource.in_house,False,T1D,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Diabetes biobank
591,BFI-0010800,T1D,DataSource.in_house,False,T1D,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Diabetes biobank


In [14]:
def make_splits(participants, n_splits):
    # preserve imbalanced class distribution in each fold
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

    folds = list(
        enumerate(
            cv.split(
                participants["participant_label"],
                participants[stratify_by_column_name],
            )
        )
    )
    return folds

In [15]:
def make_splits_single(participants: pd.DataFrame, test_proportion: float):
    # Single separation. Returns train and test integer indices into the participants dataframe.
    return train_test_split(
        np.arange(participants.shape[0]),
        test_size=test_proportion,
        random_state=0,
        shuffle=True,
        # preserve imbalanced class distribution in each fold
        stratify=participants[stratify_by_column_name],
    )

In [16]:
fold_participants = []


for gene_locus, unique_participants in unique_participants_all.groupby(
    "available_gene_loci", observed=True, sort=False
):
    print(
        f"{gene_locus}: total number of unique participants {unique_participants.shape[0]}"
    )

    if not config.cross_validation_split_strategy.value.is_single_fold_only:
        # Default case
        trainbig_test_splits = make_splits(unique_participants, n_splits=config.n_folds)
    else:
        # Special case: single fold with a pre-determined held-out test set (see study_names_for_held_out_set docs).
        held_out_bool_array = unique_participants["study_name"].isin(
            config.cross_validation_split_strategy.value.study_names_for_held_out_set
        )
        trainbig_test_splits = [
            (0, (np.where(~held_out_bool_array)[0], np.where(held_out_bool_array)[0]))
        ]

    assert len(trainbig_test_splits) == config.n_folds

    for (
        fold_id,
        (trainbig_participant_index, test_participant_index),
    ) in trainbig_test_splits:
        trainbig_participants = unique_participants.iloc[
            trainbig_participant_index
        ].reset_index(drop=True)
        test_participants = unique_participants.iloc[
            test_participant_index
        ].reset_index(drop=True)

        # confirm each patient is entirely on one or the other side of the train-big vs test divide
        assert (
            len(
                set(trainbig_participants["participant_label"]).intersection(
                    set(test_participants["participant_label"])
                )
            )
            == 0
        )

        # split trainbig into trainsmaller + validation
        train_smaller_index, validation_index = make_splits_single(
            trainbig_participants, test_proportion=1 / 3
        )
        train_smaller_participants = trainbig_participants.iloc[
            train_smaller_index
        ].reset_index(drop=True)
        validation_participants = trainbig_participants.iloc[
            validation_index
        ].reset_index(drop=True)

        # confirm each patient is entirely on one or the other side of the train-smaller vs validation divide
        assert (
            len(
                set(train_smaller_participants["participant_label"]).intersection(
                    set(validation_participants["participant_label"])
                )
            )
            == 0
        )

        # split train-smaller into train-smaller1 and train-smaller2
        train_smaller1_indices, train_smaller2_indices = make_splits_single(
            train_smaller_participants, test_proportion=1 / 3
        )
        train_smaller1_participants = train_smaller_participants.iloc[
            train_smaller1_indices
        ]
        train_smaller2_participants = train_smaller_participants.iloc[
            train_smaller2_indices
        ]
        # confirm each patient is entirely on one or the other side of the train-smaller1 vs train-smaller2 divide
        assert (
            len(
                set(train_smaller1_participants["participant_label"]).intersection(
                    set(train_smaller2_participants["participant_label"])
                )
            )
            == 0
        )

        # get list of participant labels
        for participants, fold_label in zip(
            [
                train_smaller_participants,
                train_smaller1_participants,
                train_smaller2_participants,
                validation_participants,
                test_participants,
            ],
            ["train_smaller", "train_smaller1", "train_smaller2", "validation", "test"],
        ):
            fold_participants.append(
                pd.DataFrame(
                    {"participant_label": participants["participant_label"].unique()}
                ).assign(fold_id=fold_id, fold_label=fold_label)
            )

    # also create global fold
    if config.use_global_fold:
        # split entire set into train_smaller + validation, aiming for same % split of those as we did when carving up train_big
        train_smaller_index, validation_index = make_splits_single(
            unique_participants, test_proportion=1 / 3
        )
        train_smaller_participants = unique_participants.iloc[
            train_smaller_index
        ].reset_index(drop=True)
        validation_participants = unique_participants.iloc[
            validation_index
        ].reset_index(drop=True)

        # confirm each patient is entirely on one or the other side of the train-smaller vs validation divide
        assert (
            len(
                set(train_smaller_participants["participant_label"]).intersection(
                    set(validation_participants["participant_label"])
                )
            )
            == 0
        )

        # split train-smaller into train-smaller1 and train-smaller2
        train_smaller1_indices, train_smaller2_indices = make_splits_single(
            train_smaller_participants, test_proportion=1 / 3
        )
        train_smaller1_participants = train_smaller_participants.iloc[
            train_smaller1_indices
        ]
        train_smaller2_participants = train_smaller_participants.iloc[
            train_smaller2_indices
        ]
        # confirm each patient is entirely on one or the other side of the train-smaller1 vs train-smaller2 divide
        assert (
            len(
                set(train_smaller1_participants["participant_label"]).intersection(
                    set(train_smaller2_participants["participant_label"])
                )
            )
            == 0
        )

        # get list of participant labels
        for participants, fold_label in zip(
            [
                train_smaller_participants,
                train_smaller1_participants,
                train_smaller2_participants,
                validation_participants,
            ],
            ["train_smaller", "train_smaller1", "train_smaller2", "validation"],
        ):
            fold_participants.append(
                pd.DataFrame(
                    {"participant_label": participants["participant_label"].unique()}
                ).assign(fold_id=-1, fold_label=fold_label)
            )


fold_participants = pd.concat(fold_participants, axis=0)
fold_participants

GeneLocus.BCR|TCR: total number of unique participants 542
GeneLocus.BCR: total number of unique participants 51


Unnamed: 0,participant_label,fold_id,fold_label
0,BFI-0009836,0,train_smaller
1,BFI-0009806,0,train_smaller
2,BFI-0002863,0,train_smaller
3,BFI-0009839,0,train_smaller
4,BFI-0003470,0,train_smaller
...,...,...,...
12,BFI-0005402,-1,validation
13,BFI-0005462,-1,validation
14,BFI-0005438,-1,validation
15,BFI-0005436,-1,validation


In [17]:
# sanity checks:

# each participant is in each fold (either in train_smaller, validation, or test - ignore the further subdivisions of train_smaller)
assert all(
    fold_participants[
        ~(fold_participants["fold_label"].isin(["train_smaller1", "train_smaller2"]))
    ]
    .groupby("participant_label")
    .size()
    == config.n_folds_including_global_fold
)

# within the cross validation scheme, each participant is in two non-test sets
# (i.e. shows up either in train_smaller or validation twice).
# ignore the further subdivisions of train_smaller for this check.
assert all(
    fold_participants[
        (fold_participants["fold_id"] != -1)
        & ~(
            fold_participants["fold_label"].isin(
                ["test", "train_smaller1", "train_smaller2"]
            )
        )
    ]
    .groupby("participant_label")
    .size()
    # special case for single fold due to study_names_for_held_out_set restriction
    == (config.n_folds - 1 if config.n_folds > 1 else 1)
)

# within the cross validation scheme, each participant is in one test set
assert all(
    fold_participants[
        (fold_participants["fold_id"] != -1)
        & (fold_participants["fold_label"] == "test")
    ]
    .groupby("participant_label")
    .size()
    == 1
)

In [18]:
assert (
    "fold_id" not in specimens_merged.columns
    and "fold_label" not in specimens_merged.columns
)

In [19]:
specimens_by_fold = pd.merge(
    specimens_merged, fold_participants, on="participant_label", how="inner"
)
specimens_by_fold

Unnamed: 0,participant_label,specimen_label,disease,total_sequence_count,specimen_time_point,participant_description,data_source,study_name,available_gene_loci,disease_subtype,...,symptoms_healthy_in_resequencing_experiment,specimen_time_point_days,survived_filters,is_selected_for_cv_strategy,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup,fold_id,fold_label
0,BFI-0000234,M124-S014,Healthy/Background,67139,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,...,,,True,True,True,False,Healthy/Background,Healthy/Background,0,validation
1,BFI-0000234,M124-S014,Healthy/Background,67139,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,...,,,True,True,True,False,Healthy/Background,Healthy/Background,1,test
2,BFI-0000234,M124-S014,Healthy/Background,67139,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,...,,,True,True,True,False,Healthy/Background,Healthy/Background,2,validation
3,BFI-0000234,M124-S014,Healthy/Background,67139,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,...,,,True,True,True,False,Healthy/Background,Healthy/Background,-1,train_smaller
4,BFI-0000234,M124-S014,Healthy/Background,67139,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,...,,,True,True,True,False,Healthy/Background,Healthy/Background,-1,train_smaller1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3690,BFI-0010806,M491-S167,T1D,9625,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,True,True,True,False,T1D,T1D,0,train_smaller1
3691,BFI-0010806,M491-S167,T1D,9625,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,True,True,True,False,T1D,T1D,1,test
3692,BFI-0010806,M491-S167,T1D,9625,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,True,True,True,False,T1D,T1D,2,train_smaller
3693,BFI-0010806,M491-S167,T1D,9625,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,...,,,True,True,True,False,T1D,T1D,2,train_smaller1


In [20]:
# sanity check:
# one entry per specimen per fold (either in train_smaller, validation, or test - ignore the further subdivisions of train_smaller)
assert (
    specimens_by_fold[
        ~(specimens_by_fold["fold_label"].isin(["train_smaller1", "train_smaller2"]))
    ].shape[0]
    == specimens_merged.shape[0] * config.n_folds_including_global_fold
)

In [21]:
# sanity checks:

# each specimen is in each fold, either in train_smaller, validation, or test (ignore the further subdivisions of train_smaller)
assert all(
    specimens_by_fold[
        ~(specimens_by_fold["fold_label"].isin(["train_smaller1", "train_smaller2"]))
    ]
    .groupby("specimen_label")
    .size()
    == config.n_folds_including_global_fold
)

# within the cross validation scheme, each specimen is in two non-test sets
# (i.e. shows up either in train_smaller or validation twice).
# ignore the further subdivisions of train_smaller for this check.
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & ~(
            specimens_by_fold["fold_label"].isin(
                ["test", "train_smaller1", "train_smaller2"]
            )
        )
    ]
    .groupby("specimen_label")
    .size()
    # special case for single fold due to study_names_for_held_out_set restriction
    == (config.n_folds - 1 if config.n_folds > 1 else 1)
)

# within the cross validation scheme, each specimen is in one test set
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] == "test")
    ]
    .groupby("specimen_label")
    .size()
    == 1
)

In [22]:
# sanity checks:

# each participant is in each fold, in one group or another
assert all(
    specimens_by_fold.groupby("participant_label")["fold_id"].nunique()
    == config.n_folds_including_global_fold
)

# within the cross validation scheme, each participant is in two non-test sets
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] != "test")
    ]
    .groupby("participant_label")["fold_id"]
    .nunique()
    # special case for single fold due to study_names_for_held_out_set restriction
    == (config.n_folds - 1 if config.n_folds > 1 else 1)
)

# within the cross validation scheme, each participant is in one test set
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] == "test")
    ]
    .groupby("participant_label")["fold_id"]
    .nunique()
    == 1
)

In [23]:
for (fold_id, fold_label), grp in specimens_by_fold.groupby(
    ["fold_id", "fold_label"], observed=True
):
    print(f"Fold {fold_id}-{fold_label}:")
    display(
        pd.DataFrame(
            [
                grp.groupby("disease")["participant_label"]
                .nunique()
                .rename("#participants"),
                grp.groupby("disease")["specimen_label"].nunique().rename("#specimens"),
            ]
        )
    )

    print()

Fold -1-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,42,63,146,25,58,61
#specimens,42,64,148,25,67,63



Fold -1-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,28,42,97,17,38,40
#specimens,28,43,98,17,45,41



Fold -1-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,14,21,49,8,20,21
#specimens,14,21,50,8,22,22



Fold -1-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,21,32,74,12,28,31
#specimens,21,34,76,12,31,33



Fold 0-test:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,20,32,74,12,29,31
#specimens,20,33,76,12,33,32



Fold 0-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,28,42,97,17,38,40
#specimens,28,44,99,17,46,43



Fold 0-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,18,28,64,11,26,27
#specimens,18,28,66,11,31,28



Fold 0-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,10,14,33,6,12,13
#specimens,10,16,33,6,15,15



Fold 0-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,15,21,49,8,19,21
#specimens,15,21,49,8,19,21



Fold 1-test:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,22,31,73,12,29,31
#specimens,22,33,75,12,34,34



Fold 1-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,27,42,97,17,38,41
#specimens,27,43,99,17,40,42



Fold 1-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,18,28,64,11,26,27
#specimens,18,28,64,11,28,28



Fold 1-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,9,14,33,6,12,14
#specimens,9,15,35,6,12,14



Fold 1-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,14,22,50,8,19,20
#specimens,14,22,50,8,24,20



Fold 2-test:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,21,32,73,13,28,30
#specimens,21,32,73,13,31,30



Fold 2-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,28,42,98,16,38,41
#specimens,28,45,100,16,45,43



Fold 2-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,18,28,64,11,26,27
#specimens,18,31,65,11,31,28



Fold 2-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,10,14,34,5,12,14
#specimens,10,14,35,5,14,15



Fold 2-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,14,21,49,8,20,21
#specimens,14,21,51,8,22,23





In [24]:
for (fold_id, fold_label), grp in specimens_by_fold.groupby(
    ["fold_id", "fold_label"], observed=True
):
    print(f"Fold {fold_id}-{fold_label}:")
    display(
        pd.DataFrame(
            [
                grp.groupby(stratify_by_column_name)["participant_label"]
                .nunique()
                .rename("#participants"),
                grp.groupby(stratify_by_column_name)["specimen_label"]
                .nunique()
                .rename("#specimens"),
            ]
        )
    )

    print()

Fold -1-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,42,63,146,25,58,61
#specimens,42,64,148,25,67,63



Fold -1-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,28,42,97,17,38,40
#specimens,28,43,98,17,45,41



Fold -1-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,14,21,49,8,20,21
#specimens,14,21,50,8,22,22



Fold -1-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,21,32,74,12,28,31
#specimens,21,34,76,12,31,33



Fold 0-test:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,20,32,74,12,29,31
#specimens,20,33,76,12,33,32



Fold 0-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,28,42,97,17,38,40
#specimens,28,44,99,17,46,43



Fold 0-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,18,28,64,11,26,27
#specimens,18,28,66,11,31,28



Fold 0-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,10,14,33,6,12,13
#specimens,10,16,33,6,15,15



Fold 0-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,15,21,49,8,19,21
#specimens,15,21,49,8,19,21



Fold 1-test:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,22,31,73,12,29,31
#specimens,22,33,75,12,34,34



Fold 1-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,27,42,97,17,38,41
#specimens,27,43,99,17,40,42



Fold 1-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,18,28,64,11,26,27
#specimens,18,28,64,11,28,28



Fold 1-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,9,14,33,6,12,14
#specimens,9,15,35,6,12,14



Fold 1-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,14,22,50,8,19,20
#specimens,14,22,50,8,24,20



Fold 2-test:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,21,32,73,13,28,30
#specimens,21,32,73,13,31,30



Fold 2-train_smaller:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,28,42,98,16,38,41
#specimens,28,45,100,16,45,43



Fold 2-train_smaller1:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,18,28,64,11,26,27
#specimens,18,31,65,11,31,28



Fold 2-train_smaller2:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,10,14,34,5,12,14
#specimens,10,14,35,5,14,15



Fold 2-validation:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,14,21,49,8,20,21
#specimens,14,21,51,8,22,23





In [25]:
# By gene locus
# Nest because can't sort on gene locus column
for (fold_id, fold_label), _grp in specimens_by_fold.groupby(
    ["fold_id", "fold_label"], observed=True
):
    for gene_locus, grp in _grp.groupby(
        "available_gene_loci", observed=True, sort=False
    ):
        print(f"Fold {fold_id}-{fold_label}-{gene_locus}:")

        display(
            pd.DataFrame(
                [
                    grp.groupby("disease")["participant_label"]
                    .nunique()
                    .rename("#participants"),
                    grp.groupby("disease")["specimen_label"]
                    .nunique()
                    .rename("#specimens"),
                ]
            )
        )

        print()

Fold -1-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,39,63,131,25,42,61
#specimens,39,64,131,25,43,63



Fold -1-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,3,15,16
#specimens,3,17,24



Fold -1-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,26,42,87,17,28,40
#specimens,26,43,87,17,29,41



Fold -1-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,11,16



Fold -1-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,21,44,8,14,21
#specimens,13,21,44,8,14,22



Fold -1-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,5,6
#specimens,1,6,8



Fold -1-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,19,32,66,12,21,31
#specimens,19,34,66,12,21,33



Fold -1-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,8,7
#specimens,2,10,10



Fold 0-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,19,32,66,12,21,31
#specimens,19,33,66,12,22,32



Fold 0-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,8,8
#specimens,1,10,11



Fold 0-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,26,42,87,17,28,40
#specimens,26,44,87,17,28,43



Fold 0-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,18



Fold 0-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,17,28,58,11,19,27
#specimens,17,28,58,11,19,28



Fold 0-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,7
#specimens,1,8,12



Fold 0-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,9,14,29,6,9,13
#specimens,9,16,29,6,9,15



Fold 0-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,4,3
#specimens,1,4,6



Fold 0-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,21,44,8,14,21
#specimens,13,21,44,8,14,21



Fold 0-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,5,5
#specimens,2,5,5



Fold 1-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,20,31,66,12,21,31
#specimens,20,33,66,12,21,34



Fold 1-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,7,8
#specimens,2,9,13



Fold 1-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,25,42,87,17,28,41
#specimens,25,43,87,17,28,42



Fold 1-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,12



Fold 1-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,17,28,58,11,19,27
#specimens,17,28,58,11,19,28



Fold 1-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,7
#specimens,1,6,9



Fold 1-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,8,14,29,6,9,14
#specimens,8,15,29,6,9,14



Fold 1-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,4,3
#specimens,1,6,3



Fold 1-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,22,44,8,14,20
#specimens,13,22,44,8,15,20



Fold 1-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,5
#specimens,1,6,9



Fold 2-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,19,32,65,13,21,30
#specimens,19,32,65,13,21,30



Fold 2-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,8,7
#specimens,2,8,10



Fold 2-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,26,42,88,16,28,41
#specimens,26,45,88,16,28,43



Fold 2-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,17



Fold 2-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,17,28,58,11,19,27
#specimens,17,31,58,11,19,28



Fold 2-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,7
#specimens,1,7,12



Fold 2-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,9,14,30,5,9,14
#specimens,9,14,30,5,9,15



Fold 2-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,4,3
#specimens,1,5,5



Fold 2-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,21,44,8,14,21
#specimens,13,21,44,8,15,23



Fold 2-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,5,6
#specimens,1,7,7





In [26]:
# By gene locus
# Nest because can't sort on gene locus column
for (fold_id, fold_label), _grp in specimens_by_fold.groupby(
    ["fold_id", "fold_label"], observed=True
):
    for gene_locus, grp in _grp.groupby(
        "available_gene_loci", observed=True, sort=False
    ):
        print(f"Fold {fold_id}-{fold_label}-{gene_locus}:")

        display(
            pd.DataFrame(
                [
                    grp.groupby(stratify_by_column_name)["participant_label"]
                    .nunique()
                    .rename("#participants"),
                    grp.groupby(stratify_by_column_name)["specimen_label"]
                    .nunique()
                    .rename("#specimens"),
                ]
            )
        )

        print()

Fold -1-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,39,63,131,25,42,61
#specimens,39,64,131,25,43,63



Fold -1-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,3,15,16
#specimens,3,17,24



Fold -1-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,26,42,87,17,28,40
#specimens,26,43,87,17,29,41



Fold -1-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,11,16



Fold -1-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,21,44,8,14,21
#specimens,13,21,44,8,14,22



Fold -1-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,5,6
#specimens,1,6,8



Fold -1-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,19,32,66,12,21,31
#specimens,19,34,66,12,21,33



Fold -1-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,8,7
#specimens,2,10,10



Fold 0-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,19,32,66,12,21,31
#specimens,19,33,66,12,22,32



Fold 0-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,8,8
#specimens,1,10,11



Fold 0-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,26,42,87,17,28,40
#specimens,26,44,87,17,28,43



Fold 0-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,18



Fold 0-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,17,28,58,11,19,27
#specimens,17,28,58,11,19,28



Fold 0-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,7
#specimens,1,8,12



Fold 0-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,9,14,29,6,9,13
#specimens,9,16,29,6,9,15



Fold 0-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,4,3
#specimens,1,4,6



Fold 0-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,21,44,8,14,21
#specimens,13,21,44,8,14,21



Fold 0-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,5,5
#specimens,2,5,5



Fold 1-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,20,31,66,12,21,31
#specimens,20,33,66,12,21,34



Fold 1-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,7,8
#specimens,2,9,13



Fold 1-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,25,42,87,17,28,41
#specimens,25,43,87,17,28,42



Fold 1-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,12



Fold 1-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,17,28,58,11,19,27
#specimens,17,28,58,11,19,28



Fold 1-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,7
#specimens,1,6,9



Fold 1-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,8,14,29,6,9,14
#specimens,8,15,29,6,9,14



Fold 1-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,4,3
#specimens,1,6,3



Fold 1-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,22,44,8,14,20
#specimens,13,22,44,8,15,20



Fold 1-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,5
#specimens,1,6,9



Fold 2-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,19,32,65,13,21,30
#specimens,19,32,65,13,21,30



Fold 2-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,8,7
#specimens,2,8,10



Fold 2-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,26,42,88,16,28,41
#specimens,26,45,88,16,28,43



Fold 2-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,17



Fold 2-train_smaller1-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,17,28,58,11,19,27
#specimens,17,31,58,11,19,28



Fold 2-train_smaller1-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,7
#specimens,1,7,12



Fold 2-train_smaller2-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,9,14,30,5,9,14
#specimens,9,14,30,5,9,15



Fold 2-train_smaller2-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,4,3
#specimens,1,5,5



Fold 2-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
#participants,13,21,44,8,14,21
#specimens,13,21,44,8,15,23



Fold 2-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,5,6
#specimens,1,7,7





In [27]:
fold_participants.to_csv(
    config.paths.dataset_specific_metadata_for_selected_cross_validation_strategy
    / "cross_validation_divisions.participants.tsv",
    sep="\t",
    index=None,
)

In [28]:
specimens_by_fold.to_csv(
    config.paths.dataset_specific_metadata_for_selected_cross_validation_strategy
    / "cross_validation_divisions.specimens.tsv.gz",
    sep="\t",
    index=None,
)