In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os

ravdess_path = '/content/drive/MyDrive/RAVDESS'

# 폴더 존재 여부 확인
if os.path.exists(ravdess_path):
    print(f"RAVDESS 폴더가 '{ravdess_path}'에 있습니다.")
else:
    print(f"RAVDESS 폴더가 '{ravdess_path}'에 없습니다. 경로를 확인하세요.")


RAVDESS 폴더가 '/content/drive/MyDrive/RAVDESS'에 있습니다.


In [None]:
import os

def extract_emotion_from_filename(filename):
    emotion_code = int(filename.split('-')[2])
    return emotion_code - 1  # make it 0~7

def load_ravdess_dataset(data_dir):
    audio_paths = []
    labels = []
    for root, _, files in os.walk(data_dir):  # 재귀적으로 탐색
        for fname in files:
            if fname.lower().endswith(".wav"):
                path = os.path.join(root, fname)
                label = extract_emotion_from_filename(fname)
                audio_paths.append(path)
                labels.append(label)
    return audio_paths, labels

In [11]:
from torch.utils.data import Dataset
from transformers import Wav2Vec2Processor
import torchaudio

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

class RAVDESSDataset(Dataset):
    def __init__(self, file_paths, labels):
        self.file_paths = file_paths
        self.labels = labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        waveform, sr = torchaudio.load(path)
        if sr != 16000:
            waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
            waveform = waveform[:, :48000]
        inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.squeeze(0)
        attention_mask = inputs.attention_mask.squeeze(0) if 'attention_mask' in inputs else None
        label = self.labels[idx]
        return input_values, attention_mask, label


In [12]:
def collate_fn(batch):
    input_values = [item[0] for item in batch]
    labels = torch.tensor([item[2] for item in batch], dtype=torch.long)

    padded = processor.pad(
        {"input_values": input_values},
        padding=True,
        return_tensors="pt"
    )

    return padded.input_values, labels

In [13]:
from transformers import Wav2Vec2Model
import torch.nn as nn
import torch

class EmotionClassifier(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_values):
      with torch.no_grad():  # feature extractor freeze (선택)
          output = self.wav2vec(input_values)
      hidden = output.last_hidden_state.mean(dim=1)
      return self.classifier(hidden)


In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import os


ravdess_path = '/content/drive/MyDrive/RAVDESS'

audio_paths, labels = load_ravdess_dataset(ravdess_path)
print(f"총 오디오 파일 수: {len(audio_paths)}")
print("샘플 경로:", audio_paths[0])
print("샘플 라벨:", labels[0])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EmotionClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

dataset = RAVDESSDataset(audio_paths, labels)

def collate_fn(batch):
    input_values = [item[0] for item in batch]
    labels = torch.tensor([item[2] for item in batch], dtype=torch.long)

    padded = processor.pad(
        {"input_values": input_values},
        padding=True,
        return_tensors="pt"
    )

    return padded.input_values, labels

train_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    collate_fn=collate_fn
)

for epoch in range(10):
    model.train()
    total_loss = 0.0

    for input_values, labels in train_loader:
        input_values = input_values.to(device)
        labels = labels.to(device)

        logits = model(input_values) 
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/10] Loss: {avg_loss:.4f}")


총 오디오 파일 수: 1440
샘플 경로: /content/drive/MyDrive/RAVDESS/Actor_06/03-01-02-01-01-02-06.wav
샘플 라벨: 1


pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

