In [None]:
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
import torch, torchaudio, os, random, numpy as np
from torch.utils.data import Dataset, DataLoader  
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import gradio as gr
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_path = "dog_sound_dataset"  # 你的資料集路徑

# 1. Processor + Model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = (
    Wav2Vec2ForSequenceClassification
    .from_pretrained(
        "facebook/wav2vec2-base",
        num_labels=4,
        problem_type="single_label_classification"
    )
    .to(device)
)
# model.classifier.dropout = nn.Dropout(0.05)

for p in model.wav2vec2.parameters():
    p.requires_grad = False

for layer in model.wav2vec2.encoder.layers[-2:]:
    for p in layer.parameters():
        p.requires_grad = True

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-5
)

# 2. Dataset & DataLoader (僅做 padding 相關修改)
elementary = {"Play": 0, "Defend": 1, "Beg": 2, "Fight": 3}
folders = ["Play", "Defend", "Beg", "Fight"]


class DogSoundDataset(Dataset):
    label_map = elementary

    def __init__(self, base_path, folders, sr=16000, split="train"):
        self.items = []
        self.sr = sr
        for folder in folders:
            subfolder = "train_wav" if split == "train" else "val_wav"
            full_path = os.path.join(base_path, folder, subfolder)
            if not os.path.exists(full_path):
                print(f"路徑不存在: {full_path}")
                continue

            label = self.label_map[folder]
            files = [f for f in os.listdir(full_path) if f.endswith(".wav")]
            random.shuffle(files)
            print(f"{split} - {folder}: 找到 {len(files)} 個檔案")

            for fname in files:
                self.items.append((os.path.join(full_path, fname), label))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, label = self.items[idx]
        wav, sr = torchaudio.load(path)
        wav = torchaudio.functional.resample(wav, sr, self.sr) if sr != self.sr else wav
        wav = wav.squeeze(0)
        wav = wav / (wav.abs().max() + 1e-8)

      
        proc_output = processor(
            wav.numpy(),
            sampling_rate=self.sr,
            return_tensors="pt",
            padding=True,  
            return_attention_mask=True  
        )
        input_values = proc_output.input_values.squeeze(0)       # (L,)
        attention_mask = proc_output.attention_mask.squeeze(0)   # (L,)

        return input_values, attention_mask, torch.tensor(label) 


def collate_fn(batch):
    all_vals, all_masks, all_labels = zip(*batch)
    # 1) 對所有 input_values 做 pad，變成 (B, L_max)
    vals = pad_sequence(all_vals, batch_first=True)       # (B, L_max)
    # 2) 同理，對所有 attention_mask 也 pad 到 (B, L_max)，padding 的地方自動補 0
    masks = pad_sequence(all_masks, batch_first=True)     # (B, L_max)

    # 3) （原本的）per-sample normalization，可以保持
    max_vals = vals.abs().amax(dim=1, keepdim=True)       # (B,1)
    vals = vals / (max_vals + 1e-8)

    labels = torch.stack(all_labels)                      # (B,)
    return vals, masks, labels                            # 回傳三個：waveform, mask, label


# 建立 train_ds / val_ds
train_ds = DogSoundDataset(base_path, folders, split="train")
val_ds   = DogSoundDataset(base_path, folders, split="val")


train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn
)

# 3. Training Loop
for epoch in range(10):
    model.train()
    train_loss = train_corr = train_total = 0
    for x, masks, y in train_loader:  
        x, masks, y = x.to(device), masks.to(device), y.to(device)
        
        out = model(input_values=x, attention_mask=masks, labels=y)
        loss, logits = out.loss, out.logits

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_corr += (logits.argmax(-1) == y).sum().item()
        train_total += y.size(0)

    torch.cuda.empty_cache()

    # Validation
    model.eval()
    val_loss = val_corr = val_total = 0
    preds, trues = [], []
    with torch.no_grad():
        for x, masks, y in val_loader:  
            x, masks, y = x.to(device), masks.to(device), y.to(device)
          
            out = model(input_values=x, attention_mask=masks, labels=y)
            val_loss += out.loss.item()
            logits = out.logits
            pred = logits.argmax(-1).cpu()
            preds.extend(pred.tolist())
            trues.extend(y.cpu().tolist())
            val_corr += (pred == y.cpu()).sum().item()
            val_total += y.size(0)

    val_loss = val_loss / len(val_loader)
    val_acc  = val_corr / val_total

    print(f"Epoch {epoch+1}: "
          f"Train Loss={train_loss/len(train_loader):.4f}, "
          f"Train Acc={train_corr/train_total:.4f} | "
          f"Val Loss={val_loss/len(val_loader):.4f}, "
          f"Val Acc={val_corr/val_total:.4f}")
    cm = confusion_matrix(trues, preds, labels=[0,1,2,3])
    disp = ConfusionMatrixDisplay(cm, display_labels=["Play","Defend","Beg","Fight"])
    disp.plot()
    plt.title(f"Epoch {epoch+1} Confusion Matrix with human-based fine tuned pre-trained model and imbalanced data")
    plt.tight_layout()
    plt.show()
    torch.cuda.empty_cache()

torch.save(model.state_dict(), "dog_classifier_final.pt")


def predict(path):
    wav, sr = torchaudio.load(path)
    if sr != 16000:
        wav = torchaudio.transforms.Resample(sr,16000)(wav)
    proc_output = processor(
        wav.squeeze(0).numpy(),
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True  
    )
    input_values = proc_output.input_values.to(device)
    attention_mask = proc_output.attention_mask.to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_values=input_values, attention_mask=attention_mask).logits
    label = ["Play", "Defend", "Beg", "Fight"][logits.argmax(dim=1).item()]
    return f"Predicted: {label}"

interface = gr.Interface(
    fn=predict,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="你在狗叫什麼",
    description="狗狗叫聲：Play/Defend/Beg/Fight"
)
interface.launch(share=True)
