In [1]:
import os

import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import pairwise_distances
import phate
import scprep

os.chdir('../../src')
import mappings
import data_loader

In [2]:
#import hail as hl
from pyplink import PyPlink

In [3]:
import tqdm

In [4]:
# Load MHI
exp_path = '/lustre06/project/6065672/shared/DietNet/1KGB_POP24/CaG/gsa.17k'
fname = 'gsa.17k.final.WR_hg38-updated.missing10perc.noMAF0.common1000G.noHLA'
data_path = os.path.join(exp_path, fname)

pedfile = PyPlink(data_path)

In [5]:
all_samples = pedfile.get_fam()

In [6]:
all_samples

Unnamed: 0,fid,iid,father,mother,gender,status
0,11112892,11112892,0,0,2,-9
1,11117598,11117598,0,0,2,-9
2,11109696,11109696,0,0,2,-9
3,11106652,11106652,0,0,1,-9
4,11114212,11114212,0,0,2,-9
...,...,...,...,...,...,...
17281,11105584,11105584,0,0,1,-9
17282,11130905,11130905,0,0,1,-9
17283,11140411,11140411,0,0,1,-9
17284,11111395,11111395,0,0,2,-9


In [7]:
all_markers = pedfile.get_bim()
all_markers.head()

Unnamed: 0_level_0,chrom,pos,cm,a1,a2
snp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
chr1:858952:G:A,chr1,858952,0.0,A,G
chr1:905373:T:C,chr1,905373,0.0,C,T
chr1:911428:C:T,chr1,911428,0.0,T,C
chr1:918870:A:G,chr1,918870,0.0,G,A
chr1:931513:T:C,chr1,931513,0.0,C,T


In [8]:
genotypes_array = np.load(data_path + '_raw_genotypes.npy')

In [9]:
#subset = all_samples['fid'].isin(["11112892", "11117598", "11109696", "11106652", 
#                         "11114212", "11130295", "11119566", "11142328",
#                        "11132046", "11118772"])

genotypes_array = np.zeros([pedfile.get_nb_samples(), pedfile.get_nb_markers()], dtype=np.int8)

for i, (marker_id, genotypes) in tqdm.tqdm(enumerate(pedfile)):
    genotypes_array[:,i] = genotypes

np.save(data_path + '_raw_genotypes.npy', genotypes_array)

229986it [00:23, 9868.36it/s] 


In [10]:
column_means = [genotypes_array[:,i][genotypes_array[:,i] != -1].mean() for i in tqdm.tqdm(range(pedfile.get_nb_markers()))]
column_means = np.array(column_means)

100%|██████████| 229986/229986 [00:15<00:00, 14998.81it/s]


In [11]:
genotypes_array = genotypes_array.astype(np.float16)
for i in tqdm.tqdm(range(pedfile.get_nb_markers())):
    genotypes_array[:,i][genotypes_array[:,i] == -1] = column_means[i]

100%|██████████| 229986/229986 [00:16<00:00, 14274.54it/s]


In [12]:
genotypes_array = genotypes_array - column_means.reshape(1,-1)

In [13]:
np.save(data_path + '_proc_genotypes.npy', genotypes_array)

In [15]:
genotypes_array.shape

(17286, 229986)

In [None]:
from sklearn.decomposition import PCA

# Fit PCA model
pca_1K = PCA(n_components=100)  # Number of components to keep
pca_1K.fit(inputs_raw_1K)
pca_input_1K = pca_1K.transform(inputs_raw_1K)

In [None]:
metadata_mhi['selfreported_pop']

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
scprep.plot.scatter2d(pca_input_1K[:, [0, 1]], s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax,
                      c = labels_1K['population'], xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PCA Plot',fontsize=8)

In [None]:
# Project MHI onto 1000G
pca_input_mhi = pca_1K.transform(inputs_mhi)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
scprep.plot.scatter2d(pca_input_1K[:, [0, 1]], s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax,
                      c = labels_1K['population'], xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PCA Plot',fontsize=8)

scprep.plot.scatter2d(pca_input_mhi[:, [0, 1]], c='black', marker='x', s = 20, ax=ax, xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PCA Plot',fontsize=8)

In [None]:
fig, ax = plt.subplots(nrows=6, ncols=5, figsize=(10, 10))

comp = 1
for i in range(6):
    for j in range(5):
        scprep.plot.scatter2d(pca_input_1K[:, [comp-1, comp]], s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax[i,j],
                              c = labels_1K['population'], xticks=False, yticks=False,
                              legend=False, title='PCA Plot',fontsize=8)
        scprep.plot.scatter2d(pca_input_mhi[:, [comp-1, comp]], s = 20, c='black', marker='x', ax=ax[i,j],
                              xticks=False, yticks=False,
                              legend=False, title='PCA Plot')

        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
        ax[i, j].set_title('{} vs {}'.format(comp, comp+1), fontsize=12)

        comp += 1

plt.tight_layout()


## Do Embedding

First just do PHATE on 1000G

In [None]:
phate_operator = phate.PHATE(random_state=42, knn=5, t=5, n_pca=None)
phate_operator.fit(pca_input_1K)
phate_emb_1K = phate_operator.transform(pca_input_1K)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
scprep.plot.scatter2d(phate_emb_1K, s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax,
                      c = labels_1K['population'], xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PCA Plot',fontsize=8)

Now embed 2D MHI onto this

In [None]:
phate_emb_mhi = phate_operator.transform(pca_input_mhi)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
scprep.plot.scatter2d(phate_emb_1K, s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax,
                      c = labels_1K['population'], xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PHATE Plot',fontsize=8)

scprep.plot.scatter2d(phate_emb_mhi, c='black', marker='x', s = 20, ax=ax, xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PHATE Plot',fontsize=8)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
scprep.plot.scatter2d(phate_emb_mhi, c = metadata_mhi, cmap=pop_pallette_mhi, s = 20, ax=ax, xticks=False, yticks=False,
                      legend=True, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                      title='PHATE Plot',fontsize=8)

## Try to get 10D PHATE

In [None]:
phate_operator.set_params(n_components=10, verbose=3)
phate_emb_1K_10d = phate_operator.transform(pca_input_1K)
phate_emb_mhi_10d = phate_operator.transform(pca_input_mhi)

In [None]:
fig, ax = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))

for i in range(10):
    for j in range(10):
        scprep.plot.scatter2d(phate_emb_1K_10d[:,[i, j]], s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax[i,j],
                              c = labels_1K['population'], xticks=False, yticks=False,
                              legend=False, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                              title='PHATE Plot',fontsize=8)

        scprep.plot.scatter2d(phate_emb_mhi_10d[:,[i, j]], c='black', marker='x', s = 20, ax=ax[i,j], xticks=False, yticks=False,
                              legend=False, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                              title='PHATE Plot',fontsize=8)

        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
        ax[i, j].set_title('{} vs {}'.format(i, j), fontsize=12)

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))

comp = 1
for i in range(3):
    for j in range(3):
        scprep.plot.scatter2d(phate_emb_1K_10d[:, [comp-1, comp]], s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax[i,j],
                              c = labels_1K['population'], xticks=False, yticks=False,
                              legend=False, title='PCA Plot',fontsize=8)
        scprep.plot.scatter2d(phate_emb_mhi_10d[:, [comp-1, comp]], s = 20, c='black', marker='x', ax=ax[i,j],
                              xticks=False, yticks=False,
                              legend=False, title='PCA Plot')

        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
        ax[i, j].set_title('{} vs {}'.format(comp, comp+1), fontsize=12)

        comp += 1

plt.tight_layout()


In [None]:
pd.DataFrame(phate_emb_mhi_10d, index=samples_mhi).to_csv('/lustre06/project/6065672/shared/trajGWAS/mhi_10d_phate.csv')
pd.DataFrame(pca_input_mhi, index=samples_mhi).to_csv('/lustre06/project/6065672/shared/trajGWAS/mhi_100d_pca.csv')

## Maybe Eigendecompose Diffusion Operator Instead?

In [None]:
transitions_1K = phate_operator.graph.extend_to_data(pca_input_1K)
transitions_mhi = phate_operator.graph.extend_to_data(pca_input_mhi)

In [None]:
# Fit PCA model
pca_1K_diff_op = PCA(n_components=100)  # Number of components to keep
pca_1K_diff_op.fit(transitions_1K.toarray())
pca_diff_op_1K = pca_1K_diff_op.transform(transitions_1K.toarray())
pca_diff_op_mhi = pca_1K_diff_op.transform(transitions_mhi.toarray())

In [None]:
fig, ax = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))

for i in range(10):
    for j in range(10):
        scprep.plot.scatter2d(pca_diff_op_1K[:,[i, j]], s = 20, cmap = mappings.pop_pallette_1000G_fine, ax=ax[i,j],
                              c = labels_1K['population'], xticks=False, yticks=False,
                              legend=False, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                              title='PHATE Plot',fontsize=8)

        scprep.plot.scatter2d(pca_diff_op_mhi[:,[i, j]], c='black', marker='x', s = 20, ax=ax[i,j], xticks=False, yticks=False,
                              legend=False, legend_loc='lower center', legend_anchor=(0.5, -0.15), legend_ncol=8,
                              title='PHATE Plot',fontsize=8)

        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
        ax[i, j].set_title('{} vs {}'.format(i, j), fontsize=12)

plt.tight_layout()