In [None]:
# Read all the annotations in the train table from the lakehouse
df = spark.sql("SELECT * FROM SnapshotSerengeti_LH.train_annotations WHERE train_annotations.category_id > 1")

# filter out the season, sequence ID, category_id snf image_id
df_train = df.select("season", "seq_id", "category_id", "location", "image_id", "datetime")

# remove image_id wiTH null and duplicates
df_train = df_train.filter(df_train.image_id.isNotNull()).dropDuplicates()

In [None]:
# Import the required libraries
from pyspark.sql.functions import split, regexp_replace, col

# This splits the seq_id string at '#' and takes the first part.
df_train = df_train.withColumn("season_extracted", split(col("seq_id"), "#").getItem(0))

# Remove the 'SER_' prefix from the extracted season for better readability.
df_train = df_train.withColumn("season_label", regexp_replace(col("season_extracted"), "SER_", ""))

# Group by the season_label and count the number of sequences for each season, then order the results.
df_counts = df_train.groupBy("season_label").count().orderBy("season_label")

# visualize the spark data frame directly in the notebook
display(df_counts)

In [None]:
from pyspark.sql import functions as F

# Compute the number of images per sequence.
seq_counts = df_train.groupBy("seq_id").count()

# Aggregate the data: group by the image count and count how many sequences have that count.
sequence_length_counts = (seq_counts
    .groupBy("count")
    .agg(F.count("seq_id").alias("Count of sequences"))
    .withColumnRenamed("count", "Number of images")
    .orderBy(F.col("Number of images"))
)

# visualize the spark data frame directly in the notebook
display(sequence_length_counts)

In [None]:
from pyspark.sql.functions import concat, lit

def transform_image_data(df, categories_df):
    """
    Joins the input DataFrame with the categories DataFrame, renames columns,
    and appends a '.JPG' extension to the filename column.
    """
    # Join on category_id to map category names. The join brings the category name as "name", which we then rename to "label".
    df = df.join(
        categories_df.select(col("id").alias("category_id"), col("name")),
        on="category_id",
        how="left"
    ).withColumnRenamed("name", "label")

    # Drop the 'category_id' column
    df = df.drop("category_id")

    # Rename 'image_id' to 'filename'
    df = df.withColumnRenamed("image_id", "filename")

    # Append '.JPG' to the filename
    df = df.withColumn("filename", concat(col("filename"), lit(".JPG")))

    return df

# Load the categories table as a Spark DataFrame
categories_df = spark.sql("SELECT * FROM SnapshotSerengeti_LH.categories")

# Apply transformation for image data
df_train = transform_image_data(df_train, categories_df)

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

# Define a window partitioned by seq_id and ordered by filename
windowSpec = Window.partitionBy("seq_id").orderBy("filename")

# Assign a row number to each row within its sequence, filter for the first frame, then drop the helper column.
df_train = df_train.withColumn("row_num", row_number().over(windowSpec)) \
                   .filter(col("row_num") == 1) \
                   .drop("row_num")

# Count the rows in the resulting DataFrame
df_train.count()

In [None]:
# Create a new DataFrame that counts the number of images per label
label_counts = df_train.groupBy("label").count().orderBy(col("count").desc())

# Rename the columns for better readability and visualization
label_counts = label_counts.withColumnRenamed("label", "Label") \
                           .withColumnRenamed("count", "Number of images")

# Visualize the label counts
display(label_counts)

In [None]:
def get_ImageUrl(filename):
    return f"https://lilawildlife.blob.core.windows.net/lila-wildlife/snapshotserengeti-unzipped/{filename}"

In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# Create a UDF from the function
get_ImageUrl_udf = udf(get_ImageUrl, StringType())

# Apply the UDF to create the image_url column
df_train = df_train.withColumn("image_url", get_ImageUrl_udf(col("filename")))

In [None]:
import urllib.request
import matplotlib.pyplot as plt

def display_random_image(label, random_state, width=500):
    # Filter the Spark DataFrame to only include rows with the specified label,
    # then order randomly (using the provided seed) and select one row.
    row = df_train.filter(col("label") == label) \
                  .orderBy(F.rand(random_state)) \
                  .limit(1) \
                  .collect()[0]
    
    # Get the image URL from the selected row and display the image
    url = row["image_url"]
    download_and_display_image(url, label)

def download_and_display_image(url, label):
    image = plt.imread(urllib.request.urlopen(url), format='jpg')
    plt.imshow(image)
    plt.title(f"Label: {label}")
    plt.axis('off')
    plt.show()

In [None]:
display_random_image(label='leopard', random_state=12)

In [None]:
def proportional_allocation_percentage(df, percentage):
    """
    Proportionally allocate a sample of 'percentage'% of df across 
    groups (label, season, location)
    """
    # Determine the total number of rows and desired sample size
    total_count = df.count()
    sample_size = int(round(total_count * (percentage / 100.0)))

    # Compute group counts
    group_counts = (
        df.groupBy("label", "season", "location")
          .count()  # number of rows in each group
          .withColumnRenamed("count", "group_count")
    )

    # Compute the proportion of each group, then approximate "sample_needed" via rounding
    group_counts = (
        group_counts
        .withColumn("proportion", F.col("group_count") / F.lit(total_count))
        .withColumn("sample_needed", F.round(F.col("proportion") * sample_size).cast("int"))
    )

    #  Collect just the group-level info to the driver for fine-grained adjustment
    group_counts_pd = group_counts.select(
        "label", "season", "location", "group_count", "sample_needed", "proportion"
    ).toPandas()

    # Sum of "sample_needed" might not equal the total desired sample_size due to rounding
    current_sum = group_counts_pd["sample_needed"].sum()
    difference = sample_size - current_sum

    if difference > 0:
        # If we're short, we add +1 to the groups with the largest proportions until we fix the difference
        # Sort descending by proportion
        group_counts_pd = group_counts_pd.sort_values("proportion", ascending=False)
        for i in range(difference):
            group_counts_pd.iat[i, group_counts_pd.columns.get_loc("sample_needed")] += 1
        # Re-sort back if desired
        group_counts_pd = group_counts_pd.sort_values(["label", "season", "location"])
    elif difference < 0:
        # If we have too many, subtract 1 from the groups with the smallest proportions
        # Sort ascending by proportion
        group_counts_pd = group_counts_pd.sort_values("proportion", ascending=True)
        for i in range(abs(difference)):
            group_counts_pd.iat[i, group_counts_pd.columns.get_loc("sample_needed")] -= 1
        # Re-sort back if desired
        group_counts_pd = group_counts_pd.sort_values(["label", "season", "location"])

    # Create a Spark DataFrame of the final sample allocations
    allocations_sdf = spark.createDataFrame(group_counts_pd)

    #  Join the allocations back to the main DataFrame so each row knows how many rows 
    #    from its group we want to keep
    df_joined = (
        df.join(
            F.broadcast(allocations_sdf),
            on=["label", "season", "location"],
            how="left"
        )
    )

    # Use a row_number partitioned by (label, season, location) to limit how many rows per group
    window_spec = Window.partitionBy("label", "season", "location").orderBy(F.monotonically_increasing_id())
    df_with_rn = df_joined.withColumn("rn", F.row_number().over(window_spec))

    # Filter out rows where 'rn' exceeds 'sample_needed'
    df_sample = df_with_rn.filter(F.col("rn") <= F.col("sample_needed"))

    # Drop helper columns if you don't need them in the final result
    df_sample = df_sample.drop("proportion", "group_count", "sample_needed", "rn")

    return df_sample


In [None]:
percent = 0.05
sampled_train = proportional_allocation_percentage(df_train, percent)

# Group by the season_label and count the number of sequences for each season, then order the results.
df_sampled_train_counts = sampled_train.groupBy("season_label").count().orderBy("season_label")

# visualize the spark data frame directly in the notebook
display(df_sampled_train_counts)

In [None]:
import urllib.request
from PIL import Image
import os

def download_and_resize_image(url, path, kind):
    filename = os.path.basename(path)
    directory = os.path.dirname(path)

    # Define a new directory path where permission is granted
    directory_path = f'/lakehouse/default/Files/images/{kind}/{directory}/'

    # Create the directory if it does not exist
    os.makedirs(directory_path, exist_ok=True)

    # Define the full target file path
    target_file_path = os.path.join(directory_path, filename)

    # Check if file already exists
    if os.path.isfile(target_file_path):
        return

    # Download the image
    urllib.request.urlretrieve(url, target_file_path)

    # Open the image using PIL
    img = Image.open(target_file_path)

    # Resize the image to a reasonable ML training size
    resized_img = img.resize((224, 224), Image.ANTIALIAS)

    # Save the resized image to a defined filepath
    resized_img.save(target_file_path)

In [None]:
import concurrent.futures

def execute_parallel_download(spark_df, kind):
    df = spark_df.toPandas()
    # Use a process pool instead of a thread pool to avoid thread safety issues
    with concurrent.futures.ProcessPoolExecutor() as executor:
        # Batch process images instead of processing them one at a time
        urls = df['image_url'].tolist()
        paths = df['filename'].tolist()
        futures = [executor.submit(download_and_resize_image, url, path, kind) for url, path in zip(urls, paths)]
        # Wait for all tasks to complete
        concurrent.futures.wait(futures)

In [None]:
df_test = spark.sql("SELECT * FROM SnapshotSerengeti_LH.test_annotations WHERE test_annotations.category_id > 1")


df_test = (
    df_test
    .filter(df_test.image_id.isNotNull())
    .dropDuplicates()
    .withColumn("season_extracted", split(col("seq_id"), "#").getItem(0))
    .withColumn("season_label", regexp_replace(col("season_extracted"), "SER_", "")))

df_test = transform_image_data(df_test, categories_df)
df_test = df_test.withColumn("image_url", get_ImageUrl_udf(col("filename")))

sampled_test = proportional_allocation_percentage(df_test, 0.27)

In [None]:
import os

execute_parallel_download(sampled_train, 'train')
execute_parallel_download(sampled_test, 'test')

In [None]:
import os

def list_all_files(directory):
    file_list = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_list.append(os.path.join(root, file))
    return file_list

train_images_path = f"/lakehouse/default/Files/images/train/"
test_images_path =  f"/lakehouse/default/Files/images/test/"

print(f"{len(list_all_files(train_images_path))} files downloaded out of {sampled_train.count()}")
print(f"{len(list_all_files(test_images_path))} files downloaded out of {sampled_test.count()}")

In [None]:
sampled_train.select("filename", "label")\
    .write.saveAsTable("sampled_train", mode="overwrite", overwriteSchema="true")

sampled_test.select("filename", "label")\
    .write.saveAsTable("sampled_test", mode="overwrite", overwriteSchema="true")