In [6]:
# pip install -U openai scipy plotly-express scikit-learn umap-learn
# pip install numpy == 1.23.0 # umap requires an older version of numpy

In [1]:
import os
import pandas as pd
import numpy as np
from openai import OpenAI
from sklearn.cluster import KMeans
from scipy.spatial import distance
import plotly.express as px
import umap.umap_ as umap

In [4]:
'''
Performs AI embedding on data
'''

# load data

data = pd.read_csv('../data/log2_data_v3.csv')
print(data.head())

labels = pd.read_csv('../data/labels_v3.csv')
print(labels.head())

   Unnamed: 0   5S_rRNA       7SK      A1BG  A1BG-AS1      A1CF       A2M  \
0  SRR7344546  0.000000  1.164935  0.177984  0.646144  0.051743  7.388509   
1  SRR7344554  0.128807  1.358727  0.070845  0.784005  0.015067  7.796801   
2  SRR7344556  0.162194  0.143445  0.000000  0.000000  0.012671  4.623303   
3  SRR7344564  0.267733  0.030576  0.326047  2.220248  0.105718  7.145439   
4  SRR7344565  0.135665  0.053568  0.241141  1.811070  0.097685  6.680064   

    A2M-AS1     A2ML1     A2MP1  ...   snoU2-30  snoU2_19  snoU83B   snoZ196  \
0  1.137378  0.051041  0.129775  ...   3.130928  3.287475  0.00000  0.000000   
1  2.642253  0.221139  0.119809  ...  13.681881  1.575217  0.00000  0.000000   
2  1.391299  0.016639  0.306234  ...   4.235644  0.264728  0.00000  0.000000   
3  0.736532  0.055363  0.562818  ...   0.226744  0.000000  0.36073  2.853414   
4  0.427896  0.927511  0.202232  ...   0.000000  0.437982  0.00000  1.574765   

   snoZ278  snoZ40  snoZ6  snosnR66    uc_338  yR211F11.

In [5]:
# openAI API key

client = OpenAI(
  api_key='',  
)


In [6]:
# embedding transcriptomic data

def get_embedding(df):

	# model
	
	response = client.embeddings.create(
    	model= 'text-embedding-ada-002',
    	input=[df]
	)
	# Get embedded data
	
	embedding = response.data[0].embedding
    
	return embedding

In [7]:
# drop first column of sample IDs

embedded_data = data.drop(data.columns[0], axis = 1)

# testing: drop columns

keep_thresh = 100
embedded_data = embedded_data.iloc[:, :keep_thresh]
print(embedded_data.shape)

# change numeric to str and concatenate everything by row (sentence-like input structure required)

embedded_data['concatenated'] = embedded_data.astype(str).apply(lambda row: ' '.join(row), axis = 1)
print(embedded_data['concatenated'])

# apply embeddings

embedded_data['embedding'] = embedded_data['concatenated'].apply(get_embedding)



(125, 100)
0      0.0 1.1649346867152932 0.17798414980298533 0.6...
1      0.12880729789026002 1.358726836522174 0.070844...
2      0.16219362419847294 0.14344533937100276 0.0 0....
3      0.2677332475807369 0.03057601299596717 0.32604...
4      0.13566506647355595 0.053567868543348565 0.241...
                             ...                        
120    0.18652448727387372 0.0 0.1414241399271249 0.9...
121    0.020039051950253725 0.0 0.06217935087677976 0...
122    0.07587424138298625 0.0 0.2033058122184411 0.6...
123    0.12628848729363462 0.0 0.2261974310538789 1.1...
124    0.1012481118342196 0.0 0.04323628856435453 0.7...
Name: concatenated, Length: 125, dtype: object


In [8]:
# run kmeans based on embeddings 

# 2 clusters for responder, non-responder

kmeans = KMeans(n_clusters = 2, n_init = 'auto')
model = kmeans.fit(embedded_data['embedding'].tolist())

In [9]:
# dimensionality reduction and visualization using UMAP

um = umap.UMAP()
embedded_data_2d = um.fit_transform(embedded_data['embedding'].tolist())

color_map = {
    '0' : 'orange',
    '1' : 'blue'
}

fig = px.scatter(x = embedded_data_2d[:, 0], y = embedded_data_2d[:, 1], color = model.labels_.astype(str), color_discrete_map = color_map, symbol = labels['Response'])
fig.update_layout(xaxis_title = 'umap 1', yaxis_title = 'umap 2')

fig.show()

In [24]:
# export cluster assignments

clusters_embedding = pd.concat([labels, pd.DataFrame(model.labels_, columns = ['cluster'])], axis = 1)

clusters_embedding.to_csv('../data/clusters_embedding.csv', index = False)

# export data for L1 log reg
embedded_data['embedding'].to_csv('../data/embedded_data.csv', index = False)

