In [7]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from torch import nn
from torch.nn import functional as F
import pandas as pd

sys.path.append("src")

from decoder import TweetyBertClassifier
from spectogram_generator import WavtoSpec

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
folder = "/media/george-vengrovski/disk2/canary/unsorted/USA5206"
classifier_path = "/media/george-vengrovski/disk1/linear_decoder_test"

dst_folder = "/media/george-vengrovski/disk2/canary/unsorted/USA5206_specs"
os.makedirs(dst_folder, exist_ok=True)

wav_to_spec = WavtoSpec(folder, dst_folder, csv_file_dir=None)

model = TweetyBertClassifier.load_decoder_state(classifier_path)

def inference_data_class(data, context_length=1000):
    recording_length = data[1].shape[1]

    spectogram = data[1][20:216]
    # Calculate mean and standard deviation of the spectrogram
    spec_mean = np.mean(spectogram)
    spec_std = np.std(spectogram)
    # Z-score the spectrogram
    spectogram = (spectogram - spec_mean) / spec_std

    # Process labels
    ground_truth_labels = np.array(data[0], dtype=int)
    vocalization = np.array(data[2], dtype=int)
    
    ground_truth_labels = torch.from_numpy(ground_truth_labels).long().squeeze(0)
    spectogram = torch.from_numpy(spectogram).float().permute(1, 0)
    ground_truth_labels = F.one_hot(ground_truth_labels, num_classes=model.num_classes).float()
    vocalization = torch.from_numpy(vocalization).long()

    pad_amount = context_length - (recording_length % context_length)
    if recording_length < context_length:
        pad_amount = context_length - recording_length
     
    if recording_length > context_length and pad_amount != 0:
        pad_amount = context_length - (spectogram.shape[0] % context_length)
        spectogram = F.pad(spectogram, (0, 0, 0, pad_amount), 'constant', 0)
        ground_truth_labels = F.pad(ground_truth_labels, (0, 0, 0, pad_amount), 'constant', 0)  # Adjusted padding for labels
        vocalization = F.pad(vocalization, (0, pad_amount), 'constant', 0)

    # reshape into batches 
    spectogram = spectogram.reshape(spectogram.shape[0] // context_length, context_length, spectogram.shape[1])
    ground_truth_labels = ground_truth_labels.reshape(ground_truth_labels.shape[0] // context_length, context_length, ground_truth_labels.shape[1])
    vocalization = vocalization.reshape(vocalization.shape[0] // context_length, context_length)

    return spectogram, ground_truth_labels, vocalization

database = pd.DataFrame(columns=["file_name", "song_present", "syllable_onsets/offsets"])

processed = False

song_length = 0 

def max_vote(predicted_labels):
    processed_labels = predicted_labels.copy()
    
    # Find indices of zeros and non-zeros
    zero_indices = np.where(predicted_labels == 0)[0]
    non_zero_indices = np.where(predicted_labels != 0)[0]
    
    # Find contiguous regions
    regions = []
    current_region = []
    
    for i in range(len(non_zero_indices)):
        if i == 0 or non_zero_indices[i] - non_zero_indices[i-1] == 1:
            current_region.append(non_zero_indices[i])
        else:
            if current_region:
                regions.append(current_region)
            current_region = [non_zero_indices[i]]
    
    if current_region:
        regions.append(current_region)
    
    # Process each contiguous region
    for region in regions:
        start = region[0]
        end = region[-1]
        
        # Check if the region is flanked by zeros
        if (start == 0 or predicted_labels[start-1] == 0) and \
           (end == len(predicted_labels)-1 or predicted_labels[end+1] == 0):
            
            # Get the max vote for this region
            region_labels = predicted_labels[start:end+1]
            max_vote_label = np.bincount(region_labels).argmax()
            
            # Apply the max vote to the entire region
            processed_labels[start:end+1] = max_vote_label
    
    return processed_labels

def convert_to_onset_offset(labels):
    sampling_rate = 44100 
    NFFT = 1024
    hop_length = NFFT // 2  # Assuming 50% overlap
    ms_per_timebin = (hop_length / sampling_rate) * 1000

    onsets = []
    offsets = []
    current_label = 0
    start_time = 0

    for i, label in enumerate(labels):
        if label != current_label:
            if current_label != 0:
                offsets.append(i * ms_per_timebin)
            if label != 0:
                onsets.append(i * ms_per_timebin)
            current_label = label

    if current_label != 0:
        offsets.append(len(labels) * ms_per_timebin)

    return list(zip(onsets, offsets))

database = pd.DataFrame(columns=["file_name", "song_present", "syllable_onsets/offsets"])
processed_count = 0

for day in os.listdir(folder):
    day_path = os.path.join(folder, day)
    if os.path.isdir(day_path):
        for song in os.listdir(day_path):
            file_path = os.path.join(day_path, song)
            if file_path.lower().endswith('.wav'):
                try:
                    spec, vocalization, labels = wav_to_spec.process_file(wav_to_spec, file_path=file_path)

                    spectogram, _, _ = inference_data_class((labels, spec, vocalization), context_length=1000)
                    spec_tensor = torch.Tensor(spectogram).to(device).unsqueeze(1)

                    logits = model.classifier_model(spec_tensor.permute(0,1,3,2))
                    logits = logits.reshape(logits.shape[0] * logits.shape[1], -1)

                    predicted_labels = torch.argmax(logits, dim=1).detach().cpu().numpy()
                    post_processed_labels = max_vote(predicted_labels)

                    onsets_offsets = convert_to_onset_offset(post_processed_labels)
                    
                    song_present = len(onsets_offsets) > 0

                    new_row = pd.DataFrame({
                        "file_name": [file_path],
                        "song_present": [song_present],
                        "syllable_onsets/offsets": [onsets_offsets]
                    })
                    database = pd.concat([database, new_row], ignore_index=True)

                    processed_count += 1

                    # Save the DataFrame to CSV every 100 processed files
                    if processed_count % 100 == 0:
                        csv_path = os.path.join(dst_folder, "song_database.csv")
                        database.to_csv(csv_path, index=False)
                        print(f"Processed {processed_count} files. Database saved to {csv_path}")

                except Exception as e:
                    print(f"Error processing {file_path}: {e}")

# Save the final DataFrame to CSV
csv_path = os.path.join(dst_folder, "song_database.csv")
database.to_csv(csv_path, index=False)
print(f"Processing complete. Final database saved to {csv_path}")

Decoder state loaded from /media/george-vengrovski/disk1/linear_decoder_test/decoder_state
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.42844195_12_2_11_54_4.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.37814723_12_2_10_30_14.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.48453659_12_2_13_27_33.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.52443483_12_2_14_34_3.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.42892134_12_2_11_54_52.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.44997595_12_

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f01ffe7f150>>
Traceback (most recent call last):
  File "/home/george-vengrovski/anaconda3/envs/tweetybert/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.45676962_12_2_12_41_16.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.44057354_12_2_12_14_17.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.37672147_12_2_10_27_52.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.42920214_12_2_11_55_20.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.43407680_12_2_12_3_27.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george-vengrovski/disk2/canary/unsorted/USA5206/74/USA5206_45262.37804867_12_2_10_30_4.wav: 'DataFrame' object has no attribute 'append'
Error processing /media/george