<a href="https://colab.research.google.com/github/hsabaghpour/PySpark_MLlib_repo/blob/main/PySpark_MLlib_Correlation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('corr').getOrCreate()

In [4]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import Correlation

data = [(Vectors.sparse(4, [(0, 1.0), (3, -2.0)]),),
        (Vectors.dense([4.0, 5.0, 0.0, 3.0]),),
        (Vectors.dense([6.0, 7.0, 0.0, 8.0]),),
        (Vectors.sparse(4, [(0, 9.0), (3, 1.0)]),)]


In [5]:
data

[(SparseVector(4, {0: 1.0, 3: -2.0}),),
 (DenseVector([4.0, 5.0, 0.0, 3.0]),),
 (DenseVector([6.0, 7.0, 0.0, 8.0]),),
 (SparseVector(4, {0: 9.0, 3: 1.0}),)]

In [7]:
df = spark.createDataFrame(data, ["features"])
df.show()


+--------------------+
|            features|
+--------------------+
|(4,[0,3],[1.0,-2.0])|
|   [4.0,5.0,0.0,3.0]|
|   [6.0,7.0,0.0,8.0]|
| (4,[0,3],[9.0,1.0])|
+--------------------+



In [8]:
df.printSchema()

root
 |-- features: vector (nullable = true)



#**Correlation**
Calculating the correlation between two series of data is a common operation in Statistics. In spark.ml we provide the flexibility to calculate pairwise correlations among many series. The supported correlation methods are currently **Pearson’s and Spearman’s correlation**.



In [9]:
r1 = Correlation.corr(df, "features").head()


print("Pearson correlation matrix:\n" + str(r1[0]))

r2 = Correlation.corr(df, "features", "spearman").head()


print("Spearman correlation matrix:\n" + str(r2[0]))

Pearson correlation matrix:
DenseMatrix([[1.        , 0.05564149,        nan, 0.40047142],
             [0.05564149, 1.        ,        nan, 0.91359586],
             [       nan,        nan, 1.        ,        nan],
             [0.40047142, 0.91359586,        nan, 1.        ]])
Spearman correlation matrix:
DenseMatrix([[1.        , 0.10540926,        nan, 0.4       ],
             [0.10540926, 1.        ,        nan, 0.9486833 ],
             [       nan,        nan, 1.        ,        nan],
             [0.4       , 0.9486833 ,        nan, 1.        ]])
