<a href="https://colab.research.google.com/github/ituki0426/How_to_improve_detecting_AI_voice_changer/blob/main/notebook/custom_HuBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 準備

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
import os
import torchaudio
import torch
import torch.nn as nn
from datasets import Dataset, DatasetDict
from transformers import AutoFeatureExtractor, TrainingArguments, Trainer
from transformers import AutoModelForAudioClassification
from datasets import DatasetDict
from transformers import DataCollatorWithPadding
import torch.nn.functional as F
import random
from transformers import AutoModel
from sklearn.metrics import accuracy_score
import numpy as np

# モデル定義

In [None]:
class HuBERTWithLogMel(nn.Module):
    def __init__(self, hubert_model, num_labels=2):
        super(HuBERTWithLogMel, self).__init__()
        self.hubert = hubert_model
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc_mel = nn.Linear(1228800, hubert_model.config.hidden_size)
        self.classifier = nn.Linear(hubert_model.config.hidden_size, num_labels)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_values, attention_mask, mel_spec, labels=None):
        # HuBERTの出力
        hubert_output = self.hubert(input_values=input_values, attention_mask=attention_mask)
        hubert_hidden_state = hubert_output.last_hidden_state[:, 0, :]  # [CLS]トークンの出力

        # CNNでログメルスペクトログラムを処理
        cnn_output = self.cnn(mel_spec)
        cnn_output = cnn_output.view(cnn_output.size(0), -1)  # フラット化
        mel_hidden_state = self.fc_mel(cnn_output)  # HuBERT隠れ層次元に合わせる

        # HuBERTの出力とログメルスペクトログラムの出力を加算
        combined_hidden_state = hubert_hidden_state + mel_hidden_state

        # 分類層
        logits = self.classifier(combined_hidden_state)

        # ラベルが指定されている場合、損失を計算
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)

        # 損失とlogitsを返す
        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}



# データセットの準備関数

In [None]:
def load_audio_data(feature_extractor, expanded_dir, kanata_dir, max_length, sampling_rate, mel_transform):
    data = []
    num = 0
    # Label 0 のデータを処理 (expanded_dir)
    for idx, file in enumerate(os.listdir(expanded_dir)):
        if file.endswith('.wav') and idx < 40:  # 制限をかけてサンプル数を減らす
            filepath = os.path.join(expanded_dir, file)
            waveform, sr = torchaudio.load(filepath)

            # Resample
            if sr != sampling_rate:
                resampler = torchaudio.transforms.Resample(sr, sampling_rate)
                waveform = resampler(waveform)

            # パディングまたは切り取り
            if waveform.size(1) < max_length:
                waveform = F.pad(waveform, (0, max_length - waveform.size(1)))
            else:
                waveform = waveform[:, :max_length]

            # 特徴量抽出
            inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=sampling_rate, return_attention_mask=True)
            mel_spec = mel_transform(waveform)  # ログメルスペクトログラムの抽出
            mel_spec = torch.log1p(mel_spec)  # 対数を取る
            num = num + 1
            print(f"now : {num}")
            data.append({
                "label": 0,  # expanded_dir のラベル
                "input_values": inputs["input_values"][0],
                "attention_mask": inputs["attention_mask"][0],
                "mel_spec": mel_spec
            })

    # Label 1 のデータを処理 (kanata_dir)
    for idx, file in enumerate(os.listdir(kanata_dir)):
        if file.endswith('.wav') and idx < 40:  # 制限をかけてサンプル数を減らす
            filepath = os.path.join(kanata_dir, file)
            waveform, sr = torchaudio.load(filepath)

            # Resample
            if sr != sampling_rate:
                resampler = torchaudio.transforms.Resample(sr, sampling_rate)
                waveform = resampler(waveform)

            # パディングまたは切り取り
            if waveform.size(1) < max_length:
                waveform = F.pad(waveform, (0, max_length - waveform.size(1)))
            else:
                waveform = waveform[:, :max_length]

            # 特徴量抽出
            inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=sampling_rate, return_attention_mask=True)
            mel_spec = mel_transform(waveform)  # ログメルスペクトログラムの抽出
            mel_spec = torch.log1p(mel_spec)  # 対数を取る
            num = num + 1
            print(f"now : {num}")
            data.append({
                "label": 1,  # kanata_dir のラベル
                "input_values": inputs["input_values"][0],
                "attention_mask": inputs["attention_mask"][0],
                "mel_spec": mel_spec
            })

    return data

In [None]:
def prepare_dataset(feature_extractor, expanded_dir, kanata_dir, max_length, sampling_rate, mel_transform, train_split=0.8):
    data = load_audio_data(feature_extractor, expanded_dir, kanata_dir, max_length, sampling_rate, mel_transform)

    # シャッフル
    random.shuffle(data)

    # データ分割
    train_size = int(len(data) * train_split)
    train_data = data[:train_size]
    test_data = data[train_size:]

    # DatasetDict の作成
    def convert_to_dict(data):
        return {
            "label": [item["label"] for item in data],
            "input_values": [item["input_values"] for item in data],
            "attention_mask": [item["attention_mask"] for item in data],
            "mel_spec": [item["mel_spec"] for item in data],
        }

    dataset = DatasetDict({
        "train": Dataset.from_dict(convert_to_dict(train_data)),
        "test": Dataset.from_dict(convert_to_dict(test_data)),
    })
    return dataset

In [None]:
# フォルダパスを指定
expanded_dir = "/content/drive/MyDrive/customBERT/expanded"
kanata_dir = "/content/drive/MyDrive/customBERT/kanata"

# Feature Extractor のロード
feature_extractor = AutoFeatureExtractor.from_pretrained('rinna/japanese-hubert-base')

# サンプリングレートと最大長
sampling_rate = feature_extractor.sampling_rate
max_length = int(sampling_rate * 30)  # 30秒

# ログメルスペクトログラム変換
mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sampling_rate, n_mels=128)

# データセットの準備
print("Preparing dataset...")
dataset = prepare_dataset(feature_extractor, expanded_dir, kanata_dir, max_length, sampling_rate, mel_transform)
print(dataset)
print("Done.")

preprocessor_config.json:   0%|          | 0.00/216 [00:00<?, ?B/s]



Preparing dataset...
now : 1
now : 2
now : 3
now : 4
now : 5
now : 6
now : 7
now : 8
now : 9
now : 10
now : 11
now : 12
now : 13
now : 14
now : 15
now : 16
now : 17
now : 18
now : 19
now : 20
now : 21
now : 22
now : 23
now : 24
now : 25
now : 26
now : 27
now : 28
now : 29
now : 30
now : 31
now : 32
now : 33
now : 34
now : 35
now : 36
now : 37
now : 38
now : 39
now : 40
now : 41
now : 42
now : 43
now : 44
now : 45
now : 46
now : 47
now : 48
now : 49
now : 50
now : 51
now : 52
now : 53
now : 54
now : 55
now : 56
now : 57
now : 58
now : 59
now : 60
now : 61
now : 62
now : 63
now : 64
now : 65
now : 66
now : 67
now : 68
now : 69
now : 70
now : 71
now : 72
now : 73
now : 74
now : 75
now : 76
now : 77
now : 78
now : 79
now : 80
DatasetDict({
    train: Dataset({
        features: ['label', 'input_values', 'attention_mask', 'mel_spec'],
        num_rows: 64
    })
    test: Dataset({
        features: ['label', 'input_values', 'attention_mask', 'mel_spec'],
        num_rows: 16
    })
})
Done

In [None]:
# HuBERTモデルのロード（隠れ層出力を取得可能なモデル）
hubert_model = AutoModel.from_pretrained('rinna/japanese-hubert-base')

# カスタムモデルの初期化
model = HuBERTWithLogMel(hubert_model, num_labels=2)

print(model)

config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

HuBERTWithLogMel(
  (hubert): HubertModel(
    (feature_extractor): HubertFeatureEncoder(
      (conv_layers): ModuleList(
        (0): HubertGroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x HubertNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x HubertNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): HubertFeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): HubertEnco

In [None]:
# 評価関数を定義
def compute_metrics(pred):
    labels = pred.label_ids  # 正解ラベル
    preds = np.argmax(pred.predictions, axis=1)  # 予測値（最も高い確率のクラスを選択）
    acc = accuracy_score(labels, preds)  # 正解率を計算
    return {"accuracy": acc}


In [None]:
# トレーニング設定
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=20,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
    report_to="wandb",
    run_name="audio-classification"
)

# Data Collator
data_collator = DataCollatorWithPadding(tokenizer=feature_extractor)

# Trainer の初期化（compute_metricsを追加）
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics  # 評価関数を追加
)

# トレーニング開始
trainer.train()

  trainer = Trainer(


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6943,0.454093,0.9375
2,0.2526,0.309399,0.9375
3,0.169,0.217254,0.9375
4,0.0808,0.280794,0.9375


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6943,0.454093,0.9375
2,0.2526,0.309399,0.9375
3,0.169,0.217254,0.9375
4,0.0808,0.280794,0.9375
5,0.0221,0.325377,0.75
6,0.0171,0.343335,0.9375
7,0.0213,0.228357,0.875
8,0.0045,0.923829,0.9375
9,0.0001,0.16037,0.9375
10,0.0001,0.626578,0.9375


TrainOutput(global_step=320, training_loss=0.06306835205596428, metrics={'train_runtime': 2623.0245, 'train_samples_per_second': 0.488, 'train_steps_per_second': 0.122, 'total_flos': 0.0, 'train_loss': 0.06306835205596428, 'epoch': 20.0})

In [None]:

# デバイスを設定（GPUが利用可能な場合はGPUを使用）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# サンプリングレートと最大長
sampling_rate = feature_extractor.sampling_rate
max_length = int(sampling_rate * 30)  # 30秒

# ログメルスペクトログラム変換
mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sampling_rate, n_mels=128)

# 推論対象の音声ファイルパスリスト
paths = [
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1408_1.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1408.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1407.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1396_1.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1389_1.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1386.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1356_2.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1323.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1274.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1214.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1133.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_1001.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_0932.wav.wav",
    "/content/drive/MyDrive/customBERT/kanata/BASIC5000_0844.wav.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1408_1.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1406.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1399.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1387.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1345.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1315.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1275.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1231_1.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1172.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1124.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1098_1.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1071.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_1009.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_0979_1.wav",
    "/content/drive/MyDrive/customBERT/expanded/BASIC5000_0937.wav"
]

# 推論結果を保存するリスト
results = []

# 推論処理
for audio_path in paths:
    try:
        # 音声データの読み込み
        waveform, sr = torchaudio.load(audio_path)

        # リサンプリング
        if sr != sampling_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sampling_rate)
            waveform = resampler(waveform)

        # パディングまたは切り取り
        if waveform.size(1) < max_length:
            waveform = F.pad(waveform, (0, max_length - waveform.size(1)))
        else:
            waveform = waveform[:, :max_length]

        # 特徴量抽出
        inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=sampling_rate, return_attention_mask=True)
        mel_spec = mel_transform(waveform)
        mel_spec = torch.log1p(mel_spec)  # ログスケールに変換

        # モデル入力の準備
        input_values = torch.tensor(inputs["input_values"]).to(device)
        attention_mask = torch.tensor(inputs["attention_mask"]).to(device)
        mel_spec = mel_spec.unsqueeze(0).to(device)  # バッチ次元を追加してGPUに送る

        # 推論
        with torch.no_grad():
            outputs = model(input_values=input_values, attention_mask=attention_mask, mel_spec=mel_spec)

        # 結果の取得
        logits = outputs["logits"]
        predicted_class = torch.argmax(logits, dim=1).item()

        # 結果を保存
        results.append({
            "audio_path": audio_path,
            "predicted_class": predicted_class
        })

        print(f"Processed: {audio_path}, Predicted class: {predicted_class}")

    except Exception as e:
        print(f"Error processing {audio_path}: {e}")

# 推論結果の表示
print("\nPrediction Results:")
for result in results:
    print(f"Audio: {result['audio_path']}, Predicted Class: {result['predicted_class']}")


  attention_mask = torch.tensor(inputs["attention_mask"]).to(device)


Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1408_1.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1408.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1407.wav.wav, Predicted class: 0
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1396_1.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1389_1.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1386.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1356_2.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1323.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1274.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBERT/kanata/BASIC5000_1214.wav.wav, Predicted class: 1
Processed: /content/drive/MyDrive/customBE