In [30]:
import findspark
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types

findspark.init()

spark: SparkSession = SparkSession.builder \
    .appName("JupyterLocalSpark") \
    .master("local[*]") \
    .getOrCreate()

In [28]:
df = spark.read.options(header=True, inferSchema=True).csv("data/m13-stars.csv")

df.printSchema()
print("Rows number", df.count())

root
 |-- source_id: long (nullable = true)
 |-- ra: double (nullable = true)
 |-- dec: double (nullable = true)
 |-- parallax: double (nullable = true)
 |-- parallax_error: double (nullable = true)
 |-- phot_g_mean_mag: double (nullable = true)
 |-- bp_rp: double (nullable = true)
 |-- target_separation: double (nullable = true)

Rows number 69994


In [40]:
null_count_df = df.select(
    [
        F.sum(F.col(c).isNull().cast("int")).alias(c)
        for c in df.columns
    ]
)
null_count_df.show()

+---------+---+---+--------+--------------+---------------+-----+-----------------+
|source_id| ra|dec|parallax|parallax_error|phot_g_mean_mag|bp_rp|target_separation|
+---------+---+---+--------+--------------+---------------+-----+-----------------+
|        0|  0|  0|   15420|         15420|            123|17980|                0|
+---------+---+---+--------+--------------+---------------+-----+-----------------+



In [41]:
df = df.dropna()
print("Rows number", df.count())

Rows number 45996


In [52]:
df_sample = df.select([
    "parallax",
    "parallax_error",
    "phot_g_mean_mag",
    "bp_rp",
    "target_separation"
]).sample(fraction=0.1, seed=42)

df_pd = df_sample.toPandas()
display(df_pd)

Unnamed: 0,parallax,parallax_error,phot_g_mean_mag,bp_rp,target_separation
0,0.433953,0.041290,16.440634,0.797240,0.671807
1,0.320204,0.034686,16.151880,0.750322,0.700445
2,0.379702,0.011681,13.321111,1.152290,0.713399
3,0.270067,0.491039,20.291859,1.519379,0.729202
4,0.167937,0.166044,18.942867,1.016001,0.732583
...,...,...,...,...,...
3611,0.397041,0.061824,17.212860,0.900965,0.541075
3612,0.519271,0.089306,17.915375,1.039572,0.564509
3613,0.514464,0.046244,16.582232,0.979399,0.571518
3614,-0.118273,0.111627,18.408377,0.968681,0.580732


In [53]:
df = df.where(F.col("parallax_error") < 0.5)
df.count()

35407