In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import umap
from scipy.spatial import distance
import warnings
warnings.filterwarnings("ignore")

In [None]:
# genotype labels
# -/- = -1      (missing - same as 0/0)
# 0/0 = 1
# 0/1 = 2
# 1/1 = 3

# means 77 positions/SNPs x 21 samples
g = np.load('results/gt_genes/gene_10.npy')
g

In [None]:
# 0 = ARA
# 1 = GAM
# 2 = HYB
# 3 = CTR
sample_cohorts = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2])
sample_cohorts.shape

In [None]:
np.unique(g)

In [None]:
y = np.expand_dims(sample_cohorts, 0)
arr = np.append(y.transpose(), g.transpose(), axis=1)
df = pd.DataFrame(arr, columns=['y'] + [f"pos{i}" for i in range(g.shape[0])])
df.iloc[:, 1:] = df.iloc[:, 1:].astype(object)  # encode position values as categorical

In [None]:
df

In [None]:
df_w_dummies = pd.get_dummies(df)
df_w_dummies.shape
print(77*4)
# we would expect shape of 77*4 == 308
# we only get 226, meaning not every position observes all the possible values [-1, 1, 2, 3]

In [None]:
df_w_dummies

In [None]:
# https://umap-learn.readthedocs.io/en/latest/basic_usage.html
reducer = umap.UMAP()
embedding = reducer.fit_transform(df_w_dummies.iloc[:, 1:])

In [None]:
arr = np.append(y.transpose(), embedding, axis=1)
embedding_df = pd.DataFrame(arr, columns=['label'] + [f"emb{i}" for i in range(embedding.shape[1])])
embedding_df.label = embedding_df.label.astype(int).astype('category')

In [None]:
fig = px.scatter(embedding_df, x="emb0", y="emb1", color="label", width=1000, height=600)
fig.show()


In [None]:
cluster_center_0 = embedding_df[embedding_df.label == 0].mean().to_numpy() # ARA
cluster_center_1 = embedding_df[embedding_df.label == 1].mean().to_numpy() # GAM

In [None]:
test_cases = embedding_df[embedding_df.label.isin([2, 3])]
for idx, row in test_cases.iterrows():
    sample_pos = row[['emb0', 'emb1']].to_numpy()
    dist0 = distance.euclidean(sample_pos, cluster_center_0)
    dist1 = distance.euclidean(sample_pos, cluster_center_1)
    if dist0 > dist1:
        print(f"Sample {row.to_numpy()} is closer to label 0.")
    else:
        print(f"Sample {row.to_numpy()} is closer to label 1.")