In [None]:
import torch
import os
import torchaudio
import torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm  # Use notebook version of tqdm

# Ensure we are in the correct directory within the container
# The Dockerfile clones the repo into /app
os.chdir('/app')
print(f"Current working directory: {os.getcwd()}")

# Check if model.py exists (it should)
if not os.path.exists('model.py'):
    print("ERROR: model.py not found in the current directory /app.")
    print("Make sure the XLSR-Mamba repository was cloned correctly in the Dockerfile.")
else:
    from model import Model  # Make sure this is Model, not XLSR_Mamba
    from utils import reproducibility

# === Configuration ===
# Adapted from the guide and repository defaults
class Args:
    # Model specific
    emb_size = 144
    num_encoders = 12
    FT_W2V = True # Use pretrained XLSR weights
    algo = 3  # 3 = DF (Deepfake Detection) based on checkpoint name

    # Paths within the container
    # Base XLSR model downloaded in Dockerfile and placed in /app
    model_path = '/app/xlsr2_300m.pt'
    # Fine-tuned checkpoints (mounted via volume)
    # Corrected path for DF model
    model_path_finetune = '/app/models_mounted/Bmamba3_LA_WCE_1e-06_ES144_NE12/best'

    # Checkpoint averaging
    n_average_model = 5 # Average top 5 checkpoints

    # Other settings
    loss = 'WCE' # From checkpoint path name
    lr = 1e-6    # From checkpoint path name
    seed = 1234
    comment = None # Not used for inference
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

print(f"Using device: {args.device}")
print(f"Algorithm Track (algo): {args.algo}")
print(f"Base XLSR Model Path: {args.model_path}")
print(f"Fine-tuned Checkpoint Dir: {args.model_path_finetune}")
print(f"Averaging {args.n_average_model} checkpoints.")

# === Ensure reproducibility ===
# Check if function exists before calling
if 'reproducibility' in globals():
    reproducibility(args.seed, args)
    # Assuming reproducibility function prints its own messages like:
    # cudnn_deterministic set to False
    # cudnn_benchmark set to True
else:
    print("Skipping reproducibility step as function 'reproducibility' not found.")

In [None]:
# === Initialize model ===
# Ensure Model class was imported successfully
if 'Model' in globals():
    print("Initializing model...")
    # Use the configuration 'args' object defined in Cell 1
    # The Model class __init__ takes 'args' and 'device'
    model = Model(args, device=args.device).to(args.device)
    print("Model initialized.")

    # === Load and average fine-tuned checkpoints ===
    print(f"🔄 Averaging top-{args.n_average_model} checkpoints from {args.model_path_finetune}...")

    # Check if the directory exists (mounted volume)
    if not os.path.isdir(args.model_path_finetune):
         print(f"ERROR: Checkpoint directory not found: {args.model_path_finetune}")
         print("Please ensure your local checkpoint directory is mounted correctly to /app/pretrained_models/models")
    else:
        sd_avg = None
        loaded_count = 0
        for i in range(args.n_average_model):
            ckpt_path = os.path.join(args.model_path_finetune, f'best_{i}.pth')
            if not os.path.exists(ckpt_path):
                print(f"Warning: Checkpoint {ckpt_path} not found. Skipping.")
                continue

            print(f" -> Loading {ckpt_path}")
            try:
                sd = torch.load(ckpt_path, map_location=args.device)
                if sd_avg is None:
                    sd_avg = sd
                else:
                    # Check if keys match before adding
                    if sd_avg.keys() != sd.keys():
                        print(f"ERROR: Checkpoint {i} has different keys than the first one. Cannot average.")
                        sd_avg = None # Invalidate averaging
                        break
                    for key in sd:
                        # Ensure keys exist in both dictionaries before adding
                        if key in sd_avg:
                           sd_avg[key] += sd[key]
                        else:
                           print(f"ERROR: Key '{key}' from checkpoint {i} not found in the first checkpoint. Cannot average.")
                           sd_avg = None # Invalidate averaging
                           break
                if sd_avg is None: # Break outer loop if averaging invalidated
                    break
                loaded_count += 1
            except Exception as e:
                print(f"Error loading checkpoint {ckpt_path}: {e}")
                # Optionally break if one checkpoint fails, or just continue
                # break

        if sd_avg is not None and loaded_count > 0:
            print(f"Averaging weights over {loaded_count} loaded checkpoints...")
            for key in sd_avg:
                # Ensure division is done correctly even if some checkpoints were skipped
                sd_avg[key] = torch.true_divide(sd_avg[key], loaded_count)

            print("Loading averaged state dict into model...")
            try:
                model.load_state_dict(sd_avg)
                model.eval()
                print("✅ Model loaded and averaged successfully.")
            except RuntimeError as e:
                print(f"Error loading averaged state_dict: {e}")
                print("This might indicate a mismatch between the averaged weights and the model architecture.")
            except Exception as e:
                print(f"An unexpected error occurred loading the averaged state_dict: {e}")

        elif loaded_count == 0:
             print("ERROR: No checkpoints were found or loaded. Model weights are not averaged.")
        else: # This happens if averaging failed mid-way
             print("ERROR: Could not average checkpoints due to errors (e.g., key mismatch, load error).")

else:
    print("ERROR: Model class not imported. Cannot initialize model.")


In [None]:
# Cell 3: Dataset Preparation (with Optional Multiprocessing Check)

import glob
import os
import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
from multiprocessing import Pool, cpu_count # Import multiprocessing components

# --- Configuration ---
data_dir = '/app/data/brspeech_df/'
target_file_extension = 'flac'
target_sample_rate = 16000 # Should match model's expected sample rate
num_check_workers = max(1, cpu_count() // 2) # Use half the CPU cores for checking, minimum 1
skip_file_checking = True # <<< Set to True to skip the detailed file check >>>

print(f"🔍 Searching for .{target_file_extension} files in {data_dir}...")
glob_pattern = os.path.join(data_dir, f'**/*.{target_file_extension}')
# Initial list of all found files
initial_audio_files = glob.glob(glob_pattern, recursive=True)

# --- Define Worker Function for File Check ---
# This function needs to be defined at the top level for multiprocessing
def check_audio_file(f_path):
    """Checks if an audio file is accessible, non-empty, and returns its path if valid."""
    try:
        info = torchaudio.info(f_path)
        if info.num_frames > 0:
            return f_path # Return path if valid and non-empty
        else:
            # Optionally print skipped empty files here, but it might get noisy with many workers
            # print(f"Skipping empty file: {f_path}")
            return None # Skip empty files
    except Exception as e:
        # Optionally print skipped error files here
        # print(f"Skipping file due to error: {f_path} ({e})")
        return None # Skip files that cause errors

# --- Perform File Checking (Optional) ---
if not initial_audio_files:
    print(f"⚠️ No .{target_file_extension} files found using pattern '{glob_pattern}'.")
    audio_files = [] # Ensure audio_files is defined for later checks
elif skip_file_checking:
    audio_files = initial_audio_files
    print(f"✅ Skipping detailed file check. Proceeding with all {len(audio_files)} found files.")
else:
    print(f"Found {len(initial_audio_files)} potential .{target_file_extension} files.")
    print(f"Verifying access and checking for non-empty files using {num_check_workers} workers...")

    valid_audio_files = []
    skipped_count = 0

    # Use multiprocessing Pool
    with Pool(processes=num_check_workers) as pool:
        # Use imap_unordered for potentially better performance and memory usage
        # tqdm shows progress as results come in
        results_iterator = pool.imap_unordered(check_audio_file, initial_audio_files)
        for result_path in tqdm(results_iterator, total=len(initial_audio_files), desc="Checking files"):
            if result_path:
                valid_audio_files.append(result_path)
            else:
                skipped_count += 1 # Count files that returned None (error or empty)

    audio_files = valid_audio_files
    print(f"\nProceeding with {len(audio_files)} accessible, non-empty files ({skipped_count} skipped due to errors or being empty).")


# --- Define Custom Dataset ---
class AudioInferenceDataset(Dataset):
    def __init__(self, file_paths, target_sr):
        self.file_paths = file_paths # Assumes this list is pre-filtered (or checking was skipped)
        self.target_sr = target_sr
        self.resamplers = {} # Cache resamplers to avoid re-creation

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        audio_path = self.file_paths[idx]
        try:
            waveform, sample_rate = torchaudio.load(audio_path)

            # Resample if necessary
            if sample_rate != self.target_sr:
                # Reuse or create resampler for the specific original sample rate
                if sample_rate not in self.resamplers:
                    self.resamplers[sample_rate] = T.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
                resampler = self.resamplers[sample_rate]
                waveform = resampler(waveform)

            # Ensure mono (select first channel if stereo)
            if waveform.shape[0] > 1:
                waveform = waveform[0, :].unsqueeze(0) # Result shape: [1, num_samples]

            # Ensure float32 (required by most models)
            waveform = waveform.float()

            # Remove channel dimension - model likely expects [num_samples]
            # or [batch, num_samples] after padding in the collate function.
            waveform = waveform.squeeze(0) # Result shape: [num_samples]

            # Return the processed waveform tensor and the original file path
            return waveform, audio_path

        except Exception as e:
            # This might catch errors not caught during pre-filtering (e.g., corrupted data)
            # Crucially important if skip_file_checking is True
            print(f"Error loading/processing {audio_path} within dataset __getitem__: {e}")
            # Return None for the waveform to indicate an error for this specific file
            return None, audio_path

# --- Instantiate Dataset ---
# Only create the dataset if valid files were found (or checking was skipped)
if audio_files:
    inference_dataset = AudioInferenceDataset(audio_files, target_sample_rate)
    print(f"✅ Dataset created with {len(inference_dataset)} items.")
else:
    inference_dataset = None # Set to None if no files were found initially
    print("Dataset not created as no audio files were found.")

# Note: The DataLoader will be created in the next cell (Cell 4)

In [None]:
# Cell 4: Batched Inference (Modified for OOM/Triton Errors)

import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence # For handling variable length sequences in a batch
from tqdm.notebook import tqdm
import os
import traceback # For detailed error printing
import gc # For garbage collection

# --- Configuration ---
batch_size = 1  # <<< Set batch size to 1 to avoid padding OOM errors >>>
num_workers = 2  # Number of parallel workers for data loading (adjust based on your CPU cores)

# --- Check if dataset and model are ready ---
if 'inference_dataset' not in globals() or inference_dataset is None or len(inference_dataset) == 0:
    print("🚫 Dataset not found or is empty. Please run Cell 3 successfully first.")
elif 'model' not in globals() or not hasattr(model, 'eval'):
     print("🚫 Model 'model' not found or not initialized. Please run Cell 2 first.")
else:
    # --- Collate Function (Simplified for batch_size=1, but keep padding for consistency) ---
    # Kept the padding logic, but with batch_size=1, it won't actually pad.
    def pad_collate_fn(batch):
        valid_items = [(wf, path) for wf, path in batch if wf is not None]
        if not valid_items:
            return None, None, None
        waveforms, paths = zip(*valid_items)
        lengths = torch.tensor([wf.size(0) for wf in waveforms], dtype=torch.long)
        padded_waveforms = pad_sequence(waveforms, batch_first=True, padding_value=0.0)
        return padded_waveforms, lengths, list(paths)

    # --- Create DataLoader ---
    device = args.device # Get device from Cell 1 args
    inference_dataloader = DataLoader(
        inference_dataset,
        batch_size=batch_size, # Using batch_size = 1
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if device == 'cuda' else False,
        collate_fn=pad_collate_fn
    )
    print(f"DataLoader created with batch size {batch_size}.")

    # --- Inference Loop ---
    results = []
    print(f"🚀 Starting inference (batch size 1) on {device}...")

    model.eval()
    model.to(device)

    for padded_waveforms_batch, lengths_batch, paths_batch in tqdm(inference_dataloader, desc="Inferencing Files"):

        if padded_waveforms_batch is None:
            # This could happen if the single item in the batch failed loading in __getitem__
            print(f"Skipping batch for {paths_batch[0] if paths_batch else 'unknown file'} due to loading error.")
            if paths_batch: # Ensure paths_batch is not empty
                 results.append({'filename': os.path.relpath(paths_batch[0], data_dir), 'score': float('nan')})
            continue

        padded_waveforms_batch = padded_waveforms_batch.to(device)

        try:
            with torch.no_grad():
                outputs = model(padded_waveforms_batch)

                # --- Process Output Scores (Adjust based on model output) ---
                # Assuming output is [1, 1] or [1, num_classes] for batch_size=1
                # Option 1: Output is [1, 1]
                scores_batch = outputs.squeeze().cpu().tolist() # Squeeze removes dims of size 1
                # Option 2: Output is [1, num_classes] (e.g., [bona_fide, spoof])
                # score_index = 1
                # scores_batch = outputs[0, score_index].cpu().item() # Get single item score

                # Ensure scores_batch is a list, even if squeeze resulted in a float
                if not isinstance(scores_batch, list):
                     scores_batch = [scores_batch] # Make it a list containing the single score

            # Since batch_size is 1, lengths should match
            for path, score in zip(paths_batch, scores_batch):
                results.append({
                    'filename': os.path.relpath(path, data_dir),
                    'score': score
                })

        except Exception as e:
            print(f"\n❌ Error processing file: {paths_batch[0]}...")
            print(f"Error type: {type(e).__name__}, Message: {e}")
            # traceback.print_exc() # Uncomment for full traceback if needed

            # Add error indicator for the failed file
            results.append({
                'filename': os.path.relpath(paths_batch[0], data_dir),
                'score': float('nan') # Indicate error with NaN
            })

            # --- Attempt to clear CUDA cache after error ---
            if device == 'cuda':
                print("Attempting to clear CUDA cache...")
                del padded_waveforms_batch # Explicitly delete tensor
                if 'outputs' in locals(): # Check if outputs was assigned before error
                    del outputs
                gc.collect() # Run Python garbage collector
                torch.cuda.empty_cache() # Ask PyTorch to release cached memory
                print("CUDA cache cleared.")

        finally:
             # Explicitly delete tensors from this iteration to help memory management, even on success
             if 'padded_waveforms_batch' in locals():
                 del padded_waveforms_batch
             if 'outputs' in locals():
                 del outputs

    # --- Display and Save Results ---
    if results:
        results_df = pd.DataFrame(results)
        print("\n--- Inference Results ---")
        print("Score Summary:")
        print(results_df['score'].describe())
        error_count = results_df['score'].isna().sum()
        total_files = len(results_df)
        success_count = total_files - error_count
        print(f"\nProcessed {total_files} files: {success_count} successful, {error_count} errors (marked as NaN).")
        print("\nFirst 5 results:")
        print(results_df.head())
        print("\nLast 5 results:")
        print(results_df.tail())

        # --- Optional: Save to CSV ---
        output_dir = '/app/outputs'
        output_csv_path = os.path.join(output_dir, 'xlsr_mamba_inference_scores_bs1.csv') # Changed filename
        print(f"\n💾 Saving results to {output_csv_path}...")
        try:
            os.makedirs(output_dir, exist_ok=True)
            results_df.to_csv(output_csv_path, index=False, float_format='%.8f')
            print("💾 Results saved successfully.")
        except Exception as e:
            print(f"❌ Error saving results to CSV: {e}")
    else:
        print("No results were generated.")
