In [1]:
#!pip install openai-whisper git+https://github.com/sooftware/conformer.git PyYAML gdown gradio -q
import torch
from cfg_parse import models_folder_path, cfg, class_file, data_base_path
# Check that we have a GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cpu


In [None]:
# %%bash
# # Create cfg.yaml with model parameters (adapted from the official repo)
# cat > ../models/RENE/cfg.yaml << 'CFG'
# # Data and model config
# device: 'cuda:0'        # computation device
# sampling_rate: 8000     # audio sampling rate
# win_len: 256            # STFT window length (25ms)
# hop: 80                 # STFT hop length (10ms)
# lowfreq: 50.0           # mel filterbank low freq cutoff
# highfreq: 2500.0        # mel filterbank high freq cutoff
# max_record_time: 16     # max duration of each recording (s)
# max_event_time: 3       # max duration of each respiratory event (s)
# # Model hyperparameters
# whisper_seq: 1500
# whiper_dim: 384
# encoder_dim: 256
# num_encoder_layers: 16
# num_attention_heads: 4
# rnn_hid_dim: 512
# rnn_layers: 2
# bidirect: true
# n_fc_layers: 2
# fc_layer_dim: 1024
# output_dim: 15
# input_dropout: 0.1
# feed_forward_dropout: 0.1
# attention_dropout: 0.1
# conv_dropout: 0.1
# rtb_data_channels: 1
# CFG

# %%bash
# # Create class-id.txt mapping 15 classes (Name|ID)
# cat > ../models/RENE/class-id.txt << 'CLASSIDS'
# Healthy|0
# Bronchiectasis|1
# Bronchiolitis|2
# COPD|3
# Asthma|4
# LRTI|5
# Pneumonia|6
# URTI|7
# Bronchitis|8
# Lung Fibrosis|9
# Asthma & Lung Fibrosis|10
# Heart Failure & Lung Fibrosis|11
# Heart Failure|12
# Heart Failure & COPD|13
# Pleural Effusion|14
# CLASSIDS

# %%bash
# # Create cfg_parse.py to load the YAML config
# cat > ../models/RENE/cfg_parse.py << 'PYCODE'
# import yaml
# cfg = yaml.safe_load(open('cfg.yaml'))
# PYCODE

bash: line 31: fg: no job control
bash: line 51: fg: no job control


In [2]:
# Import the config and define model architecture classes
import math
import torch
import torch.nn as nn
from conformer import Conformer

# Depthwise Separable Conv2D layer used in ReneTrialBlock
class DSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(DSConv2d, self).__init__()
        self.depth_conv = nn.Conv2d(
            in_channels=in_channels, out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=(kernel_size // 2, kernel_size // 2), groups=in_channels
        )
        self.pointwise_conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1)
        )
    def forward(self, x):
        out = self.depth_conv(x)
        out = self.pointwise_conv(out)
        return out

# ReneTrialBlock: the final convolutional block that produces class logits
class ReneTrialBlock(nn.Module):
    def __init__(self, cfg, in_channels):
        super(ReneTrialBlock, self).__init__()
        self.cfg = cfg
        # Left convolution flow
        self.left_flow = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=(1,1)),
            nn.BatchNorm2d(in_channels),
            nn.GELU(),
            DSConv2d(in_channels, in_channels, kernel_size=3),
            nn.BatchNorm2d(in_channels),
            nn.GELU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=(5,5), padding=(5//2, 5//2))
        )
        # Right convolution flow (mirror of left_flow with reversed conv order)
        self.right_flow = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=(5,5), padding=(5//2, 5//2)),
            nn.BatchNorm2d(in_channels),
            nn.GELU(),
            DSConv2d(in_channels, in_channels, kernel_size=3),
            nn.BatchNorm2d(in_channels),
            nn.GELU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=(1,1))
        )
        # Final linear layer: maps concatenated features to output classes
        self.layer = nn.Linear(cfg['rnn_hid_dim'] * 2, cfg['output_dim'])
    def forward(self, input_data):
        # input_data: [batch, channels*feature_map] as a flat vector
        # Reshape to 2D feature maps (assume square)
        feature_size = int(math.sqrt(cfg['rnn_hid_dim'] * 2))
        x = input_data.reshape(input_data.size(0), cfg['rtb_data_channels'], feature_size, feature_size)
        # Convolution flows and residual
        out = self.left_flow(x) + self.right_flow(x) + x
        # Flatten and linear layer to class logits
        out = out.view(input_data.size(0), -1)
        return self.layer(out)

# Main RENE Model class
class Model(nn.Module):
    def __init__(self, cfg):
        super(Model, self).__init__()
        self.cfg = cfg
        # Conformer encoder (from the installed library)
        self.conformer = Conformer(
            num_classes=cfg['rnn_hid_dim'],      # output feature dim = rnn hidden dim
            input_dim=cfg['whiper_dim'],         # Whisper encoder feature dimension
            encoder_dim=cfg['encoder_dim'],
            num_encoder_layers=cfg['num_encoder_layers'],
            num_attention_heads=cfg['num_attention_heads'],
            input_dropout_p=cfg['input_dropout'],
            feed_forward_dropout_p=cfg['feed_forward_dropout'],
            attention_dropout_p=cfg['attention_dropout'],
            conv_dropout_p=cfg['conv_dropout']
        )
        # Bidirectional GRU
        self.gru = nn.GRU(
            input_size=cfg['rnn_hid_dim'], hidden_size=cfg['rnn_hid_dim'],
            num_layers=cfg['rnn_layers'], bidirectional=cfg['bidirect']
        )
        # ReneTrialBlock for final classification
        self.rene = ReneTrialBlock(cfg, in_channels=cfg['rtb_data_channels'])
    def forward(self, x, input_lengths):
        # x: [batch, time_frames, whisper_dim], input_lengths: length of each sequence
        encoder_out, output_lengths = self.conformer(x, input_lengths)  # [batch, T, rnn_hid_dim]
        # Transpose to shape [T, batch, features] for GRU
        encoder_out = encoder_out.permute(1, 0, 2)
        rnn_out, _ = self.gru(encoder_out)         # rnn_out: [T, batch, 2*rnn_hid_dim] (bi-GRU)
        last_timestep = rnn_out[-1]                # take the last time-step output of GRU for each batch
        logits = self.rene(last_timestep)    # [batch, output_dim] = class scores
        return logits

# Instantiate the model (we will load pretrained weights next)
model = Model(cfg)
print("Model instantiated with %d output classes." % model.cfg['output_dim'])

Model instantiated with 15 output classes.


In [3]:
# Download the RENE(S) pretrained checkpoint from Google Drive
import os
model_path = models_folder_path / "Rene.pth"
if not os.path.exists(model_path):
    # Using gdown with the shared file ID
    !gdown --id 1NcGPIURY4mWtRr_KkwHAodssOexN-PbC -O {model_path}
else:
    print("Model checkpoint already downloaded.")

Model checkpoint already downloaded.


In [4]:
# Load the pretrained weights into the model

checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'], strict=False) 

# load fine tuned model
# checkpoint = torch.load("../models/RENE/fine_tuned_rene_model.pt", map_location='cpu')
# model.load_state_dict(checkpoint, strict=False)

model.to(device).eval()
print("Pretrained RENE model loaded.")

Pretrained RENE model loaded.


### Test Model with one sample

In [5]:
# Download a sample lung sound WAV (from SPRSound open dataset)
sample_url = "https://raw.githubusercontent.com/SJTU-YONGFU-RESEARCH-GRP/SPRSound/main/example/65097128_5.6_1_p1_2242.wav"
sample_path = data_base_path / "sample.wav"
!wget -q -O $sample_path $sample_url

In [6]:
import whisper

# Load Whisper tiny model for feature extraction
whisper_model = whisper.load_model("tiny").to(device)
whisper_model.eval()

# Load and preprocess the audio
audio = whisper.load_audio(data_base_path / "sample.wav")  # returns NumPy array in float32
# Whisper expects 16 kHz audio and pads/clips to 30 sec. We'll pad/trim to 16 sec (target of RENE)
MAX_SEC = 30
audio = whisper.pad_or_trim(audio, length=MAX_SEC * whisper.audio.SAMPLE_RATE)
mel = whisper.log_mel_spectrogram(audio).to(device)

# Use Whisper encoder to get audio features
with torch.no_grad():
    encoder_out = whisper_model.encoder(mel.unsqueeze(0).to(device))  # shape [1, n_frames, 384]
# Determine actual length in frames (to inform Conformer)
n_frames_total = encoder_out.shape[1]  # typically 1500 for 15s of audio after Whisper padding
# Estimate the number of frames corresponding to real (non-padded) audio content
orig_len_samples = min(len(audio), MAX_SEC * whisper.audio.SAMPLE_RATE)
orig_frames = math.floor(orig_len_samples / 160)  # 160-sample hop = 10ms frame step
input_length = torch.LongTensor([orig_frames // 2])  # //2 because Whisper encoder downsamples by 2x in time

# Run the RENE model to get class logits
encoder_out = encoder_out.to(device)
with torch.no_grad():
    logits = model(encoder_out, input_length.to(device))
    probs = torch.softmax(logits, dim=1)[0]  # probabilities for each of the 15 classes

# Load class names and print results
classes = class_file
top_idx = int(torch.argmax(probs))
top_class = classes[top_idx]
top_conf = probs[top_idx].item()

print(f"Top predicted class: **{top_class}** ({top_conf*100:.1f}% confidence)")
print("\nClass probabilities:")
ranked = sorted(zip(classes, probs.cpu().numpy()), key=lambda x: x[1], reverse=True)
for cls, p in ranked:
    print(f"  {cls:25s}: {p*100:.2f}%")



Top predicted class: **Healthy** (14.9% confidence)

Class probabilities:
  Healthy                  : 14.90%
  URTI                     : 14.63%
  Asthma & Lung Fibrosis   : 10.62%
  Pneumonia                : 10.19%
  Heart Failure            : 6.69%
  Heart Failure & Lung Fibrosis: 5.62%
  Bronchiolitis            : 5.52%
  Pleural Effusion         : 5.17%
  Bronchitis               : 5.09%
  Lung Fibrosis            : 4.88%
  LRTI                     : 3.94%
  Asthma                   : 3.70%
  Heart Failure & COPD     : 3.38%
  COPD                     : 2.96%
  Bronchiectasis           : 2.73%


## Load UI

In [7]:
import gradio as gr

# Define the prediction function for Gradio
def classify_respiratory_sound(audio_file):
    # Load audio from the uploaded file
    audio = whisper.load_audio(audio_file)
    audio = whisper.pad_or_trim(audio, length=MAX_SEC * whisper.audio.SAMPLE_RATE)
    mel = whisper.log_mel_spectrogram(audio).to(device)
    with torch.no_grad():
        enc_out = whisper_model.encoder(mel.unsqueeze(0).to(device))
    # Calculate original length in frames for masking
    orig_len = min(len(audio), MAX_SEC * whisper.audio.SAMPLE_RATE)
    orig_frames = math.floor(orig_len / 160)
    inp_len = torch.LongTensor([orig_frames // 2])
    with torch.no_grad():
        logits = model(enc_out.to(device), inp_len.to(device))
        probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
    # Prepare outputs
    top_idx = int(probs.argmax())
    top_label = classes[top_idx]
    # Build dict of class confidences
    confidences = {cls: float(probs[i]) for i, cls in enumerate(classes)}
    return top_label, confidences

# Create Gradio interface
interface = gr.Interface(
    fn=classify_respiratory_sound,
    inputs=gr.Audio(type="filepath", label="Upload Lung Sound (.wav)"),
    outputs=[
        gr.Textbox(label="Top Predicted Disease"),
        gr.Label(num_top_classes=15, label="All Class Probabilities")
    ],
    title="RENE Respiratory Disease Classifier",
    description="Upload a lung sound recording (.wav) to get the predicted respiratory condition and confidence scores for all 15 classes."
)

# Launch the Gradio app (in Colab, this will display an inline interface or a shareable link)
interface.launch(debug=False, share=True)

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

## Testing accuracy

In [None]:
import kagglehub
import shutil
from pathlib import Path

# Download latest version
custom_path = data_base_path / "respiratory-sound-database"
path = kagglehub.dataset_download("vbookshelf/respiratory-sound-database")
shutil.move(path, custom_path)

path = custom_path
print("Moved to:", custom_path)
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/vbookshelf/respiratory-sound-database?dataset_version_number=2...


  0%|          | 0.00/3.69G [00:00<?, ?B/s]

In [None]:
import os
import pandas as pd

# -----------------------------
# Configuration: Update these paths as needed.
# -----------------------------
wav_folder = data_base_path / "Respiratory_Sound_Database" / "Respiratory_Sound_Database" / "audio_and_txt_files"               # Folder containing all the .wav files
diagnosis_csv = data_base_path / "Respiratory_Sound_Database" / "Respiratory_Sound_Database" / "patient_diagnosis.csv"  # CSV file with patient diagnoses
label_csv = data_base_path / "labeled_wav_files.csv"             # Output CSV file

# -----------------------------
# Step 1: Load the patient diagnoses CSV.
# The CSV provided is in the format:
# 101,URTI
# 102,Healthy
# ...
# If the file does not have a header, we specify header=None.
# -----------------------------
diag_df = pd.read_csv(diagnosis_csv, header=None, names=["patient_id", "diagnosis"])

# Convert patient IDs to strings (to match with the extracted file name parts)
diag_df["patient_id"] = diag_df["patient_id"].astype(str)

# Create a mapping from patient_id to diagnosis.
diagnosis_map = diag_df.set_index("patient_id")["diagnosis"].to_dict()

# -----------------------------
# Step 2: List all .wav files in the specified directory.
# -----------------------------
wav_files = [f for f in os.listdir(wav_folder) if f.lower().endswith(".wav")]
print(f"Found {len(wav_files)} .wav files.")

# -----------------------------
# Step 3: For each .wav file, extract the patient ID and look up the diagnosis.
# Here we assume the file names start with the patient ID followed by an underscore.
# For example: '101_1b1_Al_sc_Meditron.wav' -> Patient ID: '101'
# -----------------------------
labeled_records = []
for file_name in wav_files:
    # Extract patient ID by splitting at the underscore.
    patient_id = file_name.split("_")[0]
    # Retrieve the diagnosis from the mapping.
    diagnosis = diagnosis_map.get(patient_id, "unknown")
    labeled_records.append({
        "wav_file": file_name,
        "patient_id": patient_id,
        "diagnosis": diagnosis
    })

# -----------------------------
# Step 4: Save the results into an output CSV.
# -----------------------------
labeled_df = pd.DataFrame(labeled_records)
labeled_df.to_csv(label_csv, index=False)
print(f"Labeled file saved to: {label_csv}")

Found 920 .wav files.
Labeled file saved to: /content/labeled_wav_files.csv


## Test of total dataset downloaded

In [None]:
# Tests accuracy of total dataset

import os
import math
import torch
import pandas as pd
import whisper

# -----------------------------
# Configuration
# -----------------------------
classes = class_file         # Text file containing class names (each line: "ClassName| ...")
MAX_SEC = 30  # Maximum duration (in seconds) for padding/trimming

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load Whisper model for feature extraction
# -----------------------------
whisper_model = whisper.load_model("tiny").to(device)
whisper_model.eval()

# -----------------------------
# Load your pre-trained RENE classification model.
# Make sure to load your model and move it to 'device' and set it to eval mode.
# Example (update with your actual code):
# model = torch.load("path/to/your/rene_model.pt", map_location=device)
# model.eval()
model.to(device)
model.eval()

# -----------------------------
# Read the CSV file containing true labels.
# Expected CSV format: wav_file,patient_id,diagnosis
# -----------------------------
df_labels = pd.read_csv(label_csv)
print("Loaded {} label entries.".format(len(df_labels)))

# -----------------------------
# Loop through each file, run through the pipeline, and collect predictions.
# For each file, check if the true label is present among the top 3 predictions.
# -----------------------------
correct = 0
total = 0
results = []  # To store (file, true_label, predicted_top1, predicted_top3)

for idx, row in df_labels.iterrows():
    wav_file = row['wav_file']
    true_label = row['diagnosis'].strip()  # True diagnosis label (e.g., "COPD", "Healthy", etc.)
    wav_path = os.path.join(wav_folder, wav_file)

    # Load and preprocess the audio using Whisper utilities.
    audio = whisper.load_audio(wav_path)  # Loads audio as NumPy float32 array
    audio = whisper.pad_or_trim(audio, length=MAX_SEC * whisper.audio.SAMPLE_RATE)

    # Get the log-mel spectrogram of the audio.
    mel = whisper.log_mel_spectrogram(audio).to(device)

    # Run the Whisper encoder to get features.
    with torch.no_grad():
        encoder_out = whisper_model.encoder(mel.unsqueeze(0).to(device))  # shape: [1, n_frames, 384]

    # Determine the number of frames corresponding to the original audio.
    orig_len_samples = min(len(audio), MAX_SEC * whisper.audio.SAMPLE_RATE)
    orig_frames = math.floor(orig_len_samples / 160)  # 160-sample hop (10ms per hop)
    input_length = torch.LongTensor([orig_frames // 2])  # Encoder downsamples time by 2×

    # Run the RENE classification model.
    encoder_out = encoder_out.to(device)
    with torch.no_grad():
        logits = model(encoder_out, input_length.to(device))
        probs = torch.softmax(logits, dim=1)[0]  # probabilities for each class

    # Get the top 3 predicted indices.
    topk = torch.topk(probs, k=3)
    top_indices = topk.indices.cpu().numpy()
    # Top predicted class (highest probability)
    top1_label = classes[int(torch.argmax(probs))]
    # Top 3 predicted class names
    top3_labels = [classes[i] for i in top_indices]

    # Count as correct if the true label is among the top 3 predictions.
    if true_label.lower() in [label.lower() for label in top3_labels]:
        correct += 1
    total += 1
    results.append((wav_file, true_label, top1_label, top3_labels))
    print(f"[{total}] File: {wav_file} | True: {true_label} | Top1: {top1_label} | Top3: {top3_labels}")

# -----------------------------
# Compute and print overall top-3 accuracy.
# -----------------------------
accuracy = (correct / total) * 100 if total > 0 else 0
print(f"\nOverall Top-3 Accuracy: {accuracy:.2f}% ({correct} / {total} correct)")

Loaded 920 label entries.
[1] File: 162_1b2_Ar_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[2] File: 193_1b2_Pl_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[3] File: 138_2p2_Ll_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[4] File: 207_2b2_Ar_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[5] File: 176_2b3_Pr_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[6] File: 151_2p3_Ll_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[7] File: 215_1b2_Ar_sc_Meditron.wav | True: Bronchiectasis | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[8] File: 210_1b1_Ar_sc_Meditron.wav | True: URTI | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[9] File: 156_8b3_Pl_mc_AKGC417L.wav | True: COPD | Top1: COPD | Top3: ['COPD', 'Healthy', 'Pneumonia']
[10] File: 140_2b3_Tc_mc_Lit

## Fine Tuning Model with data

In [None]:
import os
import math
import torch
import pandas as pd
import whisper
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim

# -----------------------------
# Configuration
# -----------------------------
classes = class_file
MAX_SEC = 30  # Maximum duration in seconds for padding/trimming

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# Load Whisper model for feature extraction
# -----------------------------
whisper_model = whisper.load_model("tiny").to(device)
whisper_model.eval()
# Freeze Whisper encoder parameters so that they are not updated during fine-tuning
for param in whisper_model.parameters():
    param.requires_grad = False

# -----------------------------
# Load your pre-trained RENE classification model.
# Replace the following placeholder with your actual model loading code.
# For example:
# model = torch.load("path/to/your/rene_model.pt", map_location=device)
# model.train()  # Set to training mode for fine-tuning.
# -----------------------------
model.to(device)
model.train()  # Set model to training mode

# -----------------------------
# Load class names from class-id.txt.
# Each line in the file is split at the '|' character and the first token is used.
# -----------------------------
print("Classes:", classes)

# Create a mapping from class name to index (we assume the CSV 'diagnosis' field exactly matches one of these class names).
class_to_idx = {cls: i for i, cls in enumerate(classes)}
num_classes = len(class_to_idx)
print("Class to index mapping:", class_to_idx)

# -----------------------------
# Define a custom dataset.
# For each entry, load the .wav file, extract features using Whisper encoder,
# and return (encoder_output, input_length, label) for training.
# -----------------------------
class RespiratoryDataset(Dataset):
    def __init__(self, csv_file, wav_folder, max_sec=30, whisper_model=whisper_model, class_to_idx=class_to_idx):
        self.df = pd.read_csv(csv_file)
        self.wav_folder = wav_folder
        self.max_sec = max_sec
        self.whisper_model = whisper_model
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        wav_file = row['wav_file']
        true_label = row['diagnosis'].strip()
        label_idx = self.class_to_idx[true_label]
        wav_path = os.path.join(self.wav_folder, wav_file)

        # Load and preprocess audio
        audio = whisper.load_audio(wav_path)  # Returns a NumPy float32 array.
        # Pad or trim audio to fixed length
        audio = whisper.pad_or_trim(audio, length=self.max_sec * whisper.audio.SAMPLE_RATE)
        # Compute log-mel spectrogram
        mel = whisper.log_mel_spectrogram(audio).to(device)

        # Pass through the Whisper encoder (without gradient updates)
        with torch.no_grad():
            encoder_out = self.whisper_model.encoder(mel.unsqueeze(0).to(device))  # shape: [1, n_frames, 384]

        # Determine the number of frames corresponding to the original (non-padded) audio.
        orig_len_samples = min(len(audio), self.max_sec * whisper.audio.SAMPLE_RATE)
        orig_frames = math.floor(orig_len_samples / 160)  # Whisper uses a hop of 160 samples (~10ms)
        # The encoder downsamples time by a factor of 2.
        input_length = torch.LongTensor([orig_frames // 2])

        # Squeeze the batch dimension from encoder output (resulting shape: [n_frames, 384])
        features = encoder_out.squeeze(0)

        return features, input_length, torch.tensor(label_idx, dtype=torch.long)

# -----------------------------
# Create dataset and split into training and validation sets.
# -----------------------------
dataset = RespiratoryDataset(label_csv, wav_folder, max_sec=MAX_SEC, whisper_model=whisper_model, class_to_idx=class_to_idx)
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders.
batch_size = 8  # You may need to adjust batch size based on available memory.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print("Training samples:", len(train_dataset), "Validation samples:", len(val_dataset))

# -----------------------------
# Define training parameters and loop.
# -----------------------------
num_epochs = 3
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

def train_epoch(model, data_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for features, input_length, labels in data_loader:
        # Move data to device.
        features = features.to(device)        # shape: [batch, n_frames, feature_dim]
        input_length = input_length.to(device)  # shape: [batch]
        labels = labels.to(device)              # shape: [batch]

        optimizer.zero_grad()
        # Forward pass through the classification model.
        logits = model(features, input_length)  # Expected output: [batch, num_classes]
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = (correct / total) * 100
    return epoch_loss, epoch_acc

def evaluate(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for features, input_length, labels in data_loader:
            features = features.to(device)
            input_length = input_length.to(device)
            labels = labels.to(device)

            logits = model(features, input_length)
            loss = criterion(logits, labels)
            running_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = (correct / total) * 100
    return epoch_loss, epoch_acc

# -----------------------------
# Training Loop
# -----------------------------
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

# -----------------------------
# Optionally, save the fine-tuned model.
# -----------------------------
torch.save(model.state_dict(), models_folder_path / "fine_tuned_rene_model.pt")
print("Model saved as fine_tuned_rene_model.pt")

Using device: cuda
Classes: ['Healthy', 'Bronchiectasis', 'Bronchiolitis', 'COPD', 'Asthma', 'LRTI', 'Pneumonia', 'URTI', 'Bronchitis', 'Lung Fibrosis', 'Asthma & Lung Fibrosis', 'Heart Failure & Lung Fibrosis', 'Heart Failure', 'Heart Failure & COPD', 'Pleural Effusion']
Class to index mapping: {'Asthma': 0, 'Asthma & Lung Fibrosis': 1, 'Bronchiectasis': 2, 'Bronchiolitis': 3, 'Bronchitis': 4, 'COPD': 5, 'Healthy': 6, 'Heart Failure': 7, 'Heart Failure & COPD': 8, 'Heart Failure & Lung Fibrosis': 9, 'LRTI': 10, 'Lung Fibrosis': 11, 'Pleural Effusion': 12, 'Pneumonia': 13, 'URTI': 14}
Training samples: 736 Validation samples: 184


KeyboardInterrupt: 