# Analysing the class distribution across tissues

In [74]:
import os
import pickle
import torch
import pandas as pd
from itertools import compress

from utils.knowledge_db import TISSUES

In [43]:
base_path = "/export/share/krausef99dm/data/"

cod_train_data_file = os.path.join(base_path, "data_train", "codon_train_8.1k_data.pkl")
cod_val_data_file = os.path.join(base_path, "data_test", "codon_val_8.1k_data.pkl")
cod_test_data_file = os.path.join(base_path, "data_test", "codon_test_8.1k_data.pkl")

nuc_train_data_file = os.path.join(base_path, "data_train", "train_9.0k_data.pkl")
nuc_val_data_file = os.path.join(base_path, "data_test", "val_9.0k_data.pkl")
nuc_test_data_file = os.path.join(base_path, "data_test", "test_9.0k_data.pkl")

In [79]:
def get_tissue_distrib(path):
    with open(path, 'rb') as f:
        rna_data, tissue_ids, targets, targets_bin =  pickle.load(f)
        
    # filter data
    mask = torch.ones((len(rna_data)), dtype=torch.bool)
    mask_bin = targets_bin > 0  # only keep low-/high-PTR samples
    mask = mask_bin & mask
    targets_bin -= 1  # make binary class 0/1 encoded

    if "cod" in path:
        max_seq_len = 2700
    else:
        max_seq_len = 9000
    
    mask_len = torch.tensor([len(d) <= max_seq_len for d in rna_data])
    mask = mask_len & mask
    
    rna_data = list(compress(rna_data, mask))
    tissue_ids = list(compress(tissue_ids.tolist(), mask))
    targets = list(compress(targets.tolist(), mask))
    targets_bin = list(compress(targets_bin.tolist(), mask))
    
    df = pd.DataFrame({"tissue_id": tissue_ids, "targets_bin": targets_bin,
                       "targets": targets})
    return df.groupby(['tissue_id', 'targets_bin']).targets.count()

In [80]:
# Get all distributions
cod_train_distrib, cod_val_distrib, cod_test_distrib, nuc_train_distrib, nuc_val_distrib, nuc_test_distrib = \
    [get_tissue_distrib(p) for p in [cod_train_data_file, cod_val_data_file, cod_test_data_file, nuc_train_data_file, nuc_val_data_file, nuc_test_data_file]]

In [90]:
cod_combined = pd.concat([cod_train_distrib, cod_val_distrib, cod_test_distrib], ignore_index=False, axis=1)
cod_combined.columns = ["train", "val", "test"]
cod_combined

Unnamed: 0_level_0,Unnamed: 1_level_0,train,val,test
tissue_id,targets_bin,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0,116,32,24
0,1,149,44,23
1,0,418,86,89
1,1,194,37,43
2,0,204,49,38
2,1,523,124,81
3,0,207,54,44
3,1,199,33,41
4,0,161,32,34
4,1,144,30,25


In [None]:
nuc_combined = pd.concat([nuc_train_distrib, nuc_val_distrib, nuc_test_distrib], ignore_index=False, axis=1)
nuc_combined.columns = ["train", "val", "test"]
nuc_combined.index = TISSUES
nuc_combined

In [108]:
inner_levels = cod_combined.index.droplevel(0)
new_outer = [t for t in TISSUES for _ in range(2)]

# Create new MultiIndexTISSUES
cod_combined.index = pd.MultiIndex.from_arrays(
    [new_outer, inner_levels.get_level_values(0)],
    names=['tissue', 'targets_bin']
)

nuc_combined.index = pd.MultiIndex.from_arrays(
    [new_outer, inner_levels.get_level_values(0)],
    names=['tissue', 'targets_bin']
)

In [118]:
combined = pd.concat([cod_combined, nuc_combined], ignore_index=False, axis=1)
combined.columns = [("Codon", "train"), ("Codon", "val"), ("Codon", "test"), 
                    ("Nucleotide", "train"), ("Nucleotide", "val"), ("Nucleotide", "test"), ]

combined.columns = pd.MultiIndex.from_tuples(combined.columns, names=['Encoding', 'Dataset'])
combined

Unnamed: 0_level_0,Encoding,Codon,Codon,Codon,Nucleotide,Nucleotide,Nucleotide
Unnamed: 0_level_1,Dataset,train,val,test,train,val,test
tissue,targets_bin,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
Adrenal,0,116,32,24,114,31,24
Adrenal,1,149,44,23,144,41,23
Appendices,0,418,86,89,406,84,89
Appendices,1,194,37,43,181,36,43
Brain,0,204,49,38,196,47,38
Brain,1,523,124,81,501,119,81
Colon,0,207,54,44,200,53,44
Colon,1,199,33,41,191,31,41
Duodenum,0,161,32,34,157,32,34
Duodenum,1,144,30,25,141,27,25


In [124]:
# low PTR counts
combined[combined.index.get_level_values('targets_bin') == 0].sum()

Encoding    Dataset
Codon       train      6808
            val        1476
            test       1390
Nucleotide  train      6535
            val        1423
            test       1390
dtype: int64

In [125]:
# high PTR counts
combined[combined.index.get_level_values('targets_bin') == 1].sum()

Encoding    Dataset
Codon       train      6788
            val        1475
            test       1383
Nucleotide  train      6520
            val        1420
            test       1383
dtype: int64

In [121]:
# Total
combined.sum()

Encoding    Dataset
Codon       train      13596
            val         2951
            test        2773
Nucleotide  train      13055
            val         2843
            test        2773
dtype: int64