# 음성 분류 모델 생성

### 디펜던시 설치

In [None]:
!pip install ipywidgets tqdm torch torchvision ai-edge-torch

### 데이터로더 생성

In [None]:
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import ai_edge_torch
import tensorflow as tf

In [None]:
# 시드 고정
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# cuda 사용 가능한지 판단
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
ROOT_DIR = './custom_dataset'
TARGET_SAMPLE_RATE = 16000
TARGET_LENGTH_SEC = 5

데이터셋 불러오기

In [None]:
class AudioDataSet(Dataset):
    #Assume all data are in 16KHz
    #Returns (Raw waveform, label)

    def __init__(self, root_dir, target_sample_rate, target_length_sec):

        self.root_dir = root_dir
        self.target_sample_rate = target_sample_rate
        self.target_num_samples = target_sample_rate * target_length_sec

        self.file_list = []
        self.class_to_idx = {}
        self.classes = []

        class_names = sorted(
            [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        )
        
        for i, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = i
            self.classes.append(class_name)

        print(f"Found classes: {self.class_to_idx}")

        # --- 2. Find all .wav files and store their path and label ---
        for class_name in self.classes:
            class_idx = self.class_to_idx[class_name]
            class_dir = os.path.join(root_dir, class_name)
            
            for filename in os.listdir(class_dir):
                if filename.lower().endswith('.wav'): 
                    file_path = os.path.join(class_dir, filename)
                    self.file_list.append((file_path, class_idx))
        
        print(f"Found {len(self.file_list)} audio files.")

    def __len__(self):
        """Returns the total number of files in the dataset."""
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Loads one item (audio tensor and label) from the dataset.
        """
        file_path, label = self.file_list[idx]

        try:
            # --- 1. Load the audio file ---
            # We trust the user that the sample rate is correct.
            waveform, original_sr = torchaudio.load(file_path)

            # --- 2. Ensure single channel (mono) ---
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # --- 3. Pad to fixed length ---
            current_num_samples = waveform.shape[1]
            
            if current_num_samples < self.target_num_samples:
                # Pad with zeros to the right
                padding_needed = self.target_num_samples - current_num_samples
                waveform = F.pad(waveform, (0, padding_needed))
            
            # --- 4. Truncate if (unexpectedly) longer ---
            # This handles any files that might be > 5 sec
            elif current_num_samples > self.target_num_samples:
                waveform = waveform[:, :self.target_num_samples]

            return waveform, label

        except Exception as e:
            print(f"Error loading or processing file {file_path}: {e}")
            return None, None

In [None]:
def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]

    if not batch:
        return torch.tensor([]), torch.tensor([])
    
    waveforms = [item[0] for item in batch]
    labels = [item[1] for item in batch]

    waveforms_tensor = torch.stack(waveforms)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    return waveforms_tensor, labels_tensor

학습 데이터 80%, validation 10%, evaluation 10%

In [None]:
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
BATCH_SIZE = 32

In [None]:
#load all data
custom_dataset = AudioDataSet(
    root_dir = ROOT_DIR,
    target_sample_rate= TARGET_SAMPLE_RATE,
    target_length_sec= TARGET_LENGTH_SEC
)

total_size = len(custom_dataset)
train_size = int(total_size * TRAIN_RATIO)
val_size = int(total_size * VAL_RATIO)

test_size = total_size - train_size - val_size

print(f"Total: {total_size} |Train: {train_size} |Validation: {val_size} |Test: {test_size}")

train_dataset,val_dataset, test_dataset = torch.utils.data.random_split(
    custom_dataset, [train_size, val_size, test_size]
)

데이터로더 생성

In [None]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)
eval_loader = DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)

데이터 로더 테스트

In [None]:
print("\nTesting loaders - fetching one batch from each...")
try:
    # Get one batch from training loader
    waveforms_train, labels_train = next(iter(train_loader))
    print(f"Train batch waveforms shape: {waveforms_train.shape}")
    
    # Get one batch from validation loader
    waveforms_val, labels_val = next(iter(val_loader))
    print(f"Validation batch waveforms shape: {waveforms_val.shape}")
   
    # Get one batch from test loader
    waveforms_test, labels_test = next(iter(eval_loader))
    print(f"Test batch waveforms shape: {waveforms_test.shape}")
    
except Exception as e:
    print(f"Error while fetching a batch: {e}")

### 모델 생성

상수 설정

In [None]:
sample_rate = 16000
window_ms = 30
hop_ms = 20
NUM_CLASSES = 4
DROPOUT_RATE = 0.3

win_length_samples = int(sample_rate * (window_ms / 1000.0)) #480
hop_length_samples = int(sample_rate * (hop_ms / 1000.0)) #320

In [None]:
#오디오 -> 멜 스펙트로그램 전처리기
#원래는 모델에 내장하려고 했었지만 경량화가 지원이 되지 않아 ESP현현선 별도 구현
preprocessor = transforms.MelSpectrogram(
    sample_rate=16000,
    n_fft = 512,
    win_length= win_length_samples,
    hop_length = hop_length_samples,
    n_mels=64,
    power=2.0,
    normalized=False
)

모델 생성

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, dropout_rate=DROPOUT_RATE):
        super().__init__()

        self.inst_norm = nn.InstanceNorm2d(1)

        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=8,
            kernel_size=(5,5),
            stride=(1,1),
            padding="same"
        )
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(
            kernel_size=(2,2),
            stride = (2,2)
        )

        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16, 
            kernel_size=(3, 3), 
            stride=(1, 1), 
            padding='same')
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(
            kernel_size=(2, 2), 
            stride=(2, 2))
        
        # After conv2: [B, 16, 32, 125]
        # After pool2: [B, 16, 16, 62] (32/2, floor(125/2))
        self.flattened_size = 16 * 16 * 62
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fc1 = nn.Linear(
            in_features=self.flattened_size,
            out_features=num_classes
        )

    def forward(self, waveform):
        x = waveform.unsqueeze(1)
        x = self.inst_norm(x)

        x = self.relu1(self.conv1(x))
        x = self.pool1(x)

        x = self.relu2(self.conv2(x))
        x = self.pool2(x)

        x = torch.flatten(x, start_dim=1)
        
        x = self.dropout(x)
        x = self.fc1(x)
        return x

In [None]:
#테스트용 더미 데이터 생성
dummy_waveform = torch.randn(32, 80000)
preprocessor = transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft = 512,
            win_length= win_length_samples,
            hop_length = hop_length_samples,
            n_mels=64,
            power=2.0,
            normalized=False
        )
model = Model()
output_logits = model(preprocessor(dummy_waveform))

print(f"Model output shape: {output_logits.shape}")
probabilities = torch.softmax(output_logits, dim=1)
print(f"Probabilities shape: {probabilities.shape}")
print(probabilities)

### 모델 학습

In [None]:
#관련 함수 정의 및 early stopping을 위한 변수 선언

patience = 3 #만약 3epoch 동안 accuracy가 개선되지 않으면 종료
best_val_loss = np.inf
epoch_no_improve = 0
early_stop = False
best_model_path = 'model_checkpoint.pth'

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

#위에서도 정의되었지만 가독성을 위해 다시 불러오기
preprocessor = transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft = 512,
            win_length= win_length_samples,
            hop_length = hop_length_samples,
            n_mels=64,
            power=2.0,
            normalized=False
        ).to(device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

학습 페이즈

In [None]:
for epoch in range(num_epochs):
    print(f"===== {epoch +1 } / {num_epochs} =====")
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    model.train()
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    for waveforms, labels in progress_bar:
        waveforms, labels = waveforms.to(device), labels.to(device)

        optimizer.zero_grad()

        squeezed_waveform = waveforms.squeeze(1)
        outputs = model(preprocessor(squeezed_waveform))
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * waveforms.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        progress_bar.set_postfix(loss=loss.item())
    train_loss= running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    print(f"Epoch {epoch+1} Training   -> Loss: {train_loss:.4f}, Accuracy: {epoch_acc:.4f}")
    
    #validation
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validation", leave=False)
        for waveforms, labels in progress_bar:
            waveforms, labels = waveforms.to(device), labels.to(device)

            outputs = model(preprocessor(waveforms.squeeze(1)))
            loss = criterion(outputs, labels)

            running_loss += loss.item() * waveforms.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

            progress_bar.set_postfix(loss=loss.item())
    val_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    print(f"Epoch {epoch+1} Validation   -> Loss: {val_loss:.4f}, Accuracy: {epoch_acc:.4f}")

    #early stop 확인
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epoch_no_improve = 0
        torch.save(model.state_dict(), best_model_path)
        print("model saved as validation loss improved")
    else:
        epoch_no_improve += 1
        print(f"validate did not improve for {epoch_no_improve} epochs")

    if epoch_no_improve >= patience:
        print(f"Early stopping triggered at {epoch + 1} epochs")
        early_stop = True
        break
if not early_stop:
    print("Training completed!")
else:
    print("early stop executed")

early stop된 모델 파라미터 불러오기

In [None]:
model.load_state_dict(torch.load(best_model_path))

Evaluation

In [None]:
print("Start Evaluation Phase:")
model.eval()
running_loss = 0.0
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    progress_bar = tqdm(eval_loader, desc= "Evaluation", leave = False)
    for waveforms, labels in progress_bar:
        waveforms, labels = waveforms.to(device), labels.to(device)

        outputs = model(preprocessor(waveforms.squeeze(1)))
        loss = criterion(outputs, labels)

        running_loss += loss.item() * waveforms.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        progress_bar.set_postfix(loss=loss.item())
epoch_loss = running_loss / total_samples
epoch_acc = correct_predictions / total_samples
print(f"Epoch {epoch+1} Evaluation   -> Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

### 모델 경량화

모델 경량화를 위해선 값 범위가 필요하기 때문에 경량화 과정중에 데이터를 입력해줘야 한다.

In [None]:
preprocessor.to(device)

def representative_dataset_gen():
    for i, (waveforms, _) in enumerate(val_loader):
        if i >= 50:
            print("...dataset generation complete.")
            break
        spectrograms = preprocessor(waveforms.squeeze(1).to(device))

        spectrograms_cpu = spectrograms.to("cpu")
        for j in range(spectrograms_cpu.shape[0]):
            sample = spectrograms_cpu[j]
            sample_with_batch_dim = sample.unsqueeze(0)

            yield[sample_with_batch_dim]

In [None]:
model.to("cpu")
model.eval()

#입력 크기를 넣어줘야 한다
sample_args = (torch.randn(1,64,251).to("cpu"),)

tfl_converter_flags = {
    "optimizations": [tf.lite.Optimize.DEFAULT],
    "representative_dataset": representative_dataset_gen,
    "target_spec": {
        "supported_ops": [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    },
    "inference_input_type": tf.int8,
    "inference_output_type": tf.int8,
}

tfl_int8_model = ai_edge_torch.convert(
    model, 
    sample_args, 
    _ai_edge_converter_flags=tfl_converter_flags
)

In [None]:
INT8_MODEL_PATH = "model_full_int8.tflite"
tfl_int8_model.export(INT8_MODEL_PATH)

print(f"Successfully exported full INT8 model to: {INT8_MODEL_PATH}")

경량화된 모델의 성능 확인

In [None]:
print("Start Evaluation Phase:")
tfl_int8_model.to(device)

tfl_int8_model.eval()
running_loss = 0.0
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    progress_bar = tqdm(eval_loader, desc= "Evaluation", leave = False)
    for waveforms, labels in progress_bar:
        waveforms, labels = waveforms.to(device), labels.to(device)

        outputs = tfl_int8_model(preprocessor(waveforms.squeeze(1)))
        loss = criterion(outputs, labels)

        running_loss += loss.item() * waveforms.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        progress_bar.set_postfix(loss=loss.item())
epoch_loss = running_loss / total_samples
epoch_acc = correct_predictions / total_samples
print(f"Epoch {epoch+1} Evaluation   -> Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")