In [3]:
import numpy as np
import pandas as pd
import glob
import sys
import h5py
#from netCDF4 import Dataset
from datetime import datetime
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree

import pyarrow as pa
import pyarrow.parquet as pqw

from functools import reduce
import operator
import gc

In [4]:
# plot settings
plt.rc('font', family='serif') 
plt.rc('font', serif='Times New Roman') 
plt.rcParams.update({'font.size': 16})
plt.rcParams['mathtext.fontset'] = 'stix'

# Initiate a spark session

In [5]:
# PySpark packages
from pyspark import SparkContext   
from pyspark.sql import SparkSession

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark import Row
from pyspark.sql.window import Window as W

#spark = SparkSession.builder \
#    .master("yarn") \
#    .appName("spark-shell") \
#    .config("spark.driver.maxResultSize", "32g") \
#    .config("spark.driver.memory", "32g") \
#    .config("spark.executor.memory", "6g") \
#    .config("spark.executor.cores", "1") \
#    .config("spark.executor.instances", "30") \
#    .getOrCreate()

#    .config("spark.executor.memory", "14g") \
#    .config("spark.executor.cores", "2") \
#    .config("spark.executor.instances", "60") \
#    .config("spark.executor.memory", "6g") \
#    .config("spark.jars.packages", "graphframes:graphframes:0.7.0-spark2.4-s_2.11") \
#    .getOrCreate()

spark = SparkSession.builder \
    .appName("MyApp") \
    .master("spark://sohnic:7077") \
    .config("spark.driver.memory", "100g") \
    .getOrCreate()

sc = spark.sparkContext
sc.setCheckpointDir("hdfs://sohnic:54310/tmp/checkpoints")

spark.conf.set("spark.sql.debug.maxToStringFields", 500)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/21 14:58:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [6]:
# check the spark configuration
sc.getConf().getAll()[:10]

[('spark.app.submitTime', '1740117521067'),
 ('spark.rdd.compress', 'True'),
 ('spark.app.startTime', '1740117521260'),
 ('spark.master', 'spark://sohnic:7077'),
 ('spark.driver.port', '40333'),
 ('spark.app.name', 'MyApp'),
 ('spark.sql.warehouse.dir',
  'file:/home/nowan/shape_measurement/spark-warehouse'),
 ('spark.executor.id', 'driver'),
 ('spark.submit.pyFiles', ''),
 ('spark.driver.host', 'sohnic')]

# Define particle data directories to be read

In [19]:
ptl_files = np.sort(glob.glob('/data/TNG300/TNG300/snap_099/snap099_*.csv')) # pathes of particle data files

print(f"The number of particle data files: {len(ptl_files)}")
print("First five data files:")
print(ptl_files[:5])

The number of particle data files: 600
First five data files:
['../TNG100/output/snap_099/sorted/combined/snap099_sorted_x0_y0_z0.csv'
 '../TNG100/output/snap_099/sorted/combined/snap099_sorted_x0_y0_z1.csv'
 '../TNG100/output/snap_099/sorted/combined/snap099_sorted_x0_y0_z2.csv'
 '../TNG100/output/snap_099/sorted/combined/snap099_sorted_x0_y0_z3.csv'
 '../TNG100/output/snap_099/sorted/combined/snap099_sorted_x0_y0_z4.csv']


In [20]:
%%time
# An example data from a particle data file
df = pd.read_csv(ptl_files[0])
df.head()
#h5f.keys()

CPU times: user 33.3 ms, sys: 17.2 ms, total: 50.5 ms
Wall time: 49.1 ms


Unnamed: 0,px,py,pz,vx,vy,vz,mass,ID,Formation_time
0,4816.95087,943.885776,11228.531968,306.1135,78.89162,114.42368,7.5e-05,122500826831,0.989542
1,4817.317051,945.470532,11228.469476,296.45004,84.64811,94.15483,8.1e-05,121755914935,0.871483
2,4817.384923,943.359975,11228.005229,276.48822,97.95877,118.08153,5.1e-05,118229119755,0.591693
3,4817.536942,944.688338,11229.09531,270.81726,99.322174,101.76142,7.4e-05,117845451181,0.55012
4,4816.251953,943.217905,11229.008699,314.01993,72.1817,106.111336,5.5e-05,119612834109,0.802193


# Save selected features as a parquet
We will convert all particle data files from .csv to .parquet formats.

The blow cells show how to convert one particle data file into .parquet format.

If you need to convert all particle data files and do not need to know how to convert them individually, skip this section and go to the next section

## Construct a Spark DataFrame
First, we save the particle data from the particle data files into a Spark Data Frame named sparkdf

In [9]:
# Define a data structure for a Spark Data Frame
# T.StructType([...]): define a structured schema for a DataFrame
# T.StructField(name, data_type, nullable)
#  name: column name
#  data_type: Data type (IntegerType(), StringType(), ShortType(), DoubleType(), BooleanType(), ...)
#  nullable: If True, the column can have Null values
schema = T.StructType([\
                       T.StructField('px',T.FloatType(), True),\
                       T.StructField('py',T.FloatType(), True),\
                       T.StructField('pz',T.FloatType(), True),\
                       T.StructField('vx',T.FloatType(), True),\
                       T.StructField('vy',T.FloatType(), True),\
                       T.StructField('vz',T.FloatType(), True),\
                       T.StructField('mass',T.FloatType(), True),\
                      ])

In [21]:
%%time
# Generate a Spark Data Frame named spakrdf according to the data in the pandas dataframe and the data type of schema
sparkdf = spark.createDataFrame(df[['px', 'py', 'pz', 'vx', 'vy', 'vz', 'mass']],schema)

  elif is_categorical_dtype(s.dtype):


CPU times: user 47.1 ms, sys: 11.1 ms, total: 58.1 ms
Wall time: 275 ms


In [18]:
%%time
print("Top three rows of sparkdf:")
sparkdf.show(3,truncate=True)
print()

nrow = sparkdf.count()
print(f"The number of rows: {nrow}")
print()

print("The structure of sparkdf:")
print(sparkdf.printSchema())
print()

Top three rows of sparkdf:
+---------+--------+---------+---------+--------+---------+-----------+
|       px|      py|       pz|       vx|      vy|       vz|       mass|
+---------+--------+---------+---------+--------+---------+-----------+
|4816.9507|943.8858|11228.532| 306.1135|78.89162|114.42368|7.473642E-5|
| 4817.317|945.4705| 11228.47|296.45004|84.64811| 94.15483|8.121476E-5|
| 4817.385|  943.36|11228.005|276.48822|97.95877|118.08153|5.140894E-5|
+---------+--------+---------+---------+--------+---------+-----------+
only showing top 3 rows

The number of rows: 25271
The structure of sparkdf:
root
 |-- px: float (nullable = true)
 |-- py: float (nullable = true)
 |-- pz: float (nullable = true)
 |-- vx: float (nullable = true)
 |-- vy: float (nullable = true)
 |-- vz: float (nullable = true)
 |-- mass: float (nullable = true)

None
CPU times: user 5.78 ms, sys: 6.49 ms, total: 12.3 ms
Wall time: 479 ms


# Converting all particle files (here, .csv file) into one parquet file

In [None]:
%%time

schema = T.StructType([\
                       T.StructField('px',T.FloatType(), True),\
                       T.StructField('py',T.FloatType(), True),\
                       T.StructField('pz',T.FloatType(), True),\
                       T.StructField('vx',T.FloatType(), True),\
                       T.StructField('vy',T.FloatType(), True),\
                       T.StructField('vz',T.FloatType(), True),\
                       T.StructField('mass',T.FloatType(), True),\
                      ])

# create an empty spark DataFrame for saving all particle data
sparkdf = spark.createDataFrame([], schema)

for i in tqdm(range(len(ptl_files))):
    # read i-th data file (.csv) of the particle data files and convert to a spark DataFrame
    df = pd.read_csv(ptl_files[i])
    tempdf = spark.createDataFrame(df[['px', 'py', 'pz', 'vx', 'vy', 'vz', 'mass']], schema)
    # Append to the spark DataFrame for all particle data
    sparkdf = sparkdf.union(tempdf)  



In [None]:
# Save to Parquet
outfile = "hdfs://sohnic:54310/data/TNG100/output/snap_099/particle.parquety.snappy" # make sure to start with 'hdfs://sohnic:54310' because we will save it in hadoop directory, not a real directory.
sparkdf.write.option("compression", "snappy").mode("overwrite").save(outfile)
# .option(key, value): set saving configuration.
# "compression": determines the compression codec for saving the spark DataFrame.
# "snappy" is a default compression option for Parquet. It is very fast and has a moderate compression ratio.
# other options are Gzip, Bzip2, LZ4, ZSTD. Ask ChatGPT if you want to know their properties.
# mode: specifies what happens if the output file already exists.
# "overwrite": overwrite an existing files at the output path.
# "append": add new data to existing files.
# "ignore": skips writing if the output file already exists.
# "error" or "errorifexists": raise an error if the output file exists.

# Read the saved parquet file

In [18]:
%%time
newsparkdf = spark.read.option("header","true").option("recursiveFileLookup","true").parquet(outname)
# hearder=="true": use the first row to infer column names. Actually, it is unnecessary if you read a Parquet file because the Parquet file independently stores column names in its format.
# recursiveFileLookup="true"
#  Allows Spark to recursively search through subdirectories within the specified path (outname) for Parquet files.
#  If outname is a directory with nested folders, Spark will read all Parquet files within those folders.

CPU times: user 3.39 ms, sys: 0 ns, total: 3.39 ms
Wall time: 301 ms


In [None]:
%%time
# summarize the spark Data Frame
newsparkdf.describe().toPandas()
# describe(): summarize the data in the spark Data Frame
# .toPandas(): convert the spark Data Frame into a pandas data frame

# Saving subhalo catalog

In [32]:
#subhalo table (~few seconds)
t300subhalo = pd.read_csv('subhalocat300.txt', sep=' ')
t300subhalo.head()

Unnamed: 0,SubfindID,px,py,pz,vx,vy,vz,StarHalfMass
0,0,43718.8125,48813.640625,147594.953125,472.196198,450.850006,-260.746918,265.473969
1,1,45442.273438,51850.199219,146416.5,-209.056656,-735.888916,400.641724,126.83107
2,2,44490.761719,49091.714844,147870.578125,2021.729492,1495.440186,-1797.082153,28.68231
3,3,43820.785156,50939.398438,147711.046875,925.150391,-473.445465,-275.925934,11.954713
4,4,44302.578125,49630.972656,147869.484375,-260.21463,-2221.625244,-563.641296,11.029386


In [34]:
%%time

schema_sub = T.StructType([\
                       T.StructField('px',T.FloatType(), True),\
                       T.StructField('py',T.FloatType(), True),\
                       T.StructField('pz',T.FloatType(), True),\
                       T.StructField('vx',T.FloatType(), True),\
                       T.StructField('vy',T.FloatType(), True),\
                       T.StructField('vz',T.FloatType(), True),\
                       T.StructField('StarHalfMass',T.FloatType(), True),\
                       T.StructField('sub_id',T.IntegerType(), True)
                      ])

SubhaloDf = spark.createDataFrame(t300subhalo[['px', 'py', 'pz', 'vx', 'vy', 'vz', 'StarHalfMass', 'SubfindID']], schema_sub)
SubhaloFile = 'hdfs://sohnic:54310/data/TNG300/snap99/subhalo.parquet.snappy'
SubhaloDf.write.option("compression", "snappy").mode("overwrite").save(SubhaloFile)

  elif is_categorical_dtype(s.dtype):
                                                                                

CPU times: user 3.79 s, sys: 443 ms, total: 4.24 s
Wall time: 14.6 s


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 58990)
Traceback (most recent call last):
  File "/usr/lib/python3.10/socketserver.py", line 316, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/lib/python3.10/socketserver.py", line 347, in process_request
    self.finish_request(request, client_address)
  File "/usr/lib/python3.10/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/lib/python3.10/socketserver.py", line 747, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 281, in handle
    poll(accum_updates)
  File "/usr/local/spark/python/pyspark/accumulators.py", line 253, in poll
    if func():
  File "/usr/local/spark/python/pyspark/accumulators.py", line 257, in accum_updates
    num_updates = read_int(self.rfile)
  File "/usr/local/spark/python/pyspark/serializers.py",

# repartition

In [24]:
%%time
# similar to "refresh"
newsparkdf.cache()
newsparkdf.repartition(10,"px").count()

CPU times: user 4.09 ms, sys: 40 Âµs, total: 4.13 ms
Wall time: 743 ms


765683