In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
! gcloud storage ls gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023 | head

gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023/eBird_Taxonomy_v2021.csv
gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023/sample_submission.csv
gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023/train_metadata.csv
gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023/test_soundscapes/
gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023/train_audio/


In [4]:
from birdclef.utils import get_spark
from IPython.display import Image, display

spark = get_spark()
display(spark)

# read straight from the bucket
df_meta = spark.read.csv(
    "gs://dsgt-clef-birdclef-2024/data/raw/birdclef-2023/train_metadata.csv"
)
df_meta.printSchema()
df_meta.show(vertical=True, n=1, truncate=100)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/02/19 19:59:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/02/19 20:00:03 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


                                                                                

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)
 |-- _c2: string (nullable = true)
 |-- _c3: string (nullable = true)
 |-- _c4: string (nullable = true)
 |-- _c5: string (nullable = true)
 |-- _c6: string (nullable = true)
 |-- _c7: string (nullable = true)
 |-- _c8: string (nullable = true)
 |-- _c9: string (nullable = true)
 |-- _c10: string (nullable = true)
 |-- _c11: string (nullable = true)

-RECORD 0----------------
 _c0  | primary_label    
 _c1  | secondary_labels 
 _c2  | type             
 _c3  | latitude         
 _c4  | longitude        
 _c5  | scientific_name  
 _c6  | common_name      
 _c7  | author           
 _c8  | license          
 _c9  | rating           
 _c10 | url              
 _c11 | filename         
only showing top 1 row



In [6]:
df_meta.show(n=5)

+-------------+----------------+--------+--------+---------+------------------+--------------------+-------------+--------------------+------+--------------------+--------------------+
|          _c0|             _c1|     _c2|     _c3|      _c4|               _c5|                 _c6|          _c7|                 _c8|   _c9|                _c10|                _c11|
+-------------+----------------+--------+--------+---------+------------------+--------------------+-------------+--------------------+------+--------------------+--------------------+
|primary_label|secondary_labels|    type|latitude|longitude|   scientific_name|         common_name|       author|             license|rating|                 url|            filename|
|      abethr1|              []|['song']|  4.3906|  38.2788|Turdus tephronotus|African Bare-eyed...|Rolf A. de By|Creative Commons ...|   4.0|https://www.xeno-...|abethr1/XC128013.ogg|
|      abethr1|              []|['call']| -2.9524|  38.2921|Turdus tephrono

In [7]:
from matplotlib import pyplot as plt
import pandas as pd

# Group by species and count
species_counts = df_meta.groupBy("_c5").count().orderBy("count", ascending=False)

# Convert to Pandas DataFrame for plotting
species_df = species_counts.toPandas()

# Get the top 5 species
top_5 = species_df.head(5)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  PyArrow >= 4.0.0 must be installed; however, it was not found.
Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
  warn(msg)
                                                                                

In [8]:
top_5

Unnamed: 0,_c5,count
0,Actitis hypoleucos,500
1,Hirundo rustica,499
2,Phylloscopus trochilus,499
3,Luscinia luscinia,498
4,Motacilla flava,498


In [9]:
from pyspark.sql import Window
from pyspark.sql.functions import desc, row_number

# filter the top 5 species
num_species = 5
top_species = species_df.head(num_species)["_c5"].values.tolist()
df_meta_top = df_meta.filter(df_meta._c5.isin(top_species))

# take top 100 of every species
num_per_species = 100
window = Window.partitionBy("_c5").orderBy(desc("_c9"))
df_meta_top = df_meta_top.withColumn("row_num", row_number().over(window))
df_meta_top = df_meta_top.filter(df_meta_top.row_num <= num_per_species)

In [10]:
count = df_meta_top.count()
count

500

In [11]:
df_meta_top.show()

[Stage 18:>                                                         (0 + 1) / 1]

+------+---+--------------------+-------+-------+------------------+----------------+--------------------+--------------------+---+--------------------+-------------------+-------+
|   _c0|_c1|                 _c2|    _c3|    _c4|               _c5|             _c6|                 _c7|                 _c8|_c9|                _c10|               _c11|row_num|
+------+---+--------------------+-------+-------+------------------+----------------+--------------------+--------------------+---+--------------------+-------------------+-------+
|comsan| []|['female', 'fligh...|52.3527|20.9197|Actitis hypoleucos|Common Sandpiper|      Jarek Matusiak|Creative Commons ...|5.0|https://www.xeno-...|comsan/XC129378.ogg|      1|
|comsan| []|            ['call']|54.2697| 8.8683|Actitis hypoleucos|Common Sandpiper|       Volker Arnold|Creative Commons ...|5.0|https://www.xeno-...|comsan/XC142817.ogg|      2|
|comsan| []|            ['call']|55.0194|82.8918|Actitis hypoleucos|Common Sandpiper|          

                                                                                

## Encodec with a single recording

In [99]:
from pyspark.sql.functions import rand

# row = df_meta_top.head()
seed = 6
# row = df_meta_top.sample(fraction=0.01, seed=seed).head()
df_with_random = df_meta_top.withColumn("random", rand(seed))
row = df_with_random.orderBy("random").limit(1).head()
local_path = row._c11
species_code, filename = local_path.split("/")
local_path

'eaywag1/XC639588.ogg'

In [108]:
from encodec import EncodecModel
from encodec.utils import convert_audio

import torchaudio
import torch

bandwidth = 6.0

# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
# The number of codebooks used will be determined bythe bandwidth selected.
# E.g. for a bandwidth of 6kbps, `n_q = 8` codebooks are used.
# Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8) and 12 kbps (n_q =16) and 24kbps (n_q=32).
# For the 48 kHz model, only 3, 6, 12, and 24 kbps are supported. The number
# of codebooks for each is half that of the 24 kHz model as the frame rate is twice as much.
model.set_target_bandwidth(bandwidth)

In [109]:
import os
from pathlib import Path

gs_path = f"data/raw/birdclef-2023/train_audio/{local_path}"

root = Path("data/encodec_reconstruction")
original_dir = root / "original" / species_code
reconstructed_dir = root / "reconstructed" / f"{int(bandwidth)}kbps" / species_code

original_path = original_dir / filename
reconstructed_path = reconstructed_dir / (filename[:-4] + ".wav")

os.makedirs(original_dir, exist_ok=True)
os.makedirs(reconstructed_dir, exist_ok=True)

In [110]:
from google.cloud import storage
import io

client = storage.Client()
bucket = client.get_bucket('dsgt-clef-birdclef-2024')
blob = bucket.blob(gs_path)
file_stream = io.BytesIO()
blob.download_to_file(file_stream)
file_stream.seek(0)

blob.download_to_filename(original_path)

In [111]:
# Load and pre-process the audio waveform
wav, sr = torchaudio.load(file_stream)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0)

# Extract discrete codes from EnCodec
with torch.no_grad():
    encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]

In [112]:
codes.shape

torch.Size([1, 8, 2435])

In [113]:
with torch.no_grad():
    decoded_frames = model.decode(encoded_frames)
decoded = decoded_frames[0]
torchaudio.save(reconstructed_path, decoded, model.sample_rate)

In [114]:
decoded_frames.shape

torch.Size([1, 1, 779200])