# PCA

### In this notebook we tackle the problem of sparsity in the feature representations of the different modes (chromatin accessibility, gene expression, surface protein levels). 

As noted in https://www.kaggle.com/code/leohash/complete-eda-of-mmscel-integration-data/notebook, DNA data has between 1-30k of the 229k features being nonzero, RNA data has 2-8k of the ~28k features as nonzero, and protein data has a small number of features and is sparse, which means this notebook don't care :)

We will apply PCA (and maybe some other techniques) to investigate whether we can usefully hop into a lower-dimensional, densely-populated representation for either of these two modes.

In [None]:
# imports
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import tqdm

from datasets import SparseDataset, H5Dataset

from sklearn.decomposition import PCA, SparsePCA, IncrementalPCA

In [None]:
# -----------------------------------------------------------

# Da Work

## DNA (Chromatin Accessibility)

In [13]:
## check if we can easily find which dna codes for which rna
import pandas as pd

multi_df = pd.read_hdf('data/train_multi_inputs.h5', start=0, stop=1)
cite_df = pd.read_hdf('data/train_multi_targets.h5', start=0, stop=1)

display(multi_df.head())
display(cite_df.head())


#     k = k.split('.')[0]
#     for kc in cite_keys:
#         if k in kc:
#             print(k, kc)
#             break


# for i in range(len(cite_keys)):
#     cite_keys[i] = cite_keys[i].split('_')[0]

# multi_idxs = []
# cite_idxs = []
# for i, s in enumerate(multi_keys):
#     if s in cite_keys:
#         multi_idxs.append(i)
#         cite_idxs.append(cite_keys.index(s))

gene_id,GL000194.1:114519-115365,GL000194.1:55758-56597,GL000194.1:58217-58957,GL000194.1:59535-60431,GL000195.1:119766-120427,GL000195.1:120736-121603,GL000195.1:137437-138345,GL000195.1:15901-16653,GL000195.1:22357-23209,GL000195.1:23751-24619,...,chrY:7722278-7723128,chrY:7723971-7724880,chrY:7729854-7730772,chrY:7731785-7732664,chrY:7810142-7811040,chrY:7814107-7815018,chrY:7818751-7819626,chrY:7836768-7837671,chrY:7869454-7870371,chrY:7873814-7874709
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
56390cf1b95e,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,4.428336,0.0,0.0,0.0,0.0


gene_id,ENSG00000121410,ENSG00000268895,ENSG00000175899,ENSG00000245105,ENSG00000166535,ENSG00000256661,ENSG00000184389,ENSG00000128274,ENSG00000094914,ENSG00000081760,...,ENSG00000086827,ENSG00000174442,ENSG00000122952,ENSG00000198205,ENSG00000198455,ENSG00000070476,ENSG00000203995,ENSG00000162378,ENSG00000159840,ENSG00000074755
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
56390cf1b95e,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,4.893861,0.0,0.0,0.0,0.0,5.583255,0.0,4.893861


In [20]:
multi_keys = list(multi_df.keys())
cite_keys = list(cite_df.keys())

for i, k in enumerate(multi_keys):
    if k.__contains__('chr'):
        print(k, i)
        break

chr10:100001240-100002159 64


In [None]:
# find max variance columns
from datasets import H5Dataset, SparseDataset
import numpy as np
import tqdm
import torch

sums = np.zeros(228942)
squared_sums = np.zeros(228942)
num_nonzero = np.zeros(228942)

d = SparseDataset('all', 'multi')
s = d.get_dataloader(512)
for x, day, y in tqdm.tqdm(s):
    sums += x.sum(dim=0).numpy()
    squared_sums += torch.square(x).sum(dim=0).numpy()
    num_nonzero += (x != 0).sum(dim=0).numpy()
    
variances = squared_sums / 105942 - np.square(sums / 105942)

In [None]:
var_idxs = np.argsort(variances)[::-1]
nz_idxs = np.argsort(num_nonzero)[::-1]

In [None]:
ranks = np.zeros(228942)
for i, v in enumerate(var_idxs):
    ranks[v] += i
for i, v in enumerate(nz_idxs):
    ranks[v] += i
best_idxs = np.argsort(ranks)

In [None]:
np.save('data/multi_best_idxs.npy', best_idxs)
np.save('data/multi_var_idxs.npy', var_idxs)
np.save('data/multi_nz_idxs.npy', nz_idxs)

## RNA (Gene Expression)

## Variances

In [None]:
# find max variance columns
from datasets import H5Dataset, SparseDataset, SubmissionDataset
import numpy as np
import tqdm
import torch

sums = np.zeros(22050)
squared_sums = np.zeros(22050)
num_nonzero = np.zeros(22050)

d = H5Dataset('all', 'cite')
n_data_train = len(d)
s = d.get_dataloader(1024)
for x, day, y in tqdm.tqdm(s):
    sums += x.sum(dim=0).numpy()
    squared_sums += torch.square(x).sum(dim=0).numpy()
    num_nonzero += (x != 0).sum(dim=0).numpy()
    
d = SubmissionDataset('cite', 0)
n_data_test = len(d)
s = d.get_dataloader(512)
for x, day in tqdm.tqdm(s):
    sums += x.sum(dim=0).numpy()
    squared_sums += torch.square(x).sum(dim=0).numpy()
    num_nonzero += (x != 0).sum(dim=0).numpy()
    
n = n_data_train + n_data_test
variances = squared_sums / n - np.square(sums / n)

In [None]:
var_idxs = np.argsort(variances)[::-1]
nz_idxs = np.argsort(num_nonzero)[::-1]

In [None]:
ranks = np.zeros(22050)
for i, v in enumerate(var_idxs):
    ranks[v] += i
for i, v in enumerate(nz_idxs):
    ranks[v] += i
best_idxs = np.argsort(ranks)

In [None]:
from utils import CITESEQ_CODING_GENES, CITESEQ_CONSTANT_GENES
import pandas as pd

cite_df = pd.read_hdf('data/train_cite_inputs.h5', start=1000, stop=2000)
cite_keys = list(cite_df.keys())

other_best_idxs = []
for i in best_idxs:
    k = cite_keys[i]
    if k not in CITESEQ_CODING_GENES and k not in CITESEQ_CONSTANT_GENES:
        other_best_idxs.append(i)
        
other_var_idxs = []
for i in var_idxs:
    k = cite_keys[i]
    if k not in CITESEQ_CODING_GENES and k not in CITESEQ_CONSTANT_GENES:
        other_var_idxs.append(i)

In [None]:
for e, i in enumerate(other_var_idxs):
    print(variances[i], num_nonzero[i])
    if e > 20: break

In [None]:
import pandas as pd

multi_df = pd.read_hdf('data/train_multi_targets.h5', start=1000, stop=2000)
cite_df = pd.read_hdf('data/train_cite_inputs.h5', start=1000, stop=2000)

multi_keys = list(multi_df.keys())
cite_keys = list(cite_df.keys())

for i in range(len(cite_keys)):
    cite_keys[i] = cite_keys[i].split('_')[0]

multi_idxs = []
cite_idxs = []
for i, s in enumerate(multi_keys):
    if s in cite_keys:
        multi_idxs.append(i)
        cite_idxs.append(cite_keys.index(s))

In [None]:
# multi_rna = np.asarray(H5Dataset('all', 'multi').targets_h5)
# cite_rna = np.asarray(H5Dataset('all', 'cite').inputs_h5)
# multi_shared = multi_rna[:, multi_idxs]
# cite_shared = cite_rna[:, cite_idxs]
# # shared = np.concatenate((multi_shared, cite_shared), axis=0)

In [None]:
batch_size=5120

multi_shared_loader = H5Dataset('all', 'multi').get_dataloader(batch_size)
cite_shared_loader = H5Dataset('all', 'cite').get_dataloader(batch_size)

p = IncrementalPCA(5000, batch_size=batch_size)

for (x, day), y in tqdm.tqdm(multi_shared_loader):
    rna = y.numpy()
    rna = rna[:, multi_idxs]
    p.fit((rna != 0).astype(float))

In [None]:
# for (x, day), y in tqdm.tqdm(multi_shared_loader):
#     rna = y.numpy()
#     rna = (rna[:, multi_idxs] != 0).astype(float)
#     t = p.transform(rna)
#     r = p.inverse_transform(t)
#     print(r[0, :20])
#     print(rna[0, :20])
#     break

In [None]:
for (x, day), y in tqdm.tqdm(cite_shared_loader):
    rna = x.numpy()
    rna = rna[:, cite_idxs]
    p.fit((rna != 0).astype(float))

In [None]:
for (x, day), y in tqdm.tqdm(multi_shared_loader):
    rna = y.numpy()
    rna = (rna[:, multi_idxs] != 0).astype(float)
    t = p.transform(rna)
    r = p.inverse_transform(t)
    print(r[0, :20])
    print(rna[0, :20])
    break
for (x, day), y in tqdm.tqdm(cite_shared_loader):
    rna = x.numpy()
    rna = (rna[:, cite_idxs] != 0).astype(float)
    t = p.transform(rna)
    r = p.inverse_transform(t)
    print(r[0, :20])
    print(rna[0, :20])
    break

In [None]:
p2 = IncrementalPCA(5120)
with open('data/pca.pkl', 'rb') 
p2 = pickle.load(f)