# train species subset for pytorch webinar

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark(cores=4, memory="20g")
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/23 15:50:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/03/23 15:50:09 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Sun Mar 23 03:50:11 PM EDT 2025


In [4]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/parquet/"

# Define the path to the train and test parquet files
train_path = f"{data_path}/train"

# Read the parquet files into a spark DataFrame
train_df = spark.read.parquet(train_path)

# Show the data
train_df.printSchema()
train_df.show(n=5)

root
 |-- image_name: string (nullable = true)
 |-- path: string (nullable = true)
 |-- data: binary (nullable = true)
 |-- organ: string (nullable = true)
 |-- species_id: integer (nullable = true)
 |-- obs_id: long (nullable = true)
 |-- license: string (nullable = true)
 |-- partner: string (nullable = true)
 |-- author: string (nullable = true)
 |-- altitude: double (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- gbif_species_id: string (nullable = true)
 |-- species: string (nullable = true)
 |-- genus: string (nullable = true)
 |-- family: string (nullable = true)
 |-- dataset: string (nullable = true)
 |-- publisher: string (nullable = true)
 |-- references: string (nullable = true)
 |-- url: string (nullable = true)
 |-- learn_tag: string (nullable = true)
 |-- image_backup_url: string (nullable = true)



                                                                                

+--------------------+--------------------+--------------------+-----+----------+----------+--------------------+-------+----------------+--------+-----------------+------------------+---------------+--------------------+-------------+----------+--------+-----------+--------------------+--------------------+---------+--------------------+
|          image_name|                path|                data|organ|species_id|    obs_id|             license|partner|          author|altitude|         latitude|         longitude|gbif_species_id|             species|        genus|    family| dataset|  publisher|          references|                 url|learn_tag|    image_backup_url|
+--------------------+--------------------+--------------------+-----+----------+----------+--------------------+-------+----------------+--------+-----------------+------------------+---------------+--------------------+-------------+----------+--------+-----------+--------------------+--------------------+---------

In [5]:
txt_file_path = f"{root}/clef/plantclef-2025/plantclef/train_species_ids.txt"
train_species_ids = []
with open(txt_file_path, "r") as file:
    for line in file:
        species = int(line.replace("\n", ""))
        train_species_ids.append(species)

train_species_ids[:10]

[1395807,
 1361281,
 1394311,
 1741880,
 1397468,
 1392407,
 1397535,
 1390793,
 1392323,
 1722440]

In [6]:
from pyspark.sql import functions as F

# select subset of train data
subset_df = train_df.where(F.col("species_id").isin(train_species_ids))
subset_df.count()

                                                                                

42068

In [7]:
# group by species_id
grouped_df = subset_df.groupBy("species_id").count().orderBy(F.desc("count"))
grouped_df.show()



+----------+-----+
|species_id|count|
+----------+-----+
|   1414366|  674|
|   1722433|  625|
|   1356576|  597|
|   1412585|  594|
|   1358613|  578|
|   1414356|  558|
|   1741903|  556|
|   1360260|  551|
|   1722522|  541|
|   1722501|  532|
|   1422217|  529|
|   1414367|  512|
|   1397475|  474|
|   1397535|  462|
|   1422218|  440|
|   1357630|  440|
|   1356286|  418|
|   1419076|  415|
|   1722625|  414|
|   1392608|  410|
+----------+-----+
only showing top 20 rows



                                                                                

In [8]:
grouped_df.tail(20)

                                                                                

[Row(species_id=1391313, count=50),
 Row(species_id=1397463, count=50),
 Row(species_id=1392241, count=47),
 Row(species_id=1398772, count=46),
 Row(species_id=1392732, count=44),
 Row(species_id=1392323, count=42),
 Row(species_id=1393659, count=42),
 Row(species_id=1743968, count=36),
 Row(species_id=1420558, count=35),
 Row(species_id=1697384, count=35),
 Row(species_id=1651363, count=35),
 Row(species_id=1363722, count=31),
 Row(species_id=1741587, count=22),
 Row(species_id=1741661, count=19),
 Row(species_id=1390899, count=16),
 Row(species_id=1361275, count=15),
 Row(species_id=1399800, count=15),
 Row(species_id=1651485, count=12),
 Row(species_id=1580587, count=8),
 Row(species_id=1743474, count=2)]

In [9]:
from pyspark.sql.window import Window

truncate_rows = 10
window_spec = Window.partitionBy("species_id").orderBy(F.rand())
subset_row_df = subset_df.withColumn("row_number", F.row_number().over(window_spec))
filtered_df = subset_row_df.filter(F.col("row_number") <= truncate_rows).drop(
    "row_number"
)
print(filtered_df.count())
filtered_df.printSchema()



2020
root
 |-- image_name: string (nullable = true)
 |-- path: string (nullable = true)
 |-- data: binary (nullable = true)
 |-- organ: string (nullable = true)
 |-- species_id: integer (nullable = true)
 |-- obs_id: long (nullable = true)
 |-- license: string (nullable = true)
 |-- partner: string (nullable = true)
 |-- author: string (nullable = true)
 |-- altitude: double (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- gbif_species_id: string (nullable = true)
 |-- species: string (nullable = true)
 |-- genus: string (nullable = true)
 |-- family: string (nullable = true)
 |-- dataset: string (nullable = true)
 |-- publisher: string (nullable = true)
 |-- references: string (nullable = true)
 |-- url: string (nullable = true)
 |-- learn_tag: string (nullable = true)
 |-- image_backup_url: string (nullable = true)



                                                                                

In [10]:
# write dataframe to parquet
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/parquet"
output_path = f"{data_path}/train_pytorch_webinar_filtered"
# repartition the DataFrame into 20 partitions before writing to parquet
filtered_df = filtered_df.repartition(20)
filtered_df.write.mode("overwrite").parquet(output_path)

print(f"Filtered DataFrame saved to {output_path}")

                                                                                

Filtered DataFrame saved to /storage/home/hcoda1/9/mgustineli3/p-dsgt_clef2025-0/shared/plantclef/data/parquet/train_pytorch_webinar_filtered
