# Filter to peak timepoints + Make cv divisions.

Already removed specimens with very few sequences or without all isotypes. Already sampled one sequence per clone per isotype per specimen.

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from malid import config, helpers, logger

In [2]:
import dask
import dask.dataframe as dd

In [3]:
from dask.distributed import Client

# multi-processing backend
# access dashbaord at http://127.0.0.1:61083
client = Client(
    scheduler_port=61084,
    dashboard_address=":61083",
    n_workers=4,
    processes=True,
    threads_per_worker=8,
    memory_limit="125GB",  # per worker
)
display(client)
# for debugging: client.restart()

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:61083/status,

0,1
Dashboard: http://127.0.0.1:61083/status,Workers: 4
Total threads: 32,Total memory: 465.66 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:61084,Workers: 4
Dashboard: http://127.0.0.1:61083/status,Total threads: 32
Started: Just now,Total memory: 465.66 GiB

0,1
Comm: tcp://127.0.0.1:33177,Total threads: 8
Dashboard: http://127.0.0.1:43873/status,Memory: 116.42 GiB
Nanny: tcp://127.0.0.1:33201,
Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-wgccqmbl,Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-wgccqmbl
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB

0,1
Comm: tcp://127.0.0.1:41177,Total threads: 8
Dashboard: http://127.0.0.1:38119/status,Memory: 116.42 GiB
Nanny: tcp://127.0.0.1:34641,
Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-1e4pq84b,Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-1e4pq84b
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB

0,1
Comm: tcp://127.0.0.1:46625,Total threads: 8
Dashboard: http://127.0.0.1:36431/status,Memory: 116.42 GiB
Nanny: tcp://127.0.0.1:42079,
Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-6go1wpww,Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-6go1wpww
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB

0,1
Comm: tcp://127.0.0.1:34099,Total threads: 8
Dashboard: http://127.0.0.1:34145/status,Memory: 116.42 GiB
Nanny: tcp://127.0.0.1:38407,
Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-l7s129eg,Local directory: /users/maximz/code/boyd-immune-repertoire-classification/notebooks/dask-worker-space/worker-l7s129eg
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB


In [4]:
# Don't use fastparquet, because it changes specimen labels like M54-049 to 2049-01-01 00:00:54 -- i.e. it coerces partition names to numbers or dates
df = dd.read_parquet(config.paths.sequences_sampled, engine="pyarrow")
df

Unnamed: 0_level_0,v_gene,j_gene,disease,disease_subtype,cdr1_seq_aa_q_trim,cdr2_seq_aa_q_trim,cdr3_seq_aa_q_trim,cdr3_aa_sequence_trim_len,extracted_isotype,isotype_supergroup,v_mut,num_reads,igh_or_tcrb_clone_id,total_clone_num_reads,num_clone_members,participant_label,specimen_label
npartitions=522,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
,category[unknown],category[unknown],category[unknown],category[unknown],object,object,object,int64,category[unknown],category[unknown],float64,int64,int64,int64,int64,category[known],category[known]
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [5]:
# each partition is a specimen
df.npartitions

522

In [6]:
df.columns

Index(['v_gene', 'j_gene', 'disease', 'disease_subtype', 'cdr1_seq_aa_q_trim',
       'cdr2_seq_aa_q_trim', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len',
       'extracted_isotype', 'isotype_supergroup', 'v_mut', 'num_reads',
       'igh_or_tcrb_clone_id', 'total_clone_num_reads', 'num_clone_members',
       'participant_label', 'specimen_label'],
      dtype='object')

In [7]:
# groupby participant, specimen, disease - get total sequence count
specimens = (
    df.groupby(
        ["participant_label", "specimen_label", "disease", "disease_subtype"],
        observed=True,
    )
    .size()
    .rename("total_sequence_count")
    .reset_index()
)
specimens

Unnamed: 0_level_0,participant_label,specimen_label,disease,disease_subtype,total_sequence_count
npartitions=1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
,category[known],category[known],category[unknown],category[unknown],int64
,...,...,...,...,...


In [8]:
specimens = specimens.compute()
specimens

Unnamed: 0,participant_label,specimen_label,disease,disease_subtype,total_sequence_count
0,BFI-0000234,M124-S014,Healthy/Background,Healthy/Background - HIV Negative,67707
1,BFI-0000254,M111-S037,HIV,HIV Broad Neutralizing,74203
2,BFI-0000255,M111-S033,HIV,HIV Broad Neutralizing,69648
3,BFI-0000256,M111-S038,HIV,HIV Broad Neutralizing,100383
4,BFI-0000258,M111-S055,HIV,HIV Non Neutralizing,106357
...,...,...,...,...,...
517,BFI-0010240,M464-S041,Healthy/Background,Healthy/Background (children),16518
518,BFI-0010241,M464-S042,Healthy/Background,Healthy/Background (children),92181
519,BFI-0010243,M464-S044,Healthy/Background,Healthy/Background (children),132934
520,BFI-0010244,M464-S045,Healthy/Background,Healthy/Background (children),62354


In [9]:
assert specimens.shape[0] == df.npartitions

In [10]:
assert not specimens["specimen_label"].duplicated().any()

In [11]:
# export list of specimens included in this full anndata
# not all specimens survived to this step - some are thrown out in the run_embedding notebooks for not having enough sequences or not having all isotypes
# but these aren't yet filtered to is_peak timepoints
specimens.to_csv(
    config.paths.dataset_specific_metadata / "specimens_kept_in_embedding_anndatas.tsv",
    sep="\t",
    index=None,
)

In [12]:
# filter to is_peak timepoints, in addition to the is_valid / survived_filters filter already here
specimen_metadata = helpers.get_all_specimen_info(
    # CV fold is not available yet, so must set this to False
    add_cv_fold_information=False
).sort_values("disease")
specimen_metadata = specimen_metadata[specimen_metadata["cohort"] == "Boydlab"]
specimen_metadata = specimen_metadata[specimen_metadata["in_training_set"]]

# sanity check the definitions
assert specimen_metadata["is_peak"].all()
assert specimen_metadata["survived_filters"].all()

specimen_metadata

Unnamed: 0,participant_label,specimen_label,disease,specimen_time_point,participant_description,cohort,study_name,available_gene_loci,disease_subtype,age,...,age_group_pediatric,cmv,disease_severity,specimen_time_point_days,survived_filters,is_peak,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup
0,BFI-0007450,M369-S001,Covid19,9 days,COVID-19 project,Boydlab,Covid19-buffycoat,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - Sero-positive (ICU),73.0,...,18+,,ICU,9.0,True,True,True,False,Covid19,Covid19
83,BFI-0009122,M418-S001,Covid19,10 days,COVID-19 sample from Sam Yang,Boydlab,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - ICU,63.0,...,18+,,ICU,10.0,True,True,True,False,Covid19,Covid19
82,BFI-0009121,M418-S193,Covid19,11 days,COVID-19 sample from Sam Yang,Boydlab,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - Admit,43.0,...,18+,,Admit,11.0,True,True,True,False,Covid19,Covid19
81,BFI-0009120,M418-S192,Covid19,12 days,COVID-19 sample from Sam Yang,Boydlab,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - Admit,25.0,...,18+,,Admit,12.0,True,True,True,False,Covid19,Covid19
80,BFI-0009112,M418-S184,Covid19,11 days,COVID-19 sample from Sam Yang,Boydlab,Covid19-Stanford,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Covid19 - Admit,86.0,...,,,Admit,11.0,True,True,True,False,Covid19,Covid19
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
629,BFI-0005434,M281redo-S035,Lupus,,formerly recorded as BFI-0000763,Boydlab,Lupus,(((GeneLocus.BCR))),SLE One aAbs / SLE dsDNA WITHOUT Nephritis,50.0,...,18+,,,,True,True,True,False,Lupus,Lupus
626,BFI-0005428,M281redo-S029,Lupus,,formerly recorded as BFI-0000781,Boydlab,Lupus,(((GeneLocus.BCR))),SLE Multiple aAbs / SLE dsDNA WITH Nephritis,26.0,...,18+,,,,True,True,True,False,Lupus,Lupus
625,BFI-0005425,M281redo-S026,Lupus,,formerly recorded as BFI-0000814,Boydlab,Lupus,(((GeneLocus.BCR))),SLE One aAbs,32.0,...,18+,,,,True,True,True,False,Lupus,Lupus
638,BFI-0009807,M447-S038,Lupus,00:00:00,Pediatric SLE - no nephritis,Boydlab,Lupus Pediatric,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Pediatric SLE - no nephritis,12.0,...,under 18,,,0.0,True,True,True,False,Lupus,Lupus


In [13]:
# merge back to apply the filter
specimens_merged = pd.merge(
    specimens,
    specimen_metadata,
    on=["participant_label", "specimen_label", "disease", "disease_subtype"],
    how="inner",
    validate="1:1",
)
assert specimens_merged.shape[0] == specimen_metadata.shape[0] <= specimens.shape[0]
specimens_merged

Unnamed: 0,participant_label,specimen_label,disease,disease_subtype,total_sequence_count,specimen_time_point,participant_description,cohort,study_name,available_gene_loci,...,age_group_pediatric,cmv,disease_severity,specimen_time_point_days,survived_filters,is_peak,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup
0,BFI-0000234,M124-S014,Healthy/Background,Healthy/Background - HIV Negative,67707,,Location: USA,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,18+,,,,True,True,True,False,Healthy/Background,Healthy/Background
1,BFI-0000254,M111-S037,HIV,HIV Broad Neutralizing,74203,,Location: Tanzania,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,18+,,,,True,True,True,False,HIV,HIV
2,BFI-0000255,M111-S033,HIV,HIV Broad Neutralizing,69648,,Location: Tanzania,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,18+,,,,True,True,True,False,HIV,HIV
3,BFI-0000256,M111-S038,HIV,HIV Broad Neutralizing,100383,,Location: Tanzania,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,18+,,,,True,True,True,False,HIV,HIV
4,BFI-0000258,M111-S055,HIV,HIV Non Neutralizing,106357,,Location: Tanzania,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,18+,,,,True,True,True,False,HIV,HIV
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,BFI-0010240,M464-S041,Healthy/Background,Healthy/Background (children),16518,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,under 18,,,,True,True,True,False,Healthy/Background,Healthy/Background
476,BFI-0010241,M464-S042,Healthy/Background,Healthy/Background (children),92181,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,under 18,,,,True,True,True,False,Healthy/Background,Healthy/Background
477,BFI-0010243,M464-S044,Healthy/Background,Healthy/Background (children),132934,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,under 18,,,,True,True,True,False,Healthy/Background,Healthy/Background
478,BFI-0010244,M464-S045,Healthy/Background,Healthy/Background (children),62354,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,under 18,,,,True,True,True,False,Healthy/Background,Healthy/Background


In [14]:
specimens_merged = specimens_merged.assign(cohort="Boydlab")

In [15]:
tmp = (
    specimens_merged.drop_duplicates(["participant_label", "disease_subtype"])[
        "disease_subtype"
    ]
    .value_counts()
    .rename("number of participants")
)

tmp.to_csv(
    f"{config.paths.base_output_dir}/all_data_combined.participants_by_disease_subtype.tsv",
    sep="\t",
)

tmp

Healthy/Background - CMV Negative                  53
HIV Non Neutralizing                               50
Healthy/Background - CMV Positive                  48
HIV Broad Neutralizing                             45
Healthy/Background - HIV Negative                  43
Healthy/Background (children)                      43
Covid19 - Admit                                    33
Pediatric SLE - nephritis                          23
Healthy/Background - SLE Negative                  23
Pediatric SLE - no nephritis                       20
SLE Patient                                        20
Covid19 - ICU                                      15
Covid19 - Sero-positive (ICU)                       7
SLE Multiple aAbs                                   7
Unaffected Control                                  6
SLE Multiple aAbs / SLE dsDNA WITHOUT Nephritis     5
Covid19 - Acute 2                                   5
SLE Multiple aAbs / SLE dsDNA WITH Nephritis        4
Covid19 - Sero-positive (Adm

In [16]:
tmp = (
    specimens_merged.groupby(["disease", "cohort"], observed=True)[
        "total_sequence_count"
    ]
    .sum()
    .reset_index(name="number of sequences")
)

tmp.to_csv(
    f"{config.paths.base_output_dir}/all_data_combined.sequences_by_disease_and_cohort.tsv",
    sep="\t",
    index=None,
)

tmp

Unnamed: 0,disease,cohort,number of sequences
0,Covid19,Boydlab,1783861
1,HIV,Boydlab,5927445
2,Healthy/Background,Boydlab,18912732
3,Lupus,Boydlab,6983523


# Make CV splits

**Strategy:**

- Split on patients.
- Split into (train+validation) / test first with a 2:1 ratio (because 3 folds total). Every patient is in 1 test fold. Then split (train+validation) into train (4/5) and validation (1/5).
- Also create a "global fold" that has no test set, but has the same 4:1 train/validation ratio.
- Stratified by disease label

**How to handle varying gene loci:**

* We want to include BCR+TCR samples as well as single-loci (e.g. BCR-only) samples
* All data is used for single loci models. For example, the BCR sequence model and the BCR-only metamodel will include specimens whether they're BCR+TCR or BCR-only. (Any TCR-only would be excluded)
* Only BCR+TCR samples are used in BCR+TCR metamodel. (The input BCR models will have been trained on any and all BCR data, but the second stage metamodel will be trained on only samples that have both BCR and TCR components.)

The wrong way to design the cross validation split: split patients up all together, regardless of if they are BCR-only, TCR-only, or BCR+TCR.

Example of why this is wrong: consider just one disease for now; suppose you have a BCR-only set of 3 patients and a BCR+TCR set of a different 3 patients. How do you split those 6 patients into 3 cross validation folds?

The wrong strategy might split as follows:

```
# wrong way - possible result
how many
included
patients
             fold 1              fold 2              fold 3
BCR/TCR      3 train/0 test      2 train/1 test      1 train/2 test
BCR only     1 train/2 test      2 train/1 test      3 train/0 test
```

The BCR-only vs BCR+TCR are not spread evenly - we have imbalanced folds.

**The right way to split would be: split BCR+TCR patients first, then separately split the BCR-only patients, and combine the resulting folds:**

```
# right way - result
how many
included
patients
             fold 1              fold 2              fold 3
BCR/TCR      2 train/1 test      2 train/1 test      2 train/1 test
BCR only     2 train/1 test      2 train/1 test      2 train/1 test
```



In [17]:
unique_participants_all = (
    specimens_merged[
        [
            "participant_label",
            "disease",
            "cohort",
            "past_exposure",
            "disease.separate_past_exposures",
            "available_gene_loci",
        ]
    ]
    .drop_duplicates()
    .reset_index(drop=True)
)
unique_participants_all

Unnamed: 0,participant_label,disease,cohort,past_exposure,disease.separate_past_exposures,available_gene_loci
0,BFI-0000234,Healthy/Background,Boydlab,False,Healthy/Background,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
1,BFI-0000254,HIV,Boydlab,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
2,BFI-0000255,HIV,Boydlab,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
3,BFI-0000256,HIV,Boydlab,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
4,BFI-0000258,HIV,Boydlab,False,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
...,...,...,...,...,...,...
456,BFI-0010240,Healthy/Background,Boydlab,False,Healthy/Background,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
457,BFI-0010241,Healthy/Background,Boydlab,False,Healthy/Background,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
458,BFI-0010243,Healthy/Background,Boydlab,False,Healthy/Background,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"
459,BFI-0010244,Healthy/Background,Boydlab,False,Healthy/Background,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))"


In [18]:
def make_splits(participants, n_splits):
    # preserve imbalanced class distribution in each fold
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

    folds = list(
        enumerate(
            cv.split(
                participants["participant_label"],
                participants["disease"],
            )
        )
    )
    return folds

In [19]:
fold_participants = []


for gene_locus, unique_participants in unique_participants_all.groupby(
    "available_gene_loci", observed=True, sort=False
):
    print(
        f"{gene_locus}: total number of unique participants {unique_participants.shape[0]}"
    )

    trainbig_test_splits = make_splits(unique_participants, n_splits=config.n_folds)
    assert len(trainbig_test_splits) == config.n_folds

    for (
        fold_id,
        (trainbig_participant_index, test_participant_index),
    ) in trainbig_test_splits:
        trainbig_participants = unique_participants.iloc[
            trainbig_participant_index
        ].reset_index(drop=True)
        test_participants = unique_participants.iloc[
            test_participant_index
        ].reset_index(drop=True)

        # confirm each patient is entirely on one or the other side of the train-big vs test divide
        assert (
            len(
                set(trainbig_participants["participant_label"].unique()).intersection(
                    set(test_participants["participant_label"].unique())
                )
            )
            == 0
        )

        # split trainbig into trainsmaller + validation
        mini_folds = make_splits(trainbig_participants, n_splits=3)
        # unpack first of the splits
        _, (train_smaller_index, validation_index) = mini_folds[0]

        train_smaller_participants = trainbig_participants.iloc[
            train_smaller_index
        ].reset_index(drop=True)
        validation_participants = trainbig_participants.iloc[
            validation_index
        ].reset_index(drop=True)

        # confirm each patient is entirely on one or the other side of the train-smaller vs validation divide
        assert (
            len(
                set(
                    train_smaller_participants["participant_label"].unique()
                ).intersection(
                    set(validation_participants["participant_label"].unique())
                )
            )
            == 0
        )

        # get list of participant labels
        for participants, fold_label in zip(
            [train_smaller_participants, validation_participants, test_participants],
            ["train_smaller", "validation", "test"],
        ):
            fold_participants.append(
                pd.DataFrame(
                    {"participant_label": participants["participant_label"].unique()}
                ).assign(fold_id=fold_id, fold_label=fold_label)
            )

    # also create global fold
    # split entire set into train_smaller + validation, aiming for same % split of those as we did when carving up train_big
    mini_folds = make_splits(unique_participants, n_splits=3)
    # unpack first of the splits
    _, (train_smaller_index, validation_index) = mini_folds[0]

    train_smaller_participants = unique_participants.iloc[
        train_smaller_index
    ].reset_index(drop=True)
    validation_participants = unique_participants.iloc[validation_index].reset_index(
        drop=True
    )

    # confirm each patient is entirely on one or the other side of the train-smaller vs validation divide
    assert (
        len(
            set(train_smaller_participants["participant_label"].unique()).intersection(
                set(validation_participants["participant_label"].unique())
            )
        )
        == 0
    )

    # get list of participant labels
    for participants, fold_label in zip(
        [train_smaller_participants, validation_participants],
        ["train_smaller", "validation"],
    ):
        fold_participants.append(
            pd.DataFrame(
                {"participant_label": participants["participant_label"].unique()}
            ).assign(fold_id=-1, fold_label=fold_label)
        )


fold_participants = pd.concat(fold_participants, axis=0)
fold_participants

GeneLocus.BCR|TCR: total number of unique participants 410
GeneLocus.BCR: total number of unique participants 51


Unnamed: 0,participant_label,fold_id,fold_label
0,BFI-0000255,0,train_smaller
1,BFI-0000258,0,train_smaller
2,BFI-0002852,0,train_smaller
3,BFI-0002857,0,train_smaller
4,BFI-0002859,0,train_smaller
...,...,...,...
12,BFI-0005453,-1,validation
13,BFI-0005455,-1,validation
14,BFI-0005457,-1,validation
15,BFI-0005460,-1,validation


In [20]:
# sanity checks:

# each participant is in each fold, in one group or another
assert all(fold_participants.groupby("participant_label").size() == config.n_folds + 1)

# within the cross validation scheme, each participant is in two non-test sets
assert all(
    fold_participants[
        (fold_participants["fold_id"] != -1)
        & (fold_participants["fold_label"] != "test")
    ]
    .groupby("participant_label")
    .size()
    == config.n_folds - 1
)

# within the cross validation scheme, each participant is in one test set
assert all(
    fold_participants[
        (fold_participants["fold_id"] != -1)
        & (fold_participants["fold_label"] == "test")
    ]
    .groupby("participant_label")
    .size()
    == 1
)

In [21]:
assert (
    "fold_id" not in specimens_merged.columns
    and "fold_label" not in specimens_merged.columns
)

In [22]:
specimens_by_fold = pd.merge(
    specimens_merged, fold_participants, on="participant_label", how="inner"
)
specimens_by_fold

Unnamed: 0,participant_label,specimen_label,disease,disease_subtype,total_sequence_count,specimen_time_point,participant_description,cohort,study_name,available_gene_loci,...,disease_severity,specimen_time_point_days,survived_filters,is_peak,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup,fold_id,fold_label
0,BFI-0000234,M124-S014,Healthy/Background,Healthy/Background - HIV Negative,67707,,Location: USA,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,0,validation
1,BFI-0000234,M124-S014,Healthy/Background,Healthy/Background - HIV Negative,67707,,Location: USA,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,1,test
2,BFI-0000234,M124-S014,Healthy/Background,Healthy/Background - HIV Negative,67707,,Location: USA,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,2,validation
3,BFI-0000234,M124-S014,Healthy/Background,Healthy/Background - HIV Negative,67707,,Location: USA,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,-1,train_smaller
4,BFI-0000254,M111-S037,HIV,HIV Broad Neutralizing,74203,,Location: Tanzania,Boydlab,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,HIV,HIV,0,test
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1915,BFI-0010244,M464-S045,Healthy/Background,Healthy/Background (children),62354,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,-1,validation
1916,BFI-0010245,M464-S046,Healthy/Background,Healthy/Background (children),48979,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,0,train_smaller
1917,BFI-0010245,M464-S046,Healthy/Background,Healthy/Background (children),48979,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,1,train_smaller
1918,BFI-0010245,M464-S046,Healthy/Background,Healthy/Background (children),48979,,,Boydlab,healthy_children,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",...,,,True,True,True,False,Healthy/Background,Healthy/Background,2,test


In [23]:
assert specimens_by_fold.shape[0] == specimens_merged.shape[0] * (config.n_folds + 1)

In [24]:
# sanity checks:

# each specimen is in each fold, in one group or another
assert all(specimens_by_fold.groupby("specimen_label").size() == config.n_folds + 1)

# within the cross validation scheme, each specimen is in two non-test sets
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] != "test")
    ]
    .groupby("specimen_label")
    .size()
    == config.n_folds - 1
)

# within the cross validation scheme, each specimen is in one test set
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] == "test")
    ]
    .groupby("specimen_label")
    .size()
    == 1
)

In [25]:
# sanity checks:

# each participant is in each fold, in one group or another
assert all(
    specimens_by_fold.groupby("participant_label")["fold_id"].nunique()
    == config.n_folds + 1
)

# within the cross validation scheme, each participant is in two non-test sets
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] != "test")
    ]
    .groupby("participant_label")["fold_id"]
    .nunique()
    == config.n_folds - 1
)

# within the cross validation scheme, each participant is in one test set
assert all(
    specimens_by_fold[
        (specimens_by_fold["fold_id"] != -1)
        & (specimens_by_fold["fold_label"] == "test")
    ]
    .groupby("participant_label")["fold_id"]
    .nunique()
    == 1
)

In [26]:
for (fold_id, fold_label), grp in specimens_by_fold.groupby(
    ["fold_id", "fold_label"], observed=True
):
    print(f"Fold {fold_id}-{fold_label}:")
    display(
        pd.DataFrame(
            [
                grp.groupby("disease")["participant_label"]
                .nunique()
                .rename("#participants"),
                grp.groupby("disease")["specimen_label"].nunique().rename("#specimens"),
            ]
        )
    )

    print()

Fold -1-train_smaller:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,43,63,144,57
#specimens,43,65,146,65



Fold -1-validation:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,20,32,73,29
#specimens,20,33,75,33



Fold 0-test:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,20,32,73,29
#specimens,20,33,75,33



Fold 0-train_smaller:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,28,42,96,38
#specimens,28,44,98,40



Fold 0-validation:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,15,21,48,19
#specimens,15,21,48,25



Fold 1-test:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,22,31,72,29
#specimens,22,33,74,34



Fold 1-train_smaller:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,28,42,96,38
#specimens,28,43,97,44



Fold 1-validation:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,13,22,49,19
#specimens,13,22,50,20



Fold 2-test:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,21,32,72,28
#specimens,21,32,72,31



Fold 2-train_smaller:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,28,42,96,38
#specimens,28,45,98,44



Fold 2-validation:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,14,21,49,20
#specimens,14,21,51,23





In [27]:
# By gene locus
# Nest because can't sort on gene locus column
for (fold_id, fold_label), _grp in specimens_by_fold.groupby(
    ["fold_id", "fold_label"], observed=True
):
    for gene_locus, grp in _grp.groupby(
        ["available_gene_loci"], observed=True, sort=False
    ):
        print(f"Fold {fold_id}-{fold_label}-{gene_locus}:")

        display(
            pd.DataFrame(
                [
                    grp.groupby("disease")["participant_label"]
                    .nunique()
                    .rename("#participants"),
                    grp.groupby("disease")["specimen_label"]
                    .nunique()
                    .rename("#specimens"),
                ]
            )
        )

        print()

Fold -1-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,39,63,129,42
#specimens,39,65,129,42



Fold -1-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,4,15,15
#specimens,4,17,23



Fold -1-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,19,32,65,21
#specimens,19,33,65,22



Fold -1-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,8,8
#specimens,1,10,11



Fold 0-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,19,32,65,21
#specimens,19,33,65,22



Fold 0-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,8,8
#specimens,1,10,11



Fold 0-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,26,42,86,28
#specimens,26,44,86,28



Fold 0-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,12



Fold 0-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,13,21,43,14
#specimens,13,21,43,14



Fold 0-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,5,5
#specimens,2,5,11



Fold 1-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,20,31,65,21
#specimens,20,33,65,21



Fold 1-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,7,8
#specimens,2,9,13



Fold 1-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,26,42,86,28
#specimens,26,43,86,29



Fold 1-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,11,15



Fold 1-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,12,22,43,14
#specimens,12,22,43,14



Fold 1-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,6,5
#specimens,1,7,6



Fold 2-test-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,19,32,64,21
#specimens,19,32,64,21



Fold 2-test-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,8,7
#specimens,2,8,10



Fold 2-train_smaller-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,26,42,86,28
#specimens,26,45,86,29



Fold 2-train_smaller-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,2,10,10
#specimens,2,12,15



Fold 2-validation-GeneLocus.BCR|TCR:


disease,Covid19,HIV,Healthy/Background,Lupus
#participants,13,21,44,14
#specimens,13,21,44,14



Fold 2-validation-GeneLocus.BCR:


disease,Covid19,Healthy/Background,Lupus
#participants,1,5,6
#specimens,1,7,9





In [28]:
fold_participants.to_csv(
    config.paths.dataset_specific_metadata
    / "cross_validation_divisions.participants.tsv",
    sep="\t",
    index=None,
)

In [29]:
specimens_by_fold.to_csv(
    config.paths.dataset_specific_metadata / "cross_validation_divisions.specimens.tsv",
    sep="\t",
    index=None,
)