In [1]:
import matplotlib.pyplot as plt
import os
import torch 
from torch.utils.data import DataLoader
import json
import sys
import torch

sys.path.append("src")
os.chdir('/home/george-vengrovski/Documents/projects/tweety_bert_paper')

from utils import load_model

weights_path = "/home/george-vengrovski/Documents/experiments_backup/tweety_bert_paper/experiments/TweetyBERT-Combined-MSE-1/saved_weights/model_step_11800.pth"
config_path = "/home/george-vengrovski/Documents/experiments_backup/tweety_bert_paper/experiments/TweetyBERT-Combined-MSE-1/config.json"

tweety_bert_model = load_model(config_path, weights_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


FileNotFoundError: Configuration file not found at /home/george-vengrovski/Documents/experiments_backup/TweetyBERT-Combined-MSE-1/config.json

## Data Class

In [None]:
from torch.utils.data import DataLoader
from data_class import SongDetectorDataClass, CollateFunctionSongDetection

train_dir = "/home/george-vengrovski/Documents/data/finetune_labeled_data_train"
test_dir = "/home/george-vengrovski/Documents/data/finetune_labeled_data_test"

train_dataset = SongDetectorDataClass(train_dir, num_classes=2, psuedo_labels_generated=False)
test_dataset = SongDetectorDataClass(test_dir, num_classes=2, psuedo_labels_generated=False)

collate_fn = CollateFunctionSongDetection(segment_length=1000)  # Adjust the segment length if needed

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

## Define Linear Classifier and Train

In [None]:
from linear_probe import LinearProbeModel, LinearProbeTrainer

classifier_model = LinearProbeModel(num_classes=2, model_type="neural_net", model=tweety_bert_model, freeze_layers=True, layer_num=-1, layer_id="attention_output", classifier_dims=384)
classifier_model = classifier_model.to(device)

In [None]:
trainer = LinearProbeTrainer(model=classifier_model, train_loader=train_loader, test_loader=test_loader, device=device, lr=1e-5, plotting=True, batches_per_eval=1, desired_total_batches=1e3, patience=15)
trainer.train()

## Analyze

In [None]:
from linear_probe import ModelEvaluator

evaluator = ModelEvaluator(classifier_model, test_loader)
class_frame_error_rates, total_frame_error_rate = evaluator.validate_model_multiple_passes(num_passes=1, max_batches=1250)
evaluator.save_results(class_frame_error_rates, total_frame_error_rate, '/home/george-vengrovski/Documents/projects/tweety_bert_paper/results/test')

## Visualize Song and Not Song

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def plot_spectrogram_with_labels_and_logits(spec, ground_truth_label, logits):
    # Apply sigmoid to logits to scale them between 0 and 1
    logits_sigmoid = sigmoid(logits.numpy())

    # Scale logits to match the frequency axis of the spectrogram
    freq_range = spec.shape[0]  # Assuming the frequency range is the first dimension of spec
    logits_scaled = logits_sigmoid * freq_range  # Scale the logits to the spectrogram's frequency range

    # Plotting
    plt.figure(figsize=(10, 4))

    # Plot spectrogram
    plt.imshow(spec.numpy(), aspect='auto', origin='lower')

    # Overlay ground truth labels as a bar
    song_bar = ground_truth_label.numpy()[:, 0]  # Assuming first column is for 'song'
    not_song_bar = ground_truth_label.numpy()[:, 1]  # Assuming second column is for 'not song'
    plt.fill_between(range(spec.shape[1]), -5, 0, where=song_bar > 0.5, color='green', step='mid', alpha=0.5, label='Not Song')
    plt.fill_between(range(spec.shape[1]), -5, 0, where=not_song_bar > 0.5, color='red', step='mid', alpha=0.5, label='Song')

    # Overlay logits as line plots
    # Note: We add a small offset to avoid plotting directly on the bottom axis
    plt.plot(logits_scaled[:, 0], color='cyan', label='Logits - Not Song')
    plt.plot(logits_scaled[:, 1], color='magenta', label='Logits - Song')

    plt.colorbar(label='Spectrogram Intensity')
    plt.xlabel('Time Bins')
    plt.ylabel('Frequency Bins')
    plt.title('Spectrogram with Ground Truth and Logits')
    plt.legend(loc='upper right')

    plt.tight_layout()
    plt.show()


spec, ground_truth_label, _ = next(iter(test_loader))


logits = classifier_model.forward(spec.to(device))

# first batch 
spec = spec[0]
ground_truth_label = ground_truth_label[0]
logits = logits[0]

# remove channel dims
spec = spec[0]
# Example usage with your data (convert tensors to CPU if on a different device)
plot_spectrogram_with_labels_and_logits(spec.detach().cpu(), ground_truth_label.detach().cpu(), logits.detach().cpu())


In [None]:
import os

src = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/sort_these"

files_iterator = iter(os.listdir(src))  # Create an iterator over the files


In [None]:
import numpy as np
import torch

# Ensure to run this part only after initializing `files_iterator` as shown above

try:
    file = next(files_iterator)  # Get the next file from the iterator
    file_path = os.path.join(src, file)  # Full path to the file

    # Load the spectrogram from the file
    f = np.load(file_path, allow_pickle=True)
    spec = f['s']
    # spec = spec[20:216]  # Trimming the spectrogram

    # # Z-score normalization
    # spec_mean = spec.mean()
    # spec_std = spec.std()
    # spec_normalized = (spec - spec_mean) / spec_std

    # Forward pass through the model (assuming `classifier_model` and `device` are already defined)
    logits = classifier_model.forward(torch.Tensor(spec[:,:1000]).unsqueeze(0).unsqueeze(0).to(device))
    logits = logits[0]

    # Assuming `plot_spectrogram_with_labels_and_logits` and `ground_truth_label` are defined
    plot_spectrogram_with_labels_and_logits(torch.Tensor(spec[:,:1000]), ground_truth_label.detach().cpu(), logits.detach().cpu())

except StopIteration:
    print("No more files to process.")


In [None]:
import os
import numpy as np
import torch
from tqdm import tqdm  # Import tqdm for progress tracking


# Assuming `classifier_model` and `device` are already defined

def process_spectrogram(spec, max_length=1000):
    """
    Process the spectrogram in chunks, pass through the classifier, and return the combined logits.
    """
    # Calculate the number of chunks needed
    num_chunks = int(np.ceil(spec.shape[1] / max_length))
    combined_logits = []

    for i in range(num_chunks):
        # Extract the chunk
        start_idx = i * max_length
        end_idx = min((i + 1) * max_length, spec.shape[1])
        chunk = spec[:, start_idx:end_idx]

        # # Normalize the chunk
        # chunk_mean = chunk.mean()
        # chunk_std = chunk.std()
        # chunk_normalized = (chunk - chunk_mean) / chunk_std

        # Forward pass through the model
        logits = classifier_model.forward(torch.Tensor(chunk).unsqueeze(0).unsqueeze(0).to(device))
        logits = logits[0]

        # Convert logits to binary predictions
        binary_logits = (logits[:, 1] > logits[:, 0]).long()

        # Append the binary logits
        combined_logits.append(binary_logits.detach().cpu().numpy())

    # Concatenate all chunks' logits
    final_logits = np.concatenate(combined_logits, axis=-1)

    return final_logits

def process_files(src):
    """
    Process each file in the directory, reshape logits, and save them along with the original spectrogram.
    """
    files = os.listdir(src)
    for file in tqdm(files, desc="Processing files"):  # Wrap the loop with tqdm for progress tracking
        file_path = os.path.join(src, file)

        try:
            # Load the spectrogram from the file
            f = np.load(file_path, allow_pickle=True)
            spec = f['s']

            # Process the spectrogram and get logits
            logits = process_spectrogram(spec)

            # Save the spectrogram and logits
            save_path = os.path.join(src, f"processed_{file}")
            np.savez(save_path, s=spec, song=logits)

        except Exception as e:
            print(f"Failed to process file {file}: {str(e)}")


src = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/sort_these"
process_files(src)
