Fine Tuned Model using BEATs

In [9]:
import os
import sys
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import Trainer, TrainingArguments, set_seed
from peft import get_peft_model, LoraConfig

# Add the BEATs repo to the Python path so that the import works
sys.path.append("K:/DCASE/BEATs")
from BEATs import BEATs, BEATsConfig

# ===== CONFIGURATION =====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "K:/DCASE/BEATs_iter3.pt"  # frozen BEATs_iter3 checkpoint
TRAIN_DIR = "K:/DCASE/Trainingdata/Source"            # Contains 7 machine folders, 1000 .wav files each
OUTPUT_MODEL_PATH = "K:/DCASE/BEATs_finetuned_1.pt"
SEED = 42

# Hyperparameters
BATCH_SIZE = 2                 # Reduced batch size
NUM_EPOCHS = 5
LEARNING_RATE = 5e-5
TARGET_SECONDS = 15            # Shorter duration to reduce memory (15 sec instead of 30)
TARGET_LEN = 16000 * TARGET_SECONDS  # Total samples (240,000)

set_seed(SEED)

# ===== Load BEATs Model =====
print("🔄 Loading BEATs checkpoint...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
cfg = BEATsConfig()
cfg.input_patch_size = (16, 16)
cfg.conv_bias = checkpoint["cfg"].get("conv_bias", False)
cfg.frame_length = 384  # Use 384 for the frame_length requirement
model = BEATs(cfg)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(DEVICE)
model.eval()
print("✅ BEATs model loaded successfully!")

# ===== Apply QLoRA =====
print("⚙️ Applying QLoRA...")
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="FEATURE_EXTRACTION",
    target_modules=["q_proj", "v_proj"]  # Target attention projection layers
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# ===== Custom Dataset =====
class AudioDataset(Dataset):
    def __init__(self, base_dir):
        self.file_list = []
        for machine in os.listdir(base_dir):
            machine_path = os.path.join(base_dir, machine)
            if not os.path.isdir(machine_path):
                continue
            for fname in os.listdir(machine_path):
                if fname.lower().endswith(".wav"):
                    self.file_list.append(os.path.join(machine_path, fname))
        print(f"Found {len(self.file_list)} audio files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        wav_path = self.file_list[idx]
        try:
            waveform, sr = torchaudio.load(wav_path)
        except Exception as e:
            print(f"Error loading {wav_path}: {e}")
            waveform = torch.zeros(1, TARGET_LEN)
            sr = 16000

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            waveform = resampler(waveform)
        if waveform.shape[1] < TARGET_LEN:
            waveform = torch.nn.functional.pad(waveform, (0, TARGET_LEN - waveform.shape[1]))
        else:
            waveform = waveform[:, :TARGET_LEN]
        return {"input_values": waveform.squeeze(0)}  # Returns a 1D tensor of length TARGET_LEN

# ===== Data Collator =====
def collate_fn(batch):
    processed = []
    for item in batch:
        if not isinstance(item, dict) or "input_values" not in item:
            processed.append(torch.zeros(TARGET_LEN))
        else:
            processed.append(item["input_values"])
    inputs = torch.stack(processed)
    return {"input_values": inputs}

# ===== Prepare DataLoader =====
train_dataset = AudioDataset(TRAIN_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
# ===== Custom Trainer =====
class BEATsTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # inputs["input_values"] has shape: [B, T]
        waveforms = inputs["input_values"].to(DEVICE)  # [B, T]
        # Pass waveforms directly; let model.extract_features handle channel addition
        features = model.extract_features(waveforms)[0]
        loss = torch.mean(features ** 2)
        return (loss, features) if return_outputs else loss

training_args = TrainingArguments(
    output_dir="./beats-qlora-checkpoints",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=50,
    save_strategy="no",
    report_to="none",
    fp16=True,                      # Enable mixed precision
    gradient_accumulation_steps=4,  # Accumulate gradients for effective batch size
)

trainer = BEATsTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn
)

# ===== Fine-Tuning Loop =====
print("🚀 Starting QLoRA fine-tuning...")
trainer.train()
print("✅ Fine-tuning completed!")

# ===== Save the Fine-Tuned Model =====
print(f"💾 Saving fine-tuned model to {OUTPUT_MODEL_PATH}")
torch.save({
    "cfg": checkpoint["cfg"],
    "model": model.state_dict()
}, OUTPUT_MODEL_PATH)
print("✅ Fine-tuned model saved!")

🔄 Loading BEATs checkpoint...


No label_names provided for model class `PeftModelForFeatureExtraction`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


✅ BEATs model loaded successfully!
⚙️ Applying QLoRA...
trainable params: 294,912 || all params: 90,596,480 || trainable%: 0.3255
Found 6930 audio files.
🚀 Starting QLoRA fine-tuning...


Step,Training Loss


KeyboardInterrupt: 

In [3]:
import os
import sys
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import Trainer, TrainingArguments, set_seed
from peft import get_peft_model, LoraConfig

# Add the BEATs repo to the Python path so that the import works
sys.path.append("K:/DCASE/BEATs")
from BEATs import BEATs, BEATsConfig

# ===== CONFIGURATION =====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "K:/DCASE/BEATs_iter3.pt"  # frozen BEATs_iter3 checkpoint
TRAIN_DIR = "K:/DCASE/Trainingdata/Source"            # Contains 7 machine folders, 1000 .wav files each
OUTPUT_MODEL_PATH = "K:/DCASE/BEATs_finetuned_1.pt"
SEED = 42

# Hyperparameters
BATCH_SIZE = 2                 # Reduced batch size
NUM_EPOCHS = 12
LEARNING_RATE = 5e-5
TARGET_SECONDS = 15            # Shorter duration to reduce memory (15 sec instead of 30)
TARGET_LEN = 16000 * TARGET_SECONDS  # Total samples (240,000)

set_seed(SEED)

# ===== Load BEATs Model =====
print("🔄 Loading BEATs checkpoint...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
cfg = BEATsConfig()
cfg.input_patch_size = (16, 16)
cfg.conv_bias = checkpoint["cfg"].get("conv_bias", False)
cfg.frame_length = 384  # Use 384 for the frame_length requirement
model = BEATs(cfg)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(DEVICE)
model.eval()
print("✅ BEATs model loaded successfully!")

# ===== Apply QLoRA =====
print("⚙️ Applying QLoRA...")
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="FEATURE_EXTRACTION",
    target_modules=["q_proj", "v_proj"]  # Target attention projection layers
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# ===== Custom Dataset =====
class AudioDataset(Dataset):
    def __init__(self, base_dir):
        self.file_list = []
        for machine in os.listdir(base_dir):
            machine_path = os.path.join(base_dir, machine)
            if not os.path.isdir(machine_path):
                continue
            for fname in os.listdir(machine_path):
                if fname.lower().endswith(".wav"):
                    self.file_list.append(os.path.join(machine_path, fname))
        print(f"Found {len(self.file_list)} audio files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        wav_path = self.file_list[idx]
        try:
            waveform, sr = torchaudio.load(wav_path)
        except Exception as e:
            print(f"Error loading {wav_path}: {e}")
            waveform = torch.zeros(1, TARGET_LEN)
            sr = 16000

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            waveform = resampler(waveform)
        if waveform.shape[1] < TARGET_LEN:
            waveform = torch.nn.functional.pad(waveform, (0, TARGET_LEN - waveform.shape[1]))
        else:
            waveform = waveform[:, :TARGET_LEN]
        return {"input_values": waveform.squeeze(0)}  # Returns a 1D tensor of length TARGET_LEN

# ===== Data Collator =====
def collate_fn(batch):
    processed = []
    for item in batch:
        if not isinstance(item, dict) or "input_values" not in item:
            processed.append(torch.zeros(TARGET_LEN))
        else:
            processed.append(item["input_values"])
    inputs = torch.stack(processed)
    return {"input_values": inputs}

# ===== Prepare DataLoader =====
train_dataset = AudioDataset(TRAIN_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
# ===== Custom Trainer =====
class BEATsTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # inputs["input_values"] has shape: [B, T]
        waveforms = inputs["input_values"].to(DEVICE)  # [B, T]
        # Pass waveforms directly; let model.extract_features handle channel addition
        features = model.extract_features(waveforms)[0]
        loss = torch.mean(features ** 2)
        return (loss, features) if return_outputs else loss

training_args = TrainingArguments(
    output_dir="./beats-qlora-checkpoints",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=50,
    save_strategy="no",
    report_to="none",
    fp16=True,                      # Enable mixed precision
    gradient_accumulation_steps=4,  # Accumulate gradients for effective batch size
)

trainer = BEATsTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn
)

# ===== Fine-Tuning Loop =====
print("🚀 Starting QLoRA fine-tuning...")
trainer.train()
print("✅ Fine-tuning completed!")

# ===== Save the Fine-Tuned Model =====
print(f"💾 Saving fine-tuned model to {OUTPUT_MODEL_PATH}")
torch.save({
    "cfg": checkpoint["cfg"],
    "model": model.state_dict()
}, OUTPUT_MODEL_PATH)
print("✅ Fine-tuned model saved!")

🔄 Loading BEATs checkpoint...


No label_names provided for model class `PeftModelForFeatureExtraction`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


✅ BEATs model loaded successfully!
⚙️ Applying QLoRA...
trainable params: 294,912 || all params: 90,596,480 || trainable%: 0.3255
Found 6930 audio files.
🚀 Starting QLoRA fine-tuning...


Step,Training Loss
50,0.1262
100,0.0944
150,0.0591
200,0.0433
250,0.0377
300,0.0346
350,0.0326
400,0.0311
450,0.03
500,0.0291


✅ Fine-tuning completed!
💾 Saving fine-tuned model to K:/DCASE/BEATs_finetuned_1.pt
✅ Fine-tuned model saved!


In [None]:
import torch
device = "cuda" if torch.cuda.is_available()  else "cpu"


cuda


In [7]:
import os, sys, torch, torchaudio
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments, set_seed
from peft import get_peft_model, LoraConfig
from tqdm import tqdm

# ===== Paths =====
sys.path.append("K:/DCASE/BEATs")
from BEATs import BEATs, BEATsConfig

# ===== Config =====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOURCE_MODEL_PATH = "K:/DCASE/BEATs_finetuned_1.pt"
TARGET_DIR = "K:/DCASE/Trainingdata/Target"
OUTPUT_MODEL_PATH = "K:/DCASE/BEATs_finetuned_target.pt"
SEED = 42
BATCH_SIZE = 2
NUM_EPOCHS = 5
LEARNING_RATE = 3e-5
TARGET_SECONDS = 15
TARGET_LEN = 16000 * TARGET_SECONDS
set_seed(SEED)

# ===== Load base model and source adapter =====
print("🔄 Loading source-finetuned model...")
source_ckpt = torch.load(SOURCE_MODEL_PATH, map_location=DEVICE)
cfg = BEATsConfig()
cfg.input_patch_size = (16, 16)
cfg.conv_bias = source_ckpt["cfg"].get("conv_bias", False)
cfg.frame_length = 384

model = BEATs(cfg)
model.load_state_dict(source_ckpt["model"], strict=False)
model.to(DEVICE).eval()

# ===== Merge source adapter =====
source_lora = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05,
    bias="none", task_type="FEATURE_EXTRACTION",
    target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, source_lora)
model.load_state_dict(source_ckpt["model"], strict=False)
model.merge_and_unload()
print("✅ Source adapter merged into base model.")

# ===== Inject new LoRA adapter for PDA =====
target_lora = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05,
    bias="none", task_type="FEATURE_EXTRACTION",
    target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, target_lora)
model.print_trainable_parameters()

# ===== Target Dataset =====
class TargetAudioDataset(Dataset):
    def __init__(self, base_dir):
        self.file_list = []
        for folder in os.listdir(base_dir):
            fpath = os.path.join(base_dir, folder)
            if os.path.isdir(fpath):
                self.file_list += [
                    os.path.join(fpath, f) for f in os.listdir(fpath) if f.endswith(".wav")
                ]
        print(f"🎧 Loaded {len(self.file_list)} target audio files.")

    def __len__(self): return len(self.file_list)

    def __getitem__(self, idx):
        path = self.file_list[idx]
        try:
            waveform, sr = torchaudio.load(path)
        except:
            return {"input_values": torch.zeros(TARGET_LEN)}  # fallback

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        if sr != 16000:
            waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)

        waveform = torch.nn.functional.pad(waveform, (0, max(0, TARGET_LEN - waveform.shape[1])))
        waveform = waveform[:, :TARGET_LEN]
        return {"input_values": waveform.squeeze(0)}

dataset = TargetAudioDataset(TARGET_DIR)

# ===== Data Collator =====
def collate_fn(batch):
    processed = []
    for item in batch:
        if not isinstance(item, dict) or "input_values" not in item:
            processed.append(torch.zeros(TARGET_LEN))
        else:
            processed.append(item["input_values"])
    inputs = torch.stack(processed)
    return {"input_values": inputs}

# ===== Custom Trainer =====
class BEATsTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        x = inputs["input_values"].to(DEVICE)
        feat = model.extract_features(x)[0]  # [B, T, D]
        loss = torch.mean(feat ** 2)
        return (loss, feat) if return_outputs else loss

# ===== Training Args =====
training_args = TrainingArguments(
    output_dir="./pda-checkpoints",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=10,
    save_strategy="no",
    report_to="none",
    fp16=torch.cuda.is_available(),
    gradient_accumulation_steps=4
)

# ===== Train =====
trainer = BEATsTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn
)

print("🚀 Fine-tuning BEATs on target domain via PDA...")
trainer.train()
print("✅ PDA target fine-tuning complete.")

# ===== Save Final Model =====
print("💾 Merging target adapter into final model...")
model.merge_and_unload()
torch.save({"cfg": source_ckpt["cfg"], "model": model.state_dict()}, OUTPUT_MODEL_PATH)
print(f"✅ Final PDA model saved to: {OUTPUT_MODEL_PATH}")


🔄 Loading source-finetuned model...


No label_names provided for model class `PeftModelForFeatureExtraction`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


✅ Source adapter merged into base model.
trainable params: 294,912 || all params: 90,596,480 || trainable%: 0.3255
🎧 Loaded 70 target audio files.
🚀 Fine-tuning BEATs on target domain via PDA...


Step,Training Loss
10,0.0678
20,0.0678
30,0.0679
40,0.0677


✅ PDA target fine-tuning complete.
💾 Merging target adapter into final model...
✅ Final PDA model saved to: K:/DCASE/BEATs_finetuned_target.pt


In [8]:
import os
import torch
import torchaudio
import numpy as np
from BEATs import BEATs, BEATsConfig

# ==== CONFIGURATION ====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "K:/DCASE/BEATs_finetuned_target.pt"
BASE_INPUT_DIR = "K:/DCASE/Trainingdata"
BASE_OUTPUT_DIR = "C:/DCASE_Temp/FineTuned"
MASK_PARAM = 80  # SpecAugment mask width

# ==== LOAD CHECKPOINT ====
print("🔄 Loading checkpoint...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
cfg = BEATsConfig()
cfg.input_patch_size = (16, 16)
cfg.conv_bias = checkpoint["cfg"].get("conv_bias", False)

print("🚀 Initializing BEATs model...")
model = BEATs(cfg)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(DEVICE)
model.eval()
print("✅ BEATs model loaded successfully!")

# ==== SPEC-AUGMENT FUNCTION ====
def apply_specaugment(tensor, mask_param=80):
    tensor = tensor.unsqueeze(0)  # [1, Time, Feature]
    tensor = torchaudio.transforms.FrequencyMasking(freq_mask_param=mask_param)(tensor)
    tensor = torchaudio.transforms.TimeMasking(time_mask_param=mask_param)(tensor)
    return tensor.squeeze(0)  # [Time, Feature]

# ==== EMBEDDING EXTRACTION ====
def extract_beats_embedding(wav_file):
    try:
        waveform, sample_rate = torchaudio.load(wav_file)
        print(f"🔹 Loaded {wav_file}: Shape {waveform.shape}, Sample Rate {sample_rate}")
    except Exception as e:
        print(f"❌ Error loading {wav_file}: {e}")
        return None

    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)

    waveform = waveform.to(DEVICE)
    target_length = 16000 * 30
    waveform = torch.nn.functional.pad(waveform, (0, max(0, target_length - waveform.shape[1])))
    waveform = waveform[:, :target_length]

    print(f"🎵 Prepared waveform shape: {waveform.shape}")

    with torch.no_grad():
        try:
            features = model.extract_features(waveform)[0]  # [1, T, 768]
            print(f"🎛 Raw BEATs feature shape: {features.shape}")
            features = apply_specaugment(features.squeeze(0))  # [T, 768]
            return features.cpu().numpy()
        except Exception as e:
            print(f"❌ BEATs model error: {e}")
            return None

# ==== PROCESS FILES ====
def process_all_wav_files(base_input_dir, base_output_dir):
    os.makedirs(base_output_dir, exist_ok=True)

    for machine in os.listdir(base_input_dir):
        machine_input_dir = os.path.join(base_input_dir, machine)
        machine_output_dir = os.path.join(base_output_dir, machine)
        os.makedirs(machine_output_dir, exist_ok=True)

        for file in os.listdir(machine_input_dir):
            if not file.endswith(".wav"):
                continue

            input_path = os.path.join(machine_input_dir, file)
            output_path = os.path.join(machine_output_dir, f"BEATs_aug_{file.replace('.wav', '.npy')}")

            if os.path.exists(output_path):
                print(f"⏭️ Skipping {file} - already processed")
                continue

            embedding = extract_beats_embedding(input_path)
            if embedding is not None:
                print(f"💡 Embedding shape: {embedding.shape}")
                np.save(output_path, embedding)
                print(f"💾 Saved embedding to {output_path}")
            else:
                print(f"⚠️ Skipping {file} due to error.")

# ==== RUN ====
if __name__ == "__main__":
    process_all_wav_files(BASE_INPUT_DIR, BASE_OUTPUT_DIR)
    print("✅ All embeddings extracted with SpecAugment!")


🔄 Loading checkpoint...
🚀 Initializing BEATs model...
✅ BEATs model loaded successfully!
🔹 Loaded K:/DCASE/Trainingdata\train_bearing\section_00_source_train_normal_0001_pro_A_vel_4_loc_A.wav: Shape torch.Size([1, 160000]), Sample Rate 16000
🎵 Prepared waveform shape: torch.Size([1, 480000])
🎛 Raw BEATs feature shape: torch.Size([1, 1496, 768])
💡 Embedding shape: (1496, 768)
💾 Saved embedding to C:/DCASE_Temp/FineTuned\train_bearing\BEATs_aug_section_00_source_train_normal_0001_pro_A_vel_4_loc_A.npy
🔹 Loaded K:/DCASE/Trainingdata\train_bearing\section_00_source_train_normal_0002_pro_A_vel_4_loc_A.wav: Shape torch.Size([1, 160000]), Sample Rate 16000
🎵 Prepared waveform shape: torch.Size([1, 480000])
🎛 Raw BEATs feature shape: torch.Size([1, 1496, 768])
💡 Embedding shape: (1496, 768)
💾 Saved embedding to C:/DCASE_Temp/FineTuned\train_bearing\BEATs_aug_section_00_source_train_normal_0002_pro_A_vel_4_loc_A.npy
🔹 Loaded K:/DCASE/Trainingdata\train_bearing\section_00_source_train_normal_000