# Imports

In [1]:
!git clone https://github.com/NVIDIA/CleanUNet.git

Cloning into 'CleanUNet'...
Filtering content: 100% (2/2)
Filtering content: 100% (2/2), 351.59 MiB | 10.59 MiB/s, done.


In [2]:
import librosa

from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import random
from typing import Union

from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch
from torch import Tensor

from torch.utils import data
from torch.nn.parameter import Parameter

import os

import pickle

import json

import sys
sys.path.append('./CleanUNet/')
from network import CleanUNet

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
device

device(type='cuda')

# Config

In [None]:
# 설정 파일 경로
CONFIG_PATH = './CleanUNet/configs/DNS-large-full.json'

# 모델 체크포인트 파일 경로
MODEL_PATH = './CleanUNet/exp/DNS-large-full/checkpoint/pretrained.pkl'

In [3]:
class Config:
    SR = 16000  # RawNet uses 16kHz sample rate
    # Dataset
    ROOT_FOLDER = './'
    # Training
    N_CLASSES = 2
    # batch_size
    BATCH_SIZE = 16
    # Others
    SEED = 42

CONFIG = Config()

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CONFIG.SEED) # Seed 고정

In [None]:
from sklearn.utils import resample
df1 = pd.read_csv('./train_augmented1.csv') # 1명 목소리
df2 = pd.read_csv('./train_augmented3.csv') # 2명 목소리
df3 = pd.read_csv('./train_noise.csv') # 노이즈 데이터

# Split each dataframe
def split_train_val(df, test_size=0.2, random_state=42):
    train, val = train_test_split(df, test_size=test_size, random_state=random_state, stratify=df['label'])
    return train, val

train1, val1 = split_train_val(df1, random_state=CONFIG.SEED)
train2, val2 = split_train_val(df2, random_state=CONFIG.SEED)
train3, val3 = split_train_val(df3, random_state=CONFIG.SEED)

def stratified_sample(df, target_col, sample_sizes, random_state=None):
    sample_dfs = []
    for label, sample_size in sample_sizes.items():
        class_subset = df[df[target_col] == label]
        sample_dfs.append(resample(class_subset, n_samples=sample_size, random_state=random_state, replace=False))
    return pd.concat(sample_dfs).reset_index(drop=True)

sample1_sizes = {'fake': 10000, 'real': 10000}
train1_sampled = stratified_sample(train1, 'label', sample1_sizes, random_state=CONFIG.SEED)

sample2_sizes =  {'fakereal': 20000, 'fake': 10000, 'real': 10000}
train2_sampled = stratified_sample(train2, 'label', sample2_sizes, random_state=CONFIG.SEED)

train3_sampled = train3.sample(n=8000, random_state=CONFIG.SEED)

train_combined = pd.concat([train1_sampled, train2_sampled, train3_sampled], ignore_index=True)
print(train_combined['label'].value_counts())

## Data Pre-processing : denoise

In [None]:
with open(CONFIG_PATH) as f:
    config = json.load(f)

network_config = config["network_config"]
denoiser = CleanUNet(**network_config).cuda()
checkpoint = torch.load(MODEL_PATH, map_location='cuda')
denoiser.load_state_dict(checkpoint['model_state_dict'])
denoiser.eval()

CleanUNet(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(1, 64, kernel_size=(4,), stride=(2,))
      (1): ReLU()
      (2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (1): Sequential(
      (0): Conv1d(64, 128, kernel_size=(4,), stride=(2,))
      (1): ReLU()
      (2): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (2): Sequential(
      (0): Conv1d(128, 256, kernel_size=(4,), stride=(2,))
      (1): ReLU()
      (2): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (3): Sequential(
      (0): Conv1d(256, 512, kernel_size=(4,), stride=(2,))
      (1): ReLU()
      (2): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (4): Sequential(
      (0): Conv1d(512, 768, kernel_size=(4,), stride=(2,))
      (1): ReLU()
      (2): Conv1d(768, 1536, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (5-7): 3 x Sequential(
      (0): Conv

In [None]:
def denoise_audio(audio, sr):
    # 오디오를 float32로 변환하고 -1에서 1 사이로 정규화
    audio = audio.astype(np.float32)
    audio = audio / np.max(np.abs(audio))

    # torch 텐서로 변환
    audio_tensor = torch.FloatTensor(audio).unsqueeze(0).unsqueeze(0).cuda()

    # 노이즈 제거
    with torch.no_grad():
        denoised_audio = denoiser(audio_tensor)

    # numpy 배열로 변환
    denoised_audio = denoised_audio.squeeze().cpu().numpy()

    # -1에서 1 사이로 클리핑
    denoised_audio = np.clip(denoised_audio, -1, 1)
    return denoised_audio

In [None]:
def get_audio_data(df, train_mode=True):
    features = []
    labels = []
    for _, row in tqdm(df.iterrows()):
        try:
            path = row['path']
            y, sr = librosa.load(path, sr=CONFIG.SR)
        except KeyboardInterrupt:
            print("Process interrupted by user, stopping...")
            break
        except Exception as e:
            print(f"Unsupported codec or corrupted file: {path}, skipping... Error: {e}")
            continue

        if len(y) < CONFIG.SR * 5:  # Pad if less than 5 seconds
            y = np.pad(y, (0, CONFIG.SR * 5 - len(y)))
        else:
            y = y[:CONFIG.SR * 5]  # Truncate to 5 seconds

        #노이즈 제거
        y = denoise_audio(y, sr)
        
        # 오류 데이터 제거
        if train_mode and (np.isnan(y).any() or np.isinf(y).any()):
            continue

        features.append(y)

        if train_mode:
            label = row['label']
            label_vector = np.zeros(CONFIG.N_CLASSES, dtype=float)
            if label == 'fake':
                label_vector = np.array([1, 0], dtype=float)
            elif label == 'real':
                label_vector = np.array([0, 1], dtype=float)
            elif label == 'fakereal':
                label_vector = np.array([1, 1], dtype=float)
            elif label == 'silence':
                label_vector = np.array([0, 0], dtype=float)
            labels.append(label_vector)

    if train_mode:
        return features, labels
    return features

# Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, audio, label):
        self.audio = audio
        self.label = label

    def __len__(self):
        return len(self.audio)

    def __getitem__(self, index):
        if self.label is not None:
            return torch.FloatTensor(self.audio[index]), torch.FloatTensor(self.label[index])
        return torch.FloatTensor(self.audio[index])

In [None]:
def create_dataloader(df, batch_size=32, train_mode=True, shuffle=True):
    audio, labels = get_audio_data(df, train_mode)
    dataset = CustomDataset(audio, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# RawBoost를 사용 안하려면, augment=False 추가
train_loader = create_dataloader(train_combined, batch_size=CONFIG.BATCH_SIZE, train_mode=True)
val_loaders = [
    create_dataloader(val1, batch_size=CONFIG.BATCH_SIZE, train_mode=True, shuffle=False),
    create_dataloader(val2, batch_size=CONFIG.BATCH_SIZE, train_mode=True, shuffle=False),
    create_dataloader(val3, batch_size=CONFIG.BATCH_SIZE, train_mode=True, shuffle=False)
]

68000it [28:20, 39.99it/s]
11088it [04:31, 40.78it/s]
11088it [04:41, 39.40it/s]
2000it [00:43, 46.46it/s]


# Define Model

In [None]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)

        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)

        x = torch.cat([x1, x2], dim=1)

        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)

        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)

        # directional edge for master node
        master = self._update_master(x, master)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)

        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        # att_map = torch.matmul(att_map, self.att_weight12)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)

        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self,
                 out_channels,
                 kernel_size,
                 sample_rate=16000,
                 in_channels=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 groups=1,
                 mask=False):
        super().__init__()
        if in_channels != 1:

            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
                in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * \
                np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * \
                np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow

            self.band_pass[i, :] = Tensor(np.hamming(
                self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter

        self.filters = (band_pass_filter).view(self.out_channels, 1,
                                               self.kernel_size)

        return F.conv1d(x,
                        self.filters,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        bias=None,
                        groups=1)


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))  # self.mp = nn.MaxPool2d((1,4))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)

        # print('out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        out = self.mp(out)
        return out


class Model(nn.Module):
    def __init__(self, d_args):
        super().__init__()

        self.d_args = d_args
        filts = d_args["filts"]
        gat_dims = d_args["gat_dims"]
        pool_ratios = d_args["pool_ratios"]
        temperatures = d_args["temperatures"]

        self.conv_time = CONV(out_channels=filts[0],
                              kernel_size=d_args["first_conv"],
                              in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)

        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))

        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])

        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        # self.out_layer = nn.Linear(5 * gat_dims[1], 2)
        self.fc1 = nn.Linear(5 * gat_dims[1], 64)
        self.out_layer = nn.Linear(64, 2)

    def forward(self, x, Freq_aug=False):

        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # get embeddings using encoder
        # (#bs, #filt, #spec, #seq)
        e = self.encoder(x)

        # spectral GAT (GAT-S)
        e_S, _ = torch.max(torch.abs(e), dim=3)  # max along time
        e_S = e_S.transpose(1, 2) + self.pos_S

        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)

        # temporal GAT (GAT-T)
        e_T, _ = torch.max(torch.abs(e), dim=2)  # max along freq
        e_T = e_T.transpose(1, 2)

        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)

        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)

        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

        last_hidden = self.drop(last_hidden)
        #output = self.out_layer(last_hidden)
        last_hidden = self.fc1(last_hidden)
        last_hidden = nn.ReLU()(last_hidden)
        output = self.out_layer(last_hidden)
        output = torch.sigmoid(output)

        return output

# Train & Validation

In [None]:
import numpy as np
from sklearn.calibration import calibration_curve
from sklearn.metrics import mean_squared_error
from sklearn.metrics import roc_auc_score
import pandas as pd


def expected_calibration_error(y_true, y_prob, n_bins=10):
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')
    bin_totals = np.histogram(y_prob, bins=np.linspace(0, 1, n_bins + 1), density=False)[0]
    non_empty_bins = bin_totals > 0
    bin_weights = bin_totals / len(y_prob)
    bin_weights = bin_weights[non_empty_bins]
    prob_true = prob_true[:len(bin_weights)]
    prob_pred = prob_pred[:len(bin_weights)]
    ece = np.sum(bin_weights * np.abs(prob_true - prob_pred))
    return ece

def auc_brier_ece(answer_df, submission_df):
    # Check for missing values in submission_df
    if submission_df.isnull().values.any():
        raise ValueError("The submission dataframe contains missing values.")

    # Check if the number and names of columns are the same in both dataframes
    if len(answer_df.columns) != len(submission_df.columns) or not all(answer_df.columns == submission_df.columns):
        raise ValueError("The columns of the answer and submission dataframes do not match.")

    # Sort the dataframes by index to ensure they are aligned
    answer_df = answer_df.sort_index()
    submission_df = submission_df.sort_index()

    # Calculate AUC for each class
    auc_scores = []
    for column in answer_df.columns:
        y_true = answer_df[column]
        y_scores = submission_df[column]
        if len(np.unique(y_true)) == 1:
            # If only one class is present in y_true, skip AUC calculation
            auc_scores.append(0.5)  # Neutral value for AUC
        else:
            auc = roc_auc_score(y_true, y_scores)
            auc_scores.append(auc)

    # Calculate mean AUC
    mean_auc = np.mean(auc_scores)

    brier_scores = []
    ece_scores = []

    # Calculate Brier Score and ECE for each class
    for column in answer_df.columns:
        y_true = answer_df[column].values
        y_prob = submission_df[column].values

        # Brier Score
        brier = mean_squared_error(y_true, y_prob)
        brier_scores.append(brier)

        # ECE
        ece = expected_calibration_error(y_true, y_prob)
        ece_scores.append(ece)

    # Calculate mean Brier Score and mean ECE
    mean_brier = np.mean(brier_scores)
    mean_ece = np.mean(ece_scores)

    # Calculate combined score
    combined_score = 0.5 * (1 - mean_auc) + 0.25 * mean_brier + 0.25 * mean_ece

    return combined_score

In [None]:
from sklearn.metrics import roc_auc_score

def training(model, optimizer, train_loader, val_loader, device, epochs):
    model.to(device)
    criterion = nn.BCELoss().to(device)

    best_val_score = 0
    best_model = None
    try:
      for epoch in range(1, epochs+1):
          model.train()
          train_loss = []
          for features, labels in tqdm(iter(train_loader)):
              features = features.float().to(device)
              labels = labels.float().to(device)

              optimizer.zero_grad()

              output = model(features)

              loss = criterion(output, labels)

              loss.backward()
              optimizer.step()

              train_loss.append(loss.item())

          _val_loss, _val_score, _val_all_score = [], [], []

          for val_loader in tqdm(val_loaders):
              v_loss, v_score, v_all_score = validation(model, criterion, val_loader, device)
              _val_loss.append(v_loss)
              _val_score.append(v_score)
              _val_all_score.append(v_all_score)

          avg_val_all_score = np.mean(_val_all_score)
          _train_loss = np.mean(train_loss)

          print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}], Total score: [{avg_val_all_score:.5f}]')

          for i, (v_loss, v_score, v_all_score) in enumerate(zip(_val_loss, _val_score, _val_all_score)):
              print(f'Val Set {i+1} Loss : [{v_loss:.5f}] AUC : [{v_score:.5f}] Total score: [{v_all_score:.5f}]')

          if best_val_score < avg_val_all_score:
              best_val_score = avg_val_all_score
              best_model = model
    except KeyboardInterrupt:
        print("Training interrupted. Returning the best model so far.")
        return best_model

    return best_model

def multiLabel_AUC(y_true, y_scores):
    auc_scores = []
    for i in range(y_true.shape[1]):
        if len(np.unique(y_true[:, i])) == 1:
            auc_scores.append(0.5)
        else:
            auc = roc_auc_score(y_true[:, i], y_scores[:, i])
            auc_scores.append(auc)
    mean_auc_score = np.mean(auc_scores)
    return mean_auc_score

def validation(model, criterion, val_loader, device):
    model.eval()
    val_loss, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for features, labels in iter(val_loader):
            if torch.isnan(features).any() or torch.isinf(features).any():
                  #print("NaN detected, skipping this batch.")
                  continue
            features = features.float().to(device)
            labels = labels.float().to(device)

            probs = model(features)

            loss = criterion(probs, labels)

            val_loss.append(loss.item())

            all_labels.append(labels.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

        _val_loss = np.mean(val_loss)

        all_labels = np.concatenate(all_labels, axis=0)
        all_probs = np.concatenate(all_probs, axis=0)

        # Calculate AUC score
        auc_score = multiLabel_AUC(all_labels, all_probs)
        all_score = auc_brier_ece(pd.DataFrame(all_labels), pd.DataFrame(all_probs))

    return _val_loss, auc_score, all_score

## Run

### 전이학습

In [None]:
d_args = {
    "architecture": "AASIST",
    "nb_samp": 64600,
    "first_conv": 128,
    "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
    "gat_dims": [64, 32],
    "pool_ratios": [0.5, 0.7, 0.5, 0.5],
    "temperatures": [2.0, 2.0, 100.0, 100.0]
}

In [None]:
# 모델 인스턴스 생성
model = Model(d_args).to(device)

# 사전 학습된 파라미터 로드
pretrained_dict = torch.load('./AASIST.pth', map_location=device)

# 현재 모델의 state_dict와 사전 학습된 모델의 state_dict 키 비교
model_dict = model.state_dict()

# 마지막 레이어의 이름이 'out_layer'라고 가정하고, 해당 레이어의 파라미터를 필터링
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'out_layer' not in k}

# 모델에 사전 학습된 파라미터 로드
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

<All keys matched successfully>

In [None]:
# 모든 파라미터를 고정
for param in model.parameters():
    param.requires_grad = False
# 마지막 fully connected 레이어 2개만 학습 가능하도록 설정
model.out_layer.weight.requires_grad = True
model.out_layer.bias.requires_grad = True
model.fc1.weight.requires_grad = True
model.fc1.bias.requires_grad = True

In [None]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=7e-6)

### 학습

In [None]:
infer_model = training(model, optimizer, train_loader, val_loaders, device, epochs=20)

100%|██████████| 4138/4138 [05:52<00:00, 11.72it/s]
100%|██████████| 3/3 [01:40<00:00, 33.50s/it]

Epoch [1], Train Loss : [0.44865], Total score: [0.17335]
Val Set 1 Loss : [0.39728] AUC : [0.93420] Total score: [0.09120]
Val Set 2 Loss : [0.44591] AUC : [0.84071] Total score: [0.14507]
Val Set 3 Loss : [0.15389] AUC : [0.50000] Total score: [0.28377]





In [None]:
torch.save(infer_model.state_dict(), "./clean_aasist.pth")