In [1]:
import os
import sys
import pandas
import numpy

import findspark
findspark.init("/usr/local/spark/spark")

import pyspark

from pyspark.ml.feature import PCA
from pyspark.ml.linalg import Vector, Vectors
from pyspark.ml.feature import StandardScaler
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

In [2]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

In [3]:
file_name = "/Users/simondi/Desktop/test_ba/kmeans_transform-cells_sample_10_normalized_cut_100_K005"

In [4]:
conf = pyspark.SparkConf().setMaster("local[*]").set("spark.driver.memory", "4G").set("spark.executor.memory", "4G")
sc = pyspark.SparkContext(conf=conf)
spark = pyspark.sql.SparkSession(sc)

In [5]:
data = spark.read.parquet(file_name)

In [6]:
data = data.limit(30)

In [7]:
data.take(1)

[Row(study='infectx_published', pathogen='listeria', library='a', design='u', replicate='1', plate='kb2-02-1x', well='a01', gene='chka', sirna='s3008', well_type='sirna', image_idx='4', object_idx='144', prediction=0, features=DenseVector([-0.8044, 0.0121, 1.1159, 1.4749, -0.9369, -0.7485, -1.0209, -0.703, 0.0, 0.4476, 1.2809, 1.3916, 0.1489, 0.6694, -1.2335, -0.0825, 0.0106, -0.5078, 1.2455, 0.1357]))]

In [8]:
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=True)
scalerModel = scaler.fit(data)
data = scalerModel.transform(data)

In [9]:
pca = PCA(k=2, inputCol="scaledFeatures", outputCol="pcs")

In [10]:
model = pca.fit(data)

In [11]:
data = model.transform(data)

In [12]:
model.explainedVariance

DenseVector([0.263, 0.2029])

In [13]:
counts = data.groupBy(["pathogen", "gene", "sirna"]).count()

In [14]:
window = Window.partitionBy(["pathogen", "gene", "sirna"]).rowsBetween(0, 10)

In [15]:
data = data.withColumn(
    "row_num",
    row_number().over(Window.partitionBy(["pathogen", "gene"]).orderBy(["pathogen", "gene"])))

In [22]:
data.select(["pathogen", "gene", "row_num"]).take(5)

[Row(pathogen='listeria', gene='chka', row_num=1),
 Row(pathogen='listeria', gene='chka', row_num=2),
 Row(pathogen='listeria', gene='chka', row_num=3),
 Row(pathogen='listeria', gene='chka', row_num=4),
 Row(pathogen='listeria', gene='chka', row_num=5)]

In [57]:
data = data.filter("row_num <= 10")

In [58]:
genes  = [i.gene for i in data.select("gene").distinct().sample(False, fraction=.5).limit(100).collect()]

In [64]:
genes

['chka', 'grk5']

In [70]:
data.where(data.gene.isin(genes)).take(5)

[Row(study='infectx_published', pathogen='listeria', library='a', design='u', replicate='1', plate='kb2-02-1x', well='a01', gene='chka', sirna='s3008', well_type='sirna', image_idx='4', object_idx='144', prediction=0, features=DenseVector([-0.8044, 0.0121, 1.1159, 1.4749, -0.9369, -0.7485, -1.0209, -0.703, 0.0, 0.4476, 1.2809, 1.3916, 0.1489, 0.6694, -1.2335, -0.0825, 0.0106, -0.5078, 1.2455, 0.1357]), scaledFeatures=DenseVector([-0.7828, 0.1618, 1.1262, 1.6892, -0.9413, -0.6545, -1.0412, -0.885, 0.0, 0.4695, 1.3531, 1.2438, 0.1884, 1.1718, -1.3546, -0.1607, -0.004, -0.7102, 1.6329, 0.1423]), pcs=DenseVector([2.1838, 1.6822]), row_num=1),
 Row(study='infectx_published', pathogen='listeria', library='a', design='u', replicate='1', plate='kb2-02-1x', well='a01', gene='chka', sirna='s3008', well_type='sirna', image_idx='5', object_idx='168', prediction=0, features=DenseVector([-0.9474, -0.0399, 0.1262, 0.6198, -1.0081, -0.7959, -0.9485, -0.6324, 0.0, 0.4263, -0.2104, -0.3939, 0.3941, -0.3

In [17]:
datap = data.select(
      ["pathogen", "gene", "sirna", "prediction", "pcs", "scaledFeatures"]).toPandas()

In [18]:
new_col_names = ["Feature_{}".format(x) for x in range(len(datap.loc[0, "scaledFeatures"]))]

['chka', 'chuk', 'grk5']

In [26]:
data_p = data.select(["pathogen", "gene", "sirna", "prediction", "scaledFeatures", "pcs"]).toPandas()
data_p[['pc1','pc2']] = pandas.DataFrame(data_p.pcs.values.tolist())
data_p[ new_col_names ] = pandas.DataFrame(data_p.scaledFeatures.values.tolist())

In [27]:
data_p

Unnamed: 0,pathogen,gene,sirna,prediction,scaledFeatures,pcs,pc1,pc2,Feature_0,Feature_1,...,Feature_10,Feature_11,Feature_12,Feature_13,Feature_14,Feature_15,Feature_16,Feature_17,Feature_18,Feature_19
0,listeria,chka,s3008,0,"[-0.782821944903, 0.161762490353, 1.1261790614...","[2.18380198934, 1.68224190455]",2.183802,1.682242,-0.782822,0.161762,...,1.353126,1.243849,0.188372,1.171781,-1.354577,-0.160689,-0.00398,-0.710206,1.63292,0.142251
1,listeria,chka,s3008,0,"[-0.936524096425, 0.117235546461, 0.0718197915...","[1.8602040896, 0.265992544394]",1.860204,0.265993,-0.936524,0.117236,...,-0.060292,-0.983632,0.43091,-0.056873,1.104428,-0.763271,0.246773,-1.290418,-0.641438,0.439121
2,listeria,chka,s3008,2,"[-0.00830048862969, 1.11416647686, 1.813571601...","[1.24837827101, -1.7653434898]",1.248378,-1.765343,-0.0083,1.114166,...,-0.91813,0.380662,1.000875,0.17764,1.17198,1.647058,1.058735,0.506101,-0.719555,0.958645
3,listeria,chka,s3008,4,"[-0.255597805045, 0.756608747896, -0.159959519...","[1.76119647336, -0.371681379815]",1.761196,-0.371681,-0.255598,0.756609,...,0.293918,0.179947,0.928114,0.053641,-1.418306,0.441894,0.95127,0.931444,1.013221,0.921536
4,listeria,chka,s3008,2,"[-0.272771229796, -1.55467060958, 1.7289851216...","[1.94836107216, -2.96134530907]",1.948361,-2.961345,-0.272771,-1.554671,...,-0.189831,1.227685,1.352556,0.013928,1.610767,1.647058,1.416954,0.264799,-1.375989,1.366842
5,listeria,chka,s3008,4,"[0.299533150033, -0.711028249524, -0.309593293...","[1.19544783543, -1.65043423843]",1.195448,-1.650434,0.299533,-0.711028,...,0.835372,-0.518081,0.891733,-0.810791,-0.985015,-0.763271,1.034854,1.076197,-0.698528,0.946275
6,listeria,chka,s3008,1,"[2.65315101217, -1.90727663834, -0.90018094079...","[-4.58228009064, -3.91064787259]",-4.58228,-3.910648,2.653151,-1.907277,...,0.643223,-1.463657,-0.114802,0.011814,1.659834,2.24964,0.067664,1.515402,0.908926,-0.142251
7,listeria,chka,s3008,4,"[-1.16450131, 0.0120308320822, 0.711968717046,...","[3.05146373346, -0.273217793642]",3.051464,-0.273218,-1.164501,0.012031,...,-0.776496,-0.725707,0.879606,-0.390989,0.088138,-0.763271,0.903507,-0.052281,-0.996563,0.983384
8,listeria,chka,s3008,3,"[0.08658268312, 1.2546556183, -1.39021119143, ...","[-1.87303336432, 2.83981669016]",-1.873033,2.839817,0.086583,1.254656,...,1.435098,1.130797,-1.278986,1.346561,-0.284924,-0.160689,-1.556262,-1.10466,0.198373,-1.379212
9,listeria,chka,s3008,4,"[-1.26367783793, 0.511720058032, -1.1636993586...","[2.49649619964, 0.810092587286]",2.496496,0.810093,-1.263678,0.51172,...,-0.720857,-1.110951,0.491545,-0.313517,-0.708409,-0.763271,0.58111,0.503796,-0.654743,0.562817


In [75]:
uniq_genes = list(set(data_p["gene"]))
uniq_pathogen = list(set(data_p["pathogen"]))
uniq_sirnas = list(set(data_p["sirna"]))
uniq_clusts = list(set(data_p["prediction"]))

In [79]:
uniq_clusts

[0, 1, 2, 3, 4]

In [103]:
colors = plt.cm.rainbow(numpy.linspace(0, 1, len(uniq_clusts)))
colors

array([[  5.00000000e-01,   0.00000000e+00,   1.00000000e+00,
          1.00000000e+00],
       [  1.96078431e-03,   7.09281308e-01,   9.23289106e-01,
          1.00000000e+00],
       [  5.03921569e-01,   9.99981027e-01,   7.04925547e-01,
          1.00000000e+00],
       [  1.00000000e+00,   7.00543038e-01,   3.78411050e-01,
          1.00000000e+00],
       [  1.00000000e+00,   1.22464680e-16,   6.12323400e-17,
          1.00000000e+00]])

In [108]:
data_p.loc[:, "color"] = colors[data_p.prediction]

ValueError: Must have equal len keys and value when setting with an ndarray

In [106]:
data_p

Unnamed: 0,pathogen,gene,sirna,prediction,pcs,pc1,pc2
0,listeria,chka,s3008,0,"[1.76487012551, 1.61088575943]",1.76487,1.610886
1,listeria,chka,s3008,0,"[1.63468301131, 0.594261298524]",1.634683,0.594261
2,listeria,chka,s3008,2,"[1.75772043796, -1.00642121456]",1.75772,-1.006421
3,listeria,chka,s3008,4,"[2.02350481383, 0.100451937143]",2.023505,0.100452
4,listeria,chka,s3008,2,"[2.25783110217, -2.77525388643]",2.257831,-2.775254
5,listeria,chka,s3008,4,"[1.53234179511, -1.49438655759]",1.532342,-1.494387
6,listeria,chka,s3008,1,"[-3.37821265856, -4.81525978572]",-3.378213,-4.81526
7,listeria,chka,s3008,4,"[2.81887922754, 0.382042590338]",2.818879,0.382043
8,listeria,chka,s3008,3,"[-2.03657176404, 2.7128186201]",-2.036572,2.712819
9,listeria,chka,s3008,4,"[2.16869401353, 1.386068262]",2.168694,1.386068


In [100]:
hot = plt.get_cmap('hot')
colors = plt.cm.rainbow(np.linspace(0, 1, len(uniq_clusts)))

plt.figure()

plt.scatter(data_p.loc[:, "pc1"], data_p.loc[:, "pc2"], color=colors)
    
#uid = str(uuid.uuid1())
plt.show()

NameError: name 'np' is not defined

In [66]:
font = {'weight': 'normal',
            'family': 'sans-serif',
            'size': 14}
plt.rc('font', **font)
plt.figure()
ax = plt.subplot(111)
plt.tick_params(axis="both", which="both", bottom="off", top="off",
                labelbottom="on", left="off", right="off", labelleft="on")
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(True)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(True)

In [None]:
data

In [None]:
ax.plot(ks, score, "black")
    ax.plot(ks, score, "or")
    plt.xlabel('K', fontsize=15)
    plt.ylabel(axis_label, fontsize=15)
    plt.title('')
    ax.grid(True)
    logger.info("Saving plot to: {}".format(plotfile))
    plt.savefig(plotfile, bbox_inches="tight")