# Intersection of predictions

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.utils import get_spark
from pyspark.sql import functions as F

spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/17 15:27:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/05/17 15:27:14 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [25]:
# get dataframes
gcs_path = "gs://dsgt-clef-plantclef-2024"


def get_test_path_csv(folder_name: str) -> str:
    pred_data_path = f"models/pretrained-dino/{folder_name}/dsgt_run_{folder_name}.csv"
    return pred_data_path


# paths to dataframe
grid_2 = "top_5_species_grid_2x2"
grid_3 = "top_5_species_grid_3x3"
grid_5 = "top_5_species_grid_5x5"

grid_2_path = f"{gcs_path}/{get_test_path_csv(grid_2)}"
grid_3_path = f"{gcs_path}/{get_test_path_csv(grid_3)}"
grid_5_path = f"{gcs_path}/{get_test_path_csv(grid_5)}"
# read data
grid_2_spark_df = spark.read.csv(grid_2_path, sep=";", header=True)
grid_3_spark_df = spark.read.csv(grid_3_path, sep=";", header=True)
grid_5_spark_df = spark.read.csv(grid_5_path, sep=";", header=True)
# show
grid_2_spark_df.show(n=5, truncate=50)
grid_3_spark_df.show(n=5, truncate=50)
grid_5_spark_df.show(n=5, truncate=50)

+--------------------+---------------------------------------------+
|             plot_id|                                  species_ids|
+--------------------+---------------------------------------------+
|CBN-PdlC-A1-20130807|         [1412857, 1392608, 1361281, 1394911]|
|CBN-PdlC-A1-20130903|         [1412857, 1361281, 1394641, 1742052]|
|CBN-PdlC-A1-20140721|         [1412857, 1664563, 1361281, 1393619]|
|CBN-PdlC-A1-20140811|         [1392608, 1412857, 1395807, 1519650]|
|CBN-PdlC-A1-20140901|[1392608, 1742052, 1361281, 1412857, 1667408]|
+--------------------+---------------------------------------------+
only showing top 5 rows

+--------------------+---------------------------------------------+
|             plot_id|                                  species_ids|
+--------------------+---------------------------------------------+
|CBN-PdlC-A1-20130807|[1395807, 1412857, 1392608, 1392535, 1392611]|
|CBN-PdlC-A1-20130903|         [1392608, 1742052, 1362271, 1667408]|
|CBN-PdlC

In [26]:
from pyspark.sql.types import IntegerType


# Function to convert species_ids from string to array of integers
def convert_species_ids_to_array(df):
    return df.withColumn(
        "species_ids",
        F.split(F.regexp_replace(F.col("species_ids"), "[\\[\\]]", ""), ",\s*").cast(
            "array<int>"
        ),
    )


# Convert species_ids column in each dataframe
grid_2_df = convert_species_ids_to_array(grid_2_spark_df)
grid_3_df = convert_species_ids_to_array(grid_3_spark_df)
grid_5_df = convert_species_ids_to_array(grid_5_spark_df)

In [32]:
# Explode the species_ids into individual species_id rows
grid_2_exploded = grid_2_df.withColumn("species_id", F.explode("species_ids"))
grid_3_exploded = grid_3_df.withColumn("species_id", F.explode("species_ids"))
grid_5_exploded = grid_5_df.withColumn("species_id", F.explode("species_ids"))
# show
grid_2_exploded.show(n=5, truncate=50)

+--------------------+------------------------------------+----------+
|             plot_id|                         species_ids|species_id|
+--------------------+------------------------------------+----------+
|CBN-PdlC-A1-20130807|[1412857, 1392608, 1361281, 1394911]|   1412857|
|CBN-PdlC-A1-20130807|[1412857, 1392608, 1361281, 1394911]|   1392608|
|CBN-PdlC-A1-20130807|[1412857, 1392608, 1361281, 1394911]|   1361281|
|CBN-PdlC-A1-20130807|[1412857, 1392608, 1361281, 1394911]|   1394911|
|CBN-PdlC-A1-20130903|[1412857, 1361281, 1394641, 1742052]|   1412857|
+--------------------+------------------------------------+----------+
only showing top 5 rows



In [35]:
# Perform an inner join on plot_id and species_id
intersect_df = (
    grid_2_exploded.join(grid_3_exploded, ["plot_id", "species_id"], "inner")
    .join(grid_5_exploded, ["plot_id", "species_id"], "inner")
    .select("plot_id", "species_id")
    .distinct()
)

# Group by plot_id and collect the intersected species_ids into a list
intersected_species_df = intersect_df.groupBy("plot_id").agg(
    F.collect_list("species_id").alias("intersected_species_ids")
)

# Join with grid_3_df to get the species_ids when there are no common species_ids
result_df = (
    grid_3_df.alias("g3")
    .join(intersected_species_df.alias("inter"), "plot_id", "left")
    .withColumn(
        "species_ids",
        F.when(
            F.col("inter.intersected_species_ids").isNull(), F.col("g3.species_ids")
        ).otherwise(F.col("inter.intersected_species_ids")),
    )
)

# Show the result
result_df.show(n=5, truncate=50)

+--------------------+------------------+-----------------------+
|             plot_id|       species_ids|intersected_species_ids|
+--------------------+------------------+-----------------------+
|CBN-PdlC-A1-20140721|         [1412857]|              [1412857]|
|CBN-PdlC-A1-20130807|         [1392608]|              [1392608]|
|CBN-PdlC-A1-20140901|[1392608, 1742052]|     [1392608, 1742052]|
|CBN-PdlC-A1-20150701|         [1392608]|              [1392608]|
|CBN-PdlC-A1-20130903|         [1742052]|              [1742052]|
+--------------------+------------------+-----------------------+
only showing top 5 rows



In [48]:
import pandas as pd


def format_species_ids(species_ids: list) -> str:
    """Formats the species IDs in single square brackets, separated by commas."""
    formatted_ids = ", ".join(str(id) for id in species_ids)
    return f"[{formatted_ids}]"


def prepare_df_submission(spark_df):
    records = []
    for row in spark_df.collect():
        logits = row["species_ids"]
        formatted_species = format_species_ids(logits)
        records.append({"plot_id": row["plot_id"], "species_ids": formatted_species})

    pandas_df = pd.DataFrame(records)
    return pandas_df


result_df = result_df.orderBy("plot_id")
pandas_df = prepare_df_submission(result_df)
pandas_df.head(20)

Unnamed: 0,plot_id,species_ids
0,CBN-PdlC-A1-20130807,[1392608]
1,CBN-PdlC-A1-20130903,[1742052]
2,CBN-PdlC-A1-20140721,[1412857]
3,CBN-PdlC-A1-20140811,"[1412857, 1392608]"
4,CBN-PdlC-A1-20140901,"[1392608, 1742052]"
5,CBN-PdlC-A1-20150701,[1392608]
6,CBN-PdlC-A1-20150720,"[1412857, 1742052, 1392608]"
7,CBN-PdlC-A1-20150831,"[1392608, 1412857, 1742052]"
8,CBN-PdlC-A1-20160705,[1412857]
9,CBN-PdlC-A1-20160726,"[1549015, 1742052]"


In [49]:
import csv

# write cvs
pandas_df.to_csv("dsgt_run_intersect.csv", sep=";", index=False, quoting=csv.QUOTE_NONE)