File Mount

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Check GPU

In [2]:
import torch

if torch.cuda.is_available():
    print("✅ GPU Available!")
    print("GPU :", torch.cuda.get_device_name(0))

✅ GPU Available!
GPU : Tesla T4


In [3]:
import os
import random

label_to_score = {
    0: 0.0,  # neutral
    1: 0.0,  # calm
    2: 0.0,  # happy
    3: 0.4,  # sad
    4: 0.7,  # angry
    5: 1.0,  # fearful
    6: 0.5,  # disgust
    7: 0.3,  # surprised
}

def parse_label_from_filename(fname):
    return int(fname.split('_')[-1].replace('.pt', ''))

def create_triplet_list(mel_folder, score_margin_pos=0.2, score_margin_neg=0.5, max_triplets=10000):
    file_list = [f for f in os.listdir(mel_folder) if f.endswith('.pt')]
    random.shuffle(file_list)

    file_score = {f: label_to_score[parse_label_from_filename(f)] for f in file_list}

    triplets = []

    for anchor_file in file_list:
        anchor_score = file_score[anchor_file]

        positives = [
            f for f in file_list
            if f != anchor_file and abs(file_score[f] - anchor_score) <= score_margin_pos
        ]

        negatives = [
            f for f in file_list
            if f != anchor_file and abs(file_score[f] - anchor_score) >= score_margin_neg
        ]

        for pos in positives:
            for neg in negatives:
                triplets.append((
                    os.path.join(mel_folder, anchor_file),
                    os.path.join(mel_folder, pos),
                    os.path.join(mel_folder, neg)
                ))
                if len(triplets) >= max_triplets:
                    return triplets  # 바로 종료

    return triplets


Load Dataset

In [9]:
from torch.utils.data import Dataset
import torch

def ensure_single_channel(mel):
    if mel.dim() == 2:
        return mel.unsqueeze(0)  # (H, W) → (1, H, W)
    elif mel.shape[0] != 1:
        return mel[:1, :, :]     # (C, H, W) → (1, H, W), 첫 채널만 사용
    return mel

class TripletMelDataset(Dataset):
    def __init__(self, triplet_list):
        self.triplets = triplet_list

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a_path, p_path, n_path = self.triplets[idx]

        anchor = ensure_single_channel(torch.load(a_path))
        positive = ensure_single_channel(torch.load(p_path))
        negative = ensure_single_channel(torch.load(n_path))

        # (1, H, W) 보장
        if anchor.dim() == 2: anchor = anchor.unsqueeze(0)
        if positive.dim() == 2: positive = positive.unsqueeze(0)
        if negative.dim() == 2: negative = negative.unsqueeze(0)

        return anchor, positive, negative

Model Definition

In [10]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class MelEncoder(nn.Module):
    def __init__(self, output_dim=128):
        super().__init__()
        base = models.resnet18(pretrained=True)
        base.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        base.fc = nn.Identity()  # fc 제거
        self.base = base
        self.fc = nn.Linear(512, output_dim)

    def forward(self, x):
        x = self.base(x)
        x = self.fc(x)
        return F.normalize(x, dim=1)  # cosine 유사도 기반 학습을 위해 정규화


In [12]:
from torch.utils.data import DataLoader

triplet_list = create_triplet_list("/content/drive/MyDrive/processed_data")
triplet_dataset = TripletMelDataset(triplet_list)
triplet_loader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)


Train

In [14]:
import torch
import torch.nn as nn
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MelEncoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.TripletMarginLoss(margin=1.0, p=2)

for epoch in range(10):
    model.train()
    total_loss = 0

    for anchor, positive, negative in tqdm(triplet_loader, desc=f"Epoch {epoch+1}"):
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        anchor_embed = model(anchor)
        positive_embed = model(positive)
        negative_embed = model(negative)

        loss = criterion(anchor_embed, positive_embed, negative_embed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(triplet_loader)
    print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

    save_path = f"/content/drive/MyDrive/mel_encoder_epoch{epoch+1}.pt"
    torch.save(model.state_dict(), save_path)

Epoch 1: 100%|██████████| 313/313 [02:36<00:00,  2.00it/s]


[Epoch 1] Loss: 0.0313


Epoch 2: 100%|██████████| 313/313 [02:36<00:00,  2.00it/s]


[Epoch 2] Loss: 0.0037


Epoch 3: 100%|██████████| 313/313 [02:36<00:00,  2.01it/s]


[Epoch 3] Loss: 0.0000


Epoch 4: 100%|██████████| 313/313 [02:35<00:00,  2.01it/s]


[Epoch 4] Loss: 0.0000


Epoch 5:  45%|████▌     | 142/313 [01:10<01:25,  2.01it/s]


KeyboardInterrupt: 