In [39]:
import datetime

import pandas as pd
import pywt
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, FloatType, IntegerType, ArrayType
from pyspark.sql import functions as F
from pyspark.ml.linalg import Vectors, VectorUDT

In [2]:
spark = SparkSession.builder.getOrCreate()

In [24]:
schema = StructType([
    StructField("a", IntegerType()),
    StructField("b", IntegerType()),
    StructField("c", IntegerType()),
    StructField("d", ArrayType(FloatType()))
])

In [25]:
rows = [
    [1, 2, 3, [0.1, 0.2, 0.3, 0.4]],
    [1, 2, 3, [0.1, 0.2, 0.3, 0.4]],
    [1, 2, 3, [0.1, 0.2, 0.3, 0.4]]
]

df = spark.createDataFrame(data=rows, schema=schema)
df.show()

+---+---+---+--------------------+
|  a|  b|  c|                   d|
+---+---+---+--------------------+
|  1|  2|  3|[0.1, 0.2, 0.3, 0.4]|
|  1|  2|  3|[0.1, 0.2, 0.3, 0.4]|
|  1|  2|  3|[0.1, 0.2, 0.3, 0.4]|
+---+---+---+--------------------+



In [74]:
import numpy as np

@F.udf(returnType=VectorUDT())
def tovector(x):
    return Vectors.dense(x)

@F.udf(returnType=ArrayType(FloatType()))
def dwt(x):
    cA, _ = pywt.dwt(x, wavelet="haar")
    return cA.tolist()

In [78]:
df.select(tovector(dwt("d"))).show(truncate=False)

+----------------------------------------+
|tovector(dwt(d))                        |
+----------------------------------------+
|[0.2121320375169779,0.49497475947463787]|
|[0.2121320375169779,0.49497475947463787]|
|[0.2121320375169779,0.49497475947463787]|
+----------------------------------------+

