# Steps in this notebook:
1. Set multi-index on train and val df
2. Upsample positives in train df
3. Create CNN object
4. Visualize tensors
5. Train model (and initiate WandB logging) 
# This notebook was used for the first training attempt on December 3rd, 2025.


/home/brg226/projects/vira_beg/training_data/second_pass_nov25

In [None]:
# the cnn module provides classes for training/predicting with various types of CNNs
from opensoundscape import CNN

#other utilities and packages
import torch
import pandas as pd
from pathlib import Path
import numpy as np
import random
import subprocess
from glob import glob
import sklearn
from opensoundscape.data_selection import resample
import random
import wandb

#set up plotting
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize']=[15,5] #for large visuals
%config InlineBackend.figure_format = 'retina'

In [None]:
# read in training dataframe
train_df = pd.read_pickle("/home/brg226/projects/vira_beg/training_data/second_pass_nov25/train_df_essential_nov25_2025.pkl")

In [None]:
# read in validation dataframe
val_df = pd.read_pickle("/home/brg226/projects/vira_beg/training_data/second_pass_nov25/val_df_essential_nov25_2025.pkl")

In [None]:
# Set file, start_time, and end_time as multi-index
train_df = train_df.set_index(['file', 'start_time', 'end_time'])
print("DataFrame after setting multi-index:")
print(f"Index names: {train_df.index.names}")
print(f"Shape: {train_df.shape}")
print(f"Columns: {train_df.columns.tolist()}")
train_df.head()

In [None]:
# Set file, start_time, and end_time as multi-index
val_df = val_df.set_index(['file', 'start_time', 'end_time'])
print("DataFrame after setting multi-index:")
print(f"Index names: {val_df.index.names}")
print(f"Shape: {val_df.shape}")
print(f"Columns: {val_df.columns.tolist()}")
val_df.head()

In [None]:
# upsample (repeat samples) so that all classes have 5000 samples
balanced_train_df = resample(train_df, n_samples_per_class=5000, random_state=0)

In [None]:
# Check DataFrame shape and viral column statistics
print(f"Balanced DataFrame shape: {balanced_train_df.shape}")
print(f"Total rows: {len(balanced_train_df)}")
print(f"Number of 1s in virail column: {(balanced_train_df['virail'] == 1).sum()}")
print(f"Number of 0s in virail column: {(balanced_train_df['virail'] == 0).sum()}")
print(f"\nClass distribution:")
print(balanced_train_df['virail'].value_counts())

In [None]:
# Save balanced_train_df as CSV file
output_path = "/home/brg226/projects/vira_beg/training_data/second_pass_nov25/balanced_train_df_nov25_2025.csv"
balanced_train_df.to_csv(output_path)
print(f"Saved balanced training DataFrame to: {output_path}")
print(f"Shape: {balanced_train_df.shape}")

In [None]:

# Create a CNN object 
from opensoundscape import CNN
#plot tensors to see dem
from opensoundscape.preprocess.utils import show_tensor


In [None]:

# Can use this code to get your classes, if needed
class_list = list(balanced_train_df.columns)

model = CNN(
    architecture= "resnet18",
    classes=class_list,
    sample_duration=1.5, 
)
model.preprocessor.pipeline.bandpass.set(min_f = 2000, max_f = 7000) # perhaps I need to bandpass differently


In [None]:

# Set up samples object for visualization, bypass_augmentations = False to see tensors with augmentations
samples = model.generate_samples(balanced_train_df.sample(20), bypass_augmentations = False)

In [None]:

# Visualize the first 50 samples with their labels (positive or negative for virail)
for x in range(50):
    show_tensor(samples[x].data)
    plt.show()
    
    print(f"Labels: {samples[x].labels}")

In [None]:
# Initiate WandB logging
try:
    wandb.login()
    wandb_session = wandb.init(
        entity="kitzeslab",  
        project="vira_beg",
        name="Train CNN", #you dont *have* to manually do a name, you can remove this and itll generate a rando name. Or you can manually rename it each time.
    )
except:  # if wandb.init fails, don't use wandb logging
    print("failed to create wandb session. wandb session will be None")
    wandb_session = None

In [None]:
model.train?

In [None]:
# Run nvidia-smi to check GPU status
model.device= 'cuda:1'

In [None]:
# Start model training!
model.train(
    train_df,
    validation_df=val_df,
    epochs=10,
    batch_size=128,
    num_workers=8,
    wandb_session=wandb_session,
    save_path='/home/brg226/projects/vira_beg/experiment_checkpoints/run_dec3_02',
)