In [1]:
import numpy as np
import pandas as pd

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession\
    .builder\
    .enableHiveSupport()\
    .config("spark.dynamicAllocation.enabled", False)\
    .config("spark.executor.memory", "4g")\
    .getOrCreate()

In [3]:
arrays = [
  ([1, 2], [2, 4]),
  ([1, 1], [-1, -1]),
  ([1,2], [2, 1])  
]

schema = T.StructType([
    T.StructField('a', T.ArrayType(T.IntegerType()), nullable=False),
    T.StructField('b', T.ArrayType(T.IntegerType()), nullable=False)
])

df = spark.createDataFrame(
  pd.DataFrame(
    arrays
  ),
  schema
)

In [4]:
df.show()

+------+--------+
|     a|       b|
+------+--------+
|[1, 2]|  [2, 4]|
|[1, 1]|[-1, -1]|
|[1, 2]|  [2, 1]|
+------+--------+



In [5]:
def cosine_similarity(a, b):
  return float(np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b)))

cosine_similarity_udf = F.udf(cosine_similarity, T.FloatType())

def compute_similarity(df):
  rst = df\
  .withColumn("cos_sim", cosine_similarity_udf(F.col("a"), F.col("b")))\
  .select("cos_sim")
  return rst

df_rst = compute_similarity(df)
df_rst.show()

+-------+
|cos_sim|
+-------+
|    1.0|
|   -1.0|
|    0.8|
+-------+



In [28]:
import unittest
import logging
import warnings

from pyspark.sql import SparkSession


class PysparkTestCase(unittest.TestCase):

  @classmethod
  def suppress_py4j_logging(cls):
    warnings.filterwarnings(
      action="ignore",
      message="unclosed",
      category=ResourceWarning)
    logger = logging.getLogger("py4j")
    logger.setLevel(logging.ERROR)

  @classmethod
  def setUpClass(cls):
    cls.suppress_py4j_logging()
    cls.spark = SparkSession.builder.getOrCreate()
    cls.spark.sparkContext.setLogLevel('WARN')

  @classmethod
  def tearDownClass(cls):
    cls.spark.stop()

  @classmethod
  def is_dataframe_equal(cls, df1, df2):
    if not df1.schema.simpleString() == df2.schema.simpleString():
      print(df1.schema.simpleString())
      print(df2.schema.simpleString())
      return False

    if df1.join(df2, df1.columns, "left_anti").count() > 0:
      return False
    return True

In [29]:
from utils import compute_similarity

In [30]:
class UtilsTest(PysparkTestCase):

    def test_compute_similarity(self):
        arrays = [
            ([1, 2], [2, 4], 1.0),
            ([1, 1], [-1, -1], -1.0),
            ([1, 2], [2, 1], 0.8)
        ]

        schema = T.StructType([
            T.StructField('a', T.ArrayType(T.IntegerType()), nullable=False),
            T.StructField('b', T.ArrayType(T.IntegerType()), nullable=False),
            T.StructField('c', T.FloatType(), nullable=False)
        ])
        
        df = self.spark.createDataFrame(
            pd.DataFrame(
                arrays
            ),
            schema
        )
        df_expect = df.select(F.col('c'))

        df_test = compute_similarity(df)

        self.assertTrue(
            self.is_dataframe_equal(
                df_test,
                df_expect
            )
        )

In [31]:
import unittest

In [32]:
runner = unittest.TextTestRunner()

In [33]:
runner.run(utils_test.test_compute_similarity())

AttributeError: 'UtilsTest' object has no attribute 'spark'