In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
import umap as um
import pathlib as pl
import seaborn as sns
from scipy.io import mmread
import scipy.sparse as scs
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA


def sigmoid(x):
    sig = 1 / (1 + np.exp(-(12*x)+6))
    return sig


In [None]:
save_path = '../results_manuscript/UMAPs_encoding/'
data_dir = 'sce_normalized_data_inflate'
path = '../data/' + data_dir + '/'
files = [f for f in listdir(path) if (isfile(join(path, f)) & (f[-18:-4] == 'real_logcounts'))]
files = np.sort(files)

In [None]:
for file in files:

    data_name = file[:-19]
    print(data_name)

    enc_file_real = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_real.npy'
    enc_file_sim = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_sim.npy'
    enc_real = np.load(enc_file_real)
    enc_sim = np.load(enc_file_sim)
    
    num_doubs = enc_sim.shape[0]
    enc = np.vstack([enc_real, enc_sim])
    del(enc_real)
    del(enc_sim)
    
    #read in UMAP coords
    umap_file = save_path + 'UMAP_cords/' + data_name + '_encoding_UMAP.npy'
    if(pl.Path(umap_file).exists()):
        proj = np.load(umap_file)
    else:
        proj = um.UMAP(n_neighbors=7).fit_transform(enc)
        np.save(umap_file, proj)

    ano_path  = '../data/'+ data_dir + '/' + data_name + '_anno.csv'

    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    true = np.concatenate([true, np.full(num_doubs, 2)])
    labels = np.concatenate([labels, np.full(num_doubs, 'sim')])
    
    sns.set_style("white")

    # Basic 2D density plot
    fig, axes = plt.subplots(1, 4, sharex=True, figsize=(20,5))

    plot0 = sns.kdeplot(ax=axes[0], x=proj[:,0], y=proj[:,1], cmap="crest", shade=True, bw_adjust=.5)
    axes[0].set_title('Combined')
    plot0.set(yticklabels=[])
    plot0.set(xticklabels=[])
    x0,x1 = axes[0].get_xlim()
    y0,y1 = axes[0].get_ylim()

    plot1 = sns.kdeplot(ax=axes[1], x=proj[true==0,0], y=proj[true==0,1], cmap="Reds", shade=True, bw_adjust=.5)
    axes[1].set_title('Singlets')
    plot1.set(yticklabels=[])
    plot1.set(ylim=(y0, y1))

    plot2 = sns.kdeplot(ax=axes[2], x=proj[true==1,0], y=proj[true==1,1], cmap="Blues", shade=True, bw_adjust=.5)
    axes[2].set_title('Doublets')
    plot2.set(yticklabels=[])
    plot2.set(ylim=(y0, y1))

    plot3 = sns.kdeplot(ax=axes[3], x=proj[true==2,0], y=proj[true==2,1], cmap="Greens", shade=True, bw_adjust=.5)
    axes[3].set_title('Simulated')
    plot3.set(yticklabels=[])
    plot3.set(ylim=(y0, y1))


    plt.suptitle(data_name + ' encoding UMAP', fontsize=20)
    plt.savefig(save_path + data_name + '_encoding_UMAP_density.png', dpi=100)
    plt.show()
    plt.close(fig)
        

    

In [None]:
save_path = '../results_manuscript/UMAPs_encoding/'
data_dir = 'sce_normalized_data_inflate'
path = '../data/' + data_dir + '/'
files = [f for f in listdir(path) if (isfile(join(path, f)) & (f[-18:-4] == 'real_logcounts'))]
files = np.sort(files)


#files = files[3:4]
d = 0

cols = 4
rows = len(files)

width = 18.267717
width_p_plot = width / cols
height = rows*width_p_plot + rows*0.3

fig,axes = plt.subplots(figsize=(width,height),sharex=True, sharey=True,dpi=300)


for file in files:

    data_name = file[:-19]
    print(data_name)

    enc_file_real = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_real.npy'
    enc_file_sim = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_sim.npy'
    enc_real = np.load(enc_file_real)
    enc_sim = np.load(enc_file_sim)
    
    num_doubs = enc_sim.shape[0]
    enc = np.vstack([enc_real, enc_sim])
    del(enc_real)
    del(enc_sim)
    
    #read in UMAP coords
    umap_file = save_path + 'UMAP_cords/' + data_name + '_encoding_UMAP.npy'
    if(pl.Path(umap_file).exists()):
        proj = np.load(umap_file)
    else:
        proj = um.UMAP(n_neighbors=7).fit_transform(enc)
        np.save(umap_file, proj)

    ano_path  = '../data/'+ data_dir + '/' + data_name + '_anno.csv'

    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    true = np.concatenate([true, np.full(num_doubs, 2)])
    labels = np.concatenate([labels, np.full(num_doubs, 'sim')])
    
    sns.set_style("white")

    # Basic 2D density plot
    d = d+1
    plt.subplot(rows, cols, d)
    plot0 = sns.kdeplot(x=proj[:,0], y=proj[:,1], cmap="crest", shade=True, bw_adjust=.5)
    #axes[0].set_title('Combined')
    plot0.set(yticklabels=[])
    plot0.set(xticklabels=[])
    x0,x1 = plot0.get_xlim()
    y0,y1 = plot0.get_ylim()
    if(d <= 4):
        plt.title('Combined', fontsize=20)
    if((d-1)%4==0):
        plt.ylabel(data_name, fontsize=20)
    #x0,x1 = axes[0].get_xlim()
    #y0,y1 = axes[0].get_ylim()

    d = d+1
    plt.subplot(rows, cols, d)
    plot1 = sns.kdeplot(x=proj[true==0,0], y=proj[true==0,1], cmap="Reds", shade=True, bw_adjust=.5)
    #axes[1].set_title('Singlets')
    plot1.set(yticklabels=[])
    plot1.set(xticklabels=[])
    plot1.set(ylim=(y0, y1), xlim=(x0,x1))
    if(d <= 4):
        plt.title('Real Singlets', fontsize=20)

    d = d+1
    plt.subplot(rows, cols, d)
    plot2 = sns.kdeplot(x=proj[true==1,0], y=proj[true==1,1], cmap="Blues", shade=True, bw_adjust=.5)
    #axes[2].set_title('Doublets')
    plot2.set(yticklabels=[])
    plot2.set(xticklabels=[])
    plot2.set(ylim=(y0, y1), xlim=(x0,x1))
    if(d <= 4):
        plt.title('Real Doublets', fontsize=20)

    d = d+1
    plt.subplot(rows, cols, d)
    plot3 = sns.kdeplot(x=proj[true==2,0], y=proj[true==2,1], cmap="Greens", shade=True, bw_adjust=.5)
    #axes[3].set_title('Simulated')
    plot3.set(yticklabels=[])
    plot3.set(xticklabels=[])
    plot3.set(ylim=(y0, y1), xlim=(x0,x1))
    if(d <= 4):
        plt.title('Simmulated Doublets', fontsize=20)


#plt.suptitle('UMAP of Encoding Densities', fontsize=20)
plt.tight_layout()
plt.savefig(save_path + 'ALL_encoding_UMAP_density.png', dpi=300)
#plt.show()
#plt.close(fig)
    
    

In [None]:
for file in files:

    data_name = file[:-19]
    print(data_name)

    real_path = '../data/mtx_files/' + data_name + '.mtx'
    ano_path  = '../data/mtx_files/' + data_name + '_anno.csv'
    
    
    pca_file = save_path + 'PCA_cords/' + data_name + '_PCA.npy'
    if(pl.Path(pca_file).exists()&False):
        pca_proj = np.load(pca_file)
    else:
        print('loading in real mtx')
        dat_real = mmread(real_path)
        Xr = scs.csr_matrix(dat_real).toarray().T

        npz_sim = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_sim_doubs.npz'
        which_doubs = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_which_sim_doubs.npy'

        print('loading in sim npz')
        dat_sim = scs.load_npz(npz_sim)
        Xs = scs.csr_matrix(dat_sim).toarray()
        
        ind = np.load(which_doubs)
        Xs = Xs[ind,:]
        print(Xs.shape)
        
        num_doubs = Xs.shape[0]

        X = np.vstack([Xr,Xs])
        del(Xr)
        del(Xs)


        #Filter genes
        thresh = np.floor(X.shape[0]) * 0.01
        tmp    = np.sum((X>0), axis=0)>thresh
        X = X[:,tmp]

        #- HVGs
        if(X.shape[1] > 2000):
            var = np.var(X, axis=0)
            np.random.seed(3900362577)
            hvgs = np.argpartition(var, -2000)[-2000:]  
            X = X[:,hvgs]

        #SCALING
        X = np.log2(X+1)
        np.random.seed(42)
        scaler = StandardScaler().fit(X.T)
        np.random.seed(42)
        X = scaler.transform(X.T).T

        #PCA    
        np.random.seed(42)
        pca = PCA(n_components=5)
        pca_proj = pca.fit_transform(X)
        
        np.save(pca_file, pca_proj)

    #read in UMAP coords
    umap_file = save_path + 'UMAP_cords/' + data_name + '_PCA_UMAP.npy'
    if(pl.Path(umap_file).exists()&False):
        proj = np.load(umap_file)
    else:
        proj = um.UMAP(n_neighbors=7).fit_transform(pca_proj)
        np.save(umap_file, proj)

    ano_path  = '../data/'+ data_dir + '/' + data_name + '_anno.csv'

    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    true = np.concatenate([true, np.full(num_doubs, 2)])
    labels = np.concatenate([labels, np.full(num_doubs, 'sim')])
    
    sns.set_style("white")

    # Basic 2D density plot
    fig, axes = plt.subplots(1, 4, sharex=True, figsize=(20,5))

    plot0 = sns.kdeplot(ax=axes[0], x=proj[:,0], y=proj[:,1], cmap="crest", shade=True, bw_adjust=.5)
    axes[0].set_title('Combined')
    plot0.set(yticklabels=[])
    plot0.set(xticklabels=[])
    x0,x1 = axes[0].get_xlim()
    y0,y1 = axes[0].get_ylim()

    plot1 = sns.kdeplot(ax=axes[1], x=proj[true==0,0], y=proj[true==0,1], cmap="Reds", shade=True, bw_adjust=.5)
    axes[1].set_title('Singlets')
    plot1.set(yticklabels=[])
    plot1.set(ylim=(y0, y1))

    plot2 = sns.kdeplot(ax=axes[2], x=proj[true==1,0], y=proj[true==1,1], cmap="Blues", shade=True, bw_adjust=.5)
    axes[2].set_title('Doublets')
    plot2.set(yticklabels=[])
    plot2.set(ylim=(y0, y1))

    plot3 = sns.kdeplot(ax=axes[3], x=proj[true==2,0], y=proj[true==2,1], cmap="Greens", shade=True, bw_adjust=.5)
    axes[3].set_title('Simulated')
    plot3.set(yticklabels=[])
    plot3.set(ylim=(y0, y1))


    plt.suptitle(data_name + ' PCA UMAP', fontsize=20)
    plt.savefig(save_path + data_name + '_PCA_UMAP_density.png', dpi=100)
    plt.show()
    plt.close(fig)
        

In [None]:
which_doubs

In [None]:
d = 0

cols = 4
rows = len(files)

width = 18.267717
width_p_plot = width / cols
height = rows*width_p_plot + rows*0.3

fig,axes = plt.subplots(figsize=(width,height),sharex=True, sharey=True,dpi=300)


for file in files:

    data_name = file[:-19]
    print(data_name)
    
    #read in UMAP coords
    umap_file = save_path + 'UMAP_cords/' + data_name + '_PCA_UMAP.npy'
    if(pl.Path(umap_file).exists()):
        proj = np.load(umap_file)
    else:
        proj = um.UMAP(n_neighbors=7).fit_transform(enc)
        np.save(umap_file, proj)

    ano_path  = '../data/'+ data_dir + '/' + data_name + '_anno.csv'
    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    num_doubs = proj.shape[0] - len(true)
    
    true = np.concatenate([true, np.full(num_doubs, 2)])
    labels = np.concatenate([labels, np.full(num_doubs, 'sim')])
    
    sns.set_style("white")

    # Basic 2D density plot
    d = d+1
    plt.subplot(rows, cols, d)
    plot0 = sns.kdeplot(x=proj[:,0], y=proj[:,1], cmap="crest", shade=True, bw_adjust=.5)
    #axes[0].set_title('Combined')
    plot0.set(yticklabels=[])
    plot0.set(xticklabels=[])
    x0,x1 = plot0.get_xlim()
    y0,y1 = plot0.get_ylim()
    if(d <= 4):
        plt.title('Combined', fontsize=20)
    if((d-1)%4==0):
        plt.ylabel(data_name, fontsize=20)
    #x0,x1 = axes[0].get_xlim()
    #y0,y1 = axes[0].get_ylim()

    d = d+1
    plt.subplot(rows, cols, d)
    plot1 = sns.kdeplot(x=proj[true==0,0], y=proj[true==0,1], cmap="Reds", shade=True, bw_adjust=.5)
    #axes[1].set_title('Singlets')
    plot1.set(yticklabels=[])
    plot1.set(xticklabels=[])
    plot1.set(ylim=(y0, y1), xlim=(x0,x1))
    if(d <= 4):
        plt.title('Real Singlets', fontsize=20)

    d = d+1
    plt.subplot(rows, cols, d)
    plot2 = sns.kdeplot(x=proj[true==1,0], y=proj[true==1,1], cmap="Blues", shade=True, bw_adjust=.5)
    #axes[2].set_title('Doublets')
    plot2.set(yticklabels=[])
    plot2.set(xticklabels=[])
    plot2.set(ylim=(y0, y1), xlim=(x0,x1))
    if(d <= 4):
        plt.title('Real Doublets', fontsize=20)

    d = d+1
    plt.subplot(rows, cols, d)
    plot3 = sns.kdeplot(x=proj[true==2,0], y=proj[true==2,1], cmap="Greens", shade=True, bw_adjust=.5)
    #axes[3].set_title('Simulated')
    plot3.set(yticklabels=[])
    plot3.set(xticklabels=[])
    plot3.set(ylim=(y0, y1), xlim=(x0,x1))
    if(d <= 4):
        plt.title('Simmulated Doublets', fontsize=20)


#plt.suptitle('UMAP of Encoding Densities', fontsize=20)
plt.tight_layout()
plt.savefig(save_path + 'ALL_PCA_UMAP_density.png', dpi=300)
#plt.show()
#plt.close(fig)
    
    

In [None]:
#get overlap scores between simulated and real doublets, simulated doubs and real cells for the two projections

In [None]:
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import silhouette_score
from cluster import cluster, fast_cluster
import sklearn

In [None]:
ari_pca = []
asw_pca = []
    
ari_enc = []
asw_enc = []

for file in files:

    data_name = file[:-19]
    print(data_name)
    
    pca_file = save_path + 'PCA_cords/' + data_name + '_PCA.npy'
    pca_proj = np.load(pca_file)
    
    enc_file_real = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_real.npy'
    enc_file_sim = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_sim.npy'
    enc_real = np.load(enc_file_real)
    enc_sim = np.load(enc_file_sim)
    
    num_doubs = enc_sim.shape[0]
    enc = np.vstack([enc_real, enc_sim])
    del(enc_real)
    del(enc_sim)
    
    ano_path  = '../data/'+ data_dir + '/' + data_name + '_anno.csv'
    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    y = np.concatenate([np.zeros(len(true)), np.ones(num_doubs)])
    true = np.concatenate([true, np.full(num_doubs, 1)])
    labels = np.concatenate([labels, np.full(num_doubs, 'sim')])
    
    '''if(X.shape[0]>=1000):
        clust_pca = fast_cluster(pca_proj, comp=5)
        clust_enc = fast_cluster(enc, comp=5)
    else:
        clust_pca = cluster(X, comp=5)
        clust_enc = cluster(enc, comp=5)'''
        
    clust_enc = sklearn.cluster.KMeans(n_clusters=2).fit_predict(enc)
    clust_pca = sklearn.cluster.KMeans(n_clusters=2).fit_predict(pca_proj)
    
    ari_pca.append(adjusted_rand_score(true, clust_pca))#big good
    asw_pca.append(silhouette_score(pca_proj, true))#small=good
    
    ari_enc.append(adjusted_rand_score(true, clust_enc))#big good
    asw_enc.append(silhouette_score(enc, true))#small=good
    
    

In [None]:
np.sum(np.array(ari_pca) - np.array(ari_enc)) #want negative

In [None]:
np.sum(np.array(asw_pca) - np.array(asw_enc)) #want positive

In [None]:
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
from scipy.io import mmread
import scipy.sparse as scs
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [None]:
dfs = []
for file in files:

    data_name = file[:-19]
    print('-------', data_name, '-------')

    ano_path  = '../data/mtx_files/' + data_name + '_anno.csv'
    
    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    real_path = '../data/mtx_files/' + data_name + '.mtx'

    print('loading in real mtx')
    dat_real = mmread(real_path)
    X = scs.csr_matrix(dat_real).toarray().T

    enc_file_real = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_real.npy'
    enc_file_sim = '../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/' + data_name + '_embedding_sim.npy'
    enc_real = np.load(enc_file_real)
    enc_sim = np.load(enc_file_sim)
    
    num_doubs = enc_sim.shape[0]
    #enc = np.vstack([enc_real, enc_sim])
    #del(enc_real)
    #del(enc_sim)
    enc=enc_real
    
    pca_file = save_path + 'PCA_cords/' + data_name + '_PCA.npy'
    pca_proj = np.load(pca_file)
    pca_proj = pca_proj[:enc_real.shape[0],:]

    #KNN
    #np.random.seed(42)
    #pca = PCA(n_components=30)
    #pca_proj = pca.fit_transform(temp_X)
    #del(temp_X)
    neighbors=5
    
    np.random.seed(42)
    knn = KNeighborsClassifier(n_neighbors=neighbors)
    knn.fit(X,true)

    preds = knn.predict(X)
    probs = knn.predict_proba(X)[:,1]
    acc1 = np.sum(preds==true)/len(true)
    probs[true==0] = 1-probs[true==0]
    scores1 = np.mean(probs)
    
    np.random.seed(42)
    knn = KNeighborsClassifier(n_neighbors=neighbors)
    knn.fit(enc,true)

    preds = knn.predict(enc)
    probs = knn.predict_proba(enc)[:,1]
    acc2 = np.sum(preds==true)/len(true)
    probs[true==0] = 1-probs[true==0]
    scores2 = np.mean(probs)
    
    np.random.seed(42)
    knn = KNeighborsClassifier(n_neighbors=neighbors)
    knn.fit(pca_proj,true)

    preds = knn.predict(pca_proj)
    probs = knn.predict_proba(pca_proj)[:,1]
    acc3 = np.sum(preds==true)/len(true)
    probs[true==0] = 1-probs[true==0]
    scores3 = np.mean(probs)
    
    
    df = pd.DataFrame({'knn scores': [acc1,scores1,
                                      acc2,scores2,
                                      acc3,scores3], 
                       'which': ['accuracy', 'average label agreement', 'accuracy', 'average label agreement', 'accuracy', 'average label agreement'],
                       'representation': ['raw', 'raw', 'enc', 'enc', 'pca', 'pca']})
    df['data_name'] = data_name
    
    dfs.append(df)
    
    

In [None]:
dfs

In [None]:
df = pd.concat(dfs)
df

In [None]:
df.to_csv(save_path + 'df.csv')

In [None]:
df = pd.read_csv(save_path + 'df.csv')

In [None]:
df

In [None]:
#df = pd.concat(dfs)
df = df.loc[df.which == 'accuracy',:]
df = df.loc[df.representation != 'pca',:]
df

In [None]:
MAX_WIDTH = 6.726

sns.set(rc={"figure.figsize":(MAX_WIDTH,3)})
sns.set_style("white")

fig, ax1 = plt.subplots(figsize=(MAX_WIDTH,3), dpi=300)
v = sns.violinplot(x='representation', y='knn scores', data=df, linewidth=1, edgecolor='black', color='white', inner=None)# ax2.set_aspect("equal")
#plt.setp(v.collections, alpha=.3)
sns.swarmplot(x='representation', y='knn scores', data=df, marker='o', s=7, dodge=False, linewidth=0, hue='data_name', palette=sns.color_palette("hls", 16))
df_means = df.groupby(['representation'])['knn scores'].agg('mean').reset_index()
df_means = df_means.iloc[[1,0],:]
sns.swarmplot(x='representation', y='knn scores', data=df_means, marker='o', s=3, dodge=True, linewidth=1, edgecolor='black', color='black')

plt.legend(bbox_to_anchor=(1,1), loc="upper left", fontsize=6, markerscale=0.75)
plt.title('Local Label Information')
plt.ylabel('local label accuracy')
plt.xlabel('representation')
plt.tick_params(axis='both', which='major', labelsize=10)

#make square
x0,x1 = v.get_xlim()
y0,y1 = v.get_ylim()
ax1.set_aspect(abs(x1-x0)/abs(y1-y0))


plt.tight_layout()
plt.savefig(save_path + 'knn.png', dpi=300)

In [None]:
df.groupby(['representation']).agg('mean').reset_index()

In [None]:
df = pd.concat(dfs)
df = df.loc[df.which == 'average label agreement',:]
df = df.loc[df.representation != 'pca',:]

sns.set(rc={"figure.figsize":(7, 5)})
sns.set_style("white")

fig, ax1 = plt.subplots()
v = sns.violinplot(x='representation', y='knn scores', data=df, linewidth=1, edgecolor='black', color='white')# ax2.set_aspect("equal")
#plt.setp(v.collections, alpha=.3)
sns.swarmplot(x='representation', y='knn scores', data=df, marker='o', s=10, dodge=False, linewidth=0, hue='data_name', palette=sns.color_palette("hls", 16))
df_means = df.groupby(['representation'])['knn scores'].agg('mean').reset_index()
sns.swarmplot(x='representation', y='knn scores', data=df_means, marker='o', s=5, dodge=True, linewidth=1, edgecolor='black', color='black')

plt.legend(bbox_to_anchor=(1,1), loc="upper left")
plt.title('')

plt.tight_layout()
#plt.savefig(save_path + 'knn.png')


In [None]:
df.groupby(['representation']).agg('mean').reset_index()