In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"


In [5]:
import random
import warnings
from sklearn.model_selection import train_test_split

import librosa
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor

warnings.filterwarnings(action='ignore')

In [6]:
import glob
import math
import os
import tempfile
import time
from typing import List, Optional, Tuple, Union

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from loguru import logger
from PIL import Image
from torch import Tensor
from torchaudio.backend.common import AudioMetaData

from df import config
from df.enhance import enhance, init_df, load_audio, save_audio
from df.io import resample

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
CFG = {
    'SR':16_000,
    'SEED':42,
    'BATCH_SIZE':24, 
    'EPOCHS':20,
    'LR':1e-4,
}

In [9]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [18]:
base_path = os.getcwd()
train_df = pd.read_csv(f'{base_path}/dataset/dacon_dataset/train_bandstop_new_dataset_addmix.csv')
test_df = pd.read_csv(f'{base_path}/dataset/dacon_dataset/test.csv')

x_train_zero = train_df['path'].values.tolist()[-384:] * 156
y_train_zero = train_df[['fake', 'real']][-384:].values.tolist() * 156

x_train = train_df['path'][:-384].values.tolist() + x_train_zero
y_train = train_df[['fake', 'real']][:-384].values.tolist() + y_train_zero


print(len(x_train))

239971


In [11]:
label_mapping = {
    (1, 0): 0,
    (0, 1): 1,
    (2, 0): 2,
    (0, 2): 3,
    (1, 1): 4,
    (0, 0): 5
}

y_train_transformed = [label_mapping[tuple(labels)] for labels in y_train]



In [12]:
# X_train, X_val, y_train, y_val = train_test_split(X_train, y_train_transformed, test_size=0.1, stratify=y_train_transformed, random_state=CFG['SEED'])
_, X_val, _, y_val = train_test_split(x_train, y_train_transformed, test_size=0.35, stratify=y_train_transformed, random_state=CFG['SEED'])
X_train, X_val, y_train, y_val = train_test_split(X_val, y_val, test_size=0.1, stratify=y_val, random_state=CFG['SEED'])

print(f"train: {len(X_train)} val: {len(X_val)} test: {len(test_df)}")

train: 75570 val: 8397 test: 50000


In [13]:
class RandomPitchShiftSegment:
    def __init__(self, min_semitones=-12, max_semitones=2, duration_range=(1, 5), fade_duration=0.1, p=0.5):
        self.min_semitones = min_semitones
        self.max_semitones = max_semitones
        self.duration_range = duration_range
        self.fade_duration = fade_duration
        self.p = p

    def __call__(self, samples, sample_rate):
        if random.random() > self.p:
            return samples
        
        total_duration = len(samples) / sample_rate
        segment_duration = random.uniform(*self.duration_range)
        fade_samples = int(self.fade_duration * sample_rate)
        
        start_time = random.uniform(0, total_duration - segment_duration)
        start_sample = int(start_time * sample_rate)
        end_sample = int((start_time + segment_duration) * sample_rate)
        
        pitch_shift = PitchShift(min_semitones=self.min_semitones, max_semitones=self.max_semitones, p=1.0)
        original_segment = samples[start_sample:end_sample].copy()
        shifted_segment = pitch_shift(samples=original_segment, sample_rate=sample_rate)
        
        # Apply crossfade
        for i in range(fade_samples):
            fade_in_factor = i / fade_samples
            fade_out_factor = 1 - fade_in_factor
            shifted_segment[i] = fade_out_factor * original_segment[i] + fade_in_factor * shifted_segment[i]
            shifted_segment[-i-1] = fade_in_factor * original_segment[-i-1] + fade_out_factor * shifted_segment[-i-1]
        
        samples[start_sample:end_sample] = shifted_segment
        
        return samples

In [14]:
val_df = pd.read_csv("./valid.csv")
y_val = val_df['label'].values.tolist()

In [15]:
X_val = [f"./dataset/dacon_dataset/valid_denoising/enhanced_VALID_{i:05d}.wav" for i in range(len(y_val))]

In [16]:
MODEL_NAME = 'facebook/wav2vec2-xls-r-300m'
processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)

def pad_and_randomize_audio(audio, sr=16000, target_length=5):
    """5초 이하의 음성을 5초 길이로 패딩하고 랜덤 위치에 배치"""
    target_samples = sr * target_length
    audio_length = len(audio)

    if audio_length >= target_samples:
        return audio[:target_samples]  # 이미 5초 이상인 경우, 잘라서 반환

    # 5초 이하인 경우
    padded_audio = np.zeros(target_samples)
    start_pos = random.randint(0, target_samples - audio_length)
    padded_audio[start_pos:start_pos + audio_length] = audio
    return padded_audio


from audiomentations import Compose, BandStopFilter, AirAbsorption

# 증강 기법 정의 예시
augment = Compose([
    BandStopFilter(min_center_freq=200.0, max_center_freq=8000.0, min_bandwidth_fraction=0.1, max_bandwidth_fraction=0.4, p=1.0)
])

# 배치 데이터를 처리하는 함수
def process_batch(file_paths, is_augment=False):
    audio_batch = []
    for audio_path in tqdm(file_paths):
        audio, sr = librosa.load(audio_path)
        if is_augment:
            audio = augment(audio[:sr*5], sample_rate=sr)
        audio = pad_and_randomize_audio(audio, sr=sr, target_length=5)
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
        # audio = processor(audio, sampling_rate=16000, return_tensors="np", padding=True).input_values
        audio_batch.append(audio)

    audio_batch = np.vstack(audio_batch)
    return audio_batch

In [17]:
train_x = process_batch(X_train, is_augment=True)
valid_x = process_batch(X_val, is_augment=False)

  0%|          | 0/75570 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [17]:
from collections import Counter

label_counts = Counter(y_train)

combined_counts = {
    '0_and_2': label_counts[0] + label_counts[2],
    '1_and_3': label_counts[1] + label_counts[3],
    '4': label_counts[4],
    '5': label_counts[5]
}

# 합산된 라벨 균형을 출력
for label, count in combined_counts.items():
    print(f"Label {label}: {count} occurrences")


Label 0_and_2: 18900 occurrences
Label 1_and_3: 18900 occurrences
Label 4: 18900 occurrences
Label 5: 18870 occurrences


In [18]:
idx2ten = {0: [1,0], 1:[0,1],
           2: [1,0], 3:[0,1],
           4: [1,1], 5:[0,0]} 

In [19]:
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, OneOf

class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, x, y, train_mode=False, augment_func=None):
        self.x = x
        self.y = y
        self.train_mode = train_mode
        self.augment_func = augment_func
        self.augmentation = Compose([
            OneOf([
                # Compose([AddReverb(reverb_amount=0.6, p=1), AddEcho(delay=0.09, attenuation=0.5, sample_rate=32000, p=1)]),
                # PitchShift(min_semitones=-10, max_semitones=2, p=1),
                GainTransition(min_gain_in_db=-140, max_gain_in_db=0, p=1.),
                RandomPitchShiftSegment(min_semitones=-10, max_semitones=2, duration_range=(1, 3.1), fade_duration=0.1, p=1.0),
            ], p=0.35)
        ])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        input_values = self.x[idx]
        
        if self.y is not None:
            labels = torch.tensor(self.y[idx], dtype=torch.float32)
            labels_trans = torch.tensor(idx2ten[self.y[idx]], dtype=torch.float32)
            if (labels == torch.tensor([5], dtype=torch.float32)).all():
                input_values = self.augment_func(samples=input_values, sample_rate=16000)

            if random.random() < 0.5 and self.train_mode:
                input_values = self.augmentation(input_values, sample_rate=16000)            
            input_values = processor(input_values, sampling_rate=16000, return_tensors="pt", padding=True).input_values.squeeze()

            return input_values, labels, labels_trans
        else:
            input_values = processor(input_values, sampling_rate=16000, return_tensors="pt", padding=True).input_values.squeeze()
            return input_values

# Define the augmentation pipeline using audiomentations
augmentations = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
    Shift(min_shift=-0.5, max_shift=0.5, p=0.5)
])



In [20]:
def create_data_loader(dataset, batch_size, shuffle, num_workers=0):
    return DataLoader(dataset,
                      batch_size=batch_size,
                      shuffle=shuffle,
                      num_workers=num_workers,
                      pin_memory=True
                      )


In [21]:

train_dataset = CustomDataSet(train_x, y_train, train_mode=True, augment_func=augmentations)
valid_dataset = CustomDataSet(valid_x, y_val, augment_func=augmentations)

train_loader = create_data_loader(train_dataset, CFG['BATCH_SIZE'], True, 4)
valid_loader = create_data_loader(valid_dataset, CFG['BATCH_SIZE'], False, 4)

In [22]:
data, target, target_trans = next(iter(train_loader))
print(data.shape, target.shape, target_trans.shape)
print(target)

torch.Size([24, 80000]) torch.Size([24]) torch.Size([24, 2])
tensor([4., 5., 4., 5., 5., 5., 1., 4., 0., 5., 5., 2., 5., 4., 0., 1., 5., 5.,
        0., 4., 1., 0., 5., 4.])


In [23]:
"""
AASIST
Copyright (c) 2021-present NAVER Corp.
MIT license
"""

import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)

        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)

        x = torch.cat([x1, x2], dim=1)

        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)

        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)

        # directional edge for master node
        master = self._update_master(x, master)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)

        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        # att_map = torch.matmul(att_map, self.att_weight12)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)

        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self,
                 out_channels,
                 kernel_size,
                 sample_rate=16000,
                 in_channels=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 groups=1,
                 mask=False):
        super().__init__()
        if in_channels != 1:

            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
                in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * \
                np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * \
                np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow

            self.band_pass[i, :] = Tensor(np.hamming(
                self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter

        self.filters = (band_pass_filter).view(self.out_channels, 1,
                                               self.kernel_size)

        return F.conv1d(x,
                        self.filters,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        bias=None,
                        groups=1)


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))  # self.mp = nn.MaxPool2d((1,4))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)

        # print('out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        out = self.mp(out)
        return out


class Model(nn.Module):
    def __init__(self, d_args):
        super().__init__()

        self.d_args = d_args
        filts = d_args["filts"]
        gat_dims = d_args["gat_dims"]
        pool_ratios = d_args["pool_ratios"]
        temperatures = d_args["temperatures"]

        self.conv_time = CONV(out_channels=filts[0],
                              kernel_size=d_args["first_conv"],
                              in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)

        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))

        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])

        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.out_layer1 = nn.Linear(5 * gat_dims[1], 2)
        self.out_layer2 = nn.Linear(5 * gat_dims[1], 6)


    def forward(self, x, Freq_aug=False):

        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # get embeddings using encoder
        # (#bs, #filt, #spec, #seq)
        e = self.encoder(x)

        # spectral GAT (GAT-S)
        e_S, _ = torch.max(torch.abs(e), dim=3)  # max along time
        e_S = e_S.transpose(1, 2) + self.pos_S

        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)

        # temporal GAT (GAT-T)
        e_T, _ = torch.max(torch.abs(e), dim=2)  # max along freq
        e_T = e_T.transpose(1, 2)

        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)

        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)

        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

        last_hidden = self.drop(last_hidden)
        output1 = self.out_layer1(last_hidden)
        output2 = self.out_layer2(last_hidden)

        return output1, output2

d_args = {
        "architecture": "AASIST",
        "nb_samp": 80000,
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
    }
audio_model = Model(d_args=d_args)

In [24]:
# # NO PRETRAINING

# audio_model = W2VAASIST()
# # audio_model = torch.load("/home/hyj/ChanHyung/Audio/fake_voice_detection/anti-spoofing_feat_model.pt")
# audio_model(datas)

In [25]:
import torch
import torch.nn as nn
from torchmetrics.classification import BinaryAUROC, BinaryCalibrationError
from torchmetrics import MeanSquaredError

class DaconMetrics(nn.Module):
    def __init__(self, device):
        super(DaconMetrics, self).__init__()
        self.auc_metric_class_0 = BinaryAUROC().to(device)
        self.auc_metric_class_1 = BinaryAUROC().to(device)
        self.brier_metric_class_0 = MeanSquaredError().to(device)
        self.brier_metric_class_1 = MeanSquaredError().to(device)
        self.ece_metric_class_0 = BinaryCalibrationError(n_bins=10, norm='l1').to(device)
        self.ece_metric_class_1 = BinaryCalibrationError(n_bins=10, norm='l1').to(device)
        self.device = device

    def forward(self, pred, label):
        # For class 0
        self.auc_metric_class_0.update(pred[:, 0].detach(), label[:, 0].detach())
        self.brier_metric_class_0.update(pred[:, 0].detach(), label[:, 0].detach())
        self.ece_metric_class_0.update(pred[:, 0].detach(), label[:, 0].detach())

        # For class 1
        self.auc_metric_class_1.update(pred[:, 1].detach(), label[:, 1].detach())
        self.brier_metric_class_1.update(pred[:, 1].detach(), label[:, 1].detach())
        self.ece_metric_class_1.update(pred[:, 1].detach(), label[:, 1].detach())

    def reset(self):
        self.auc_metric_class_0.reset()
        self.auc_metric_class_1.reset()
        self.brier_metric_class_0.reset()
        self.brier_metric_class_1.reset()
        self.ece_metric_class_0.reset()
        self.ece_metric_class_1.reset()

    def compute(self):
        with torch.no_grad():
            auc_class_0 = self.auc_metric_class_0.compute().item()
            auc_class_1 = self.auc_metric_class_1.compute().item()
            mean_auc = (auc_class_0 + auc_class_1) / 2

            mse_class_0 = self.brier_metric_class_0.compute().item()
            mse_class_1 = self.brier_metric_class_1.compute().item()
            mean_mse = (mse_class_0 + mse_class_1) / 2

            ece_class_0 = self.ece_metric_class_0.compute().item()
            ece_class_1 = self.ece_metric_class_1.compute().item()
            mean_ece = (ece_class_0 + ece_class_1) / 2

            final = 0.5 * (1 - mean_auc) + 0.25 * mean_mse + 0.25 * mean_ece

        return final, mean_auc, mean_mse, mean_ece
    

dacon_metrics = DaconMetrics(torch.device('cpu'))

# Example data
pred = torch.tensor([[0.9, 0.1], [0.6, 0.4], [0.65, 0.35], [0.2, 0.8]])  
label = torch.tensor([[1, 0], [1, 0], [1, 1], [0, 1]])

# Update metrics
dacon_metrics(pred, label)
# Compute results
final, auc, brier, ece = dacon_metrics.compute()
print(f'Final Score: {final} auc: {auc} brier: {brier} ece: {ece}')


Final Score: 0.16765624564141035 auc: 0.875 brier: 0.12062500044703484 ece: 0.29999998211860657


In [26]:
def validation(model, valid_loader, criterion1, criterion2):
    model.eval()
    val_loss = []
    valid_metrics = DaconMetrics(device=device)

    # Initialize F1 score metric
    valid_loader_tqdm = tqdm(valid_loader, desc=f"Valid ")
    with torch.no_grad():
        for x, y, y_trans in valid_loader_tqdm:
            x = x.to(device)
            y_trans = y_trans.to(device)
            y = y.long().to(device)

            output = model(x)
            loss1 = criterion1(output[0], y_trans)
            loss2 = criterion2(output[1], y)
            loss = loss1 * 0.65 + loss2 * 0.35

            val_loss.append(loss.item())

            # Update metrics
            probs = torch.sigmoid(output[0])

            # Update metrics
            valid_metrics(probs, y_trans)
            dacon_score, auc, brier, ece = valid_metrics.compute()
            
            valid_loader_tqdm.set_postfix({'loss': np.mean(val_loss), 'dacon_score': dacon_score, 'auc': auc, 'brier': brier, 'ece': ece})
            
    avg_loss = np.mean(val_loss)
    avg_dacon_score, avg_acu, avg_brier, avg_ece = valid_metrics.compute()

    return avg_loss, avg_dacon_score, avg_acu, avg_brier, avg_ece


In [27]:
def train(model, train_loader, valid_loader, optimizer, scheduler):
    model.to(device)
    criterion1 = nn.MultiLabelSoftMarginLoss().to(device)
    criterion2 = nn.CrossEntropyLoss().to(device)

    best_model = None
    best_dacon = 14651651513221231231
    train_metrics = DaconMetrics(device=device)
    num_batches = len(train_loader)

    for epoch in range(1, CFG['EPOCHS'] + 1):
        train_loss = []
        model.train()

        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch} Training")
        for i, (x, y, y_trans) in enumerate(train_loader_tqdm):
            x = x.to(device)
            y_trans = y_trans.to(device)
            y = y.long().to(device)

            optimizer.zero_grad()
            output = model(x)
            loss1 = criterion1(output[0], y_trans)
            loss2 = criterion2(output[1], y)
            loss = loss1 * 0.65 + loss2 * 0.35
            loss.backward()

            
            with torch.no_grad():
                probs = torch.sigmoid(output[0])
                train_metrics(probs, y_trans)
                dacon_score, auc, brier, ece = train_metrics.compute()

        
            optimizer.step()
            optimizer.zero_grad()

            train_loss.append(loss.item())
            train_loader_tqdm.set_postfix({
                'loss': np.mean(train_loss),
                'dacon_score': dacon_score,
                'auc': auc,
                'brier': brier,
                'ece': ece
            })

            # if i % 3 == 0:
            #     wandb.log({
            #         'train_loss': np.mean(train_loss),
            #         'train_dacon_score': dacon_score,
            #         'train_auc': auc,
            #         'train_brier': brier,
            #         'train_ece': ece
            #     })



        dacon_score, auc, brier, ece = train_metrics.compute()

        valid_loss, valid_dacon, valid_acu, valid_brier, valid_ece = validation(model, valid_loader, criterion1, criterion2)

        
        # wandb.log({
        #         'valid_loss': valid_loss,
        #         'valid_dacon_score': valid_dacon,
        #         'valid_auc': valid_acu,
        #         'valid_brier': valid_brier,
        #         'valid_ece': valid_ece
        #     })
        print(f'Epoch {epoch} - step: [{i + 1}/{num_batches}] train_loss:[{np.mean(train_loss):.5f}] dacon_score:[{dacon_score:.5f}] auc:[{auc:.5f}] brier:[{brier:.5f}] ece:[{ece:.5f}]')
        print(f'Validation - step: [{i + 1}/{num_batches}] valid_loss:[{valid_loss:.5f}] valid_dacon:[{valid_dacon:.5f}] valid_auc:[{valid_acu:.5f}] valid_brier:[{valid_brier:.5f}] valid_ece:[{valid_ece:.5f}]')
        train_metrics.reset()



        if scheduler is not None:
            scheduler.step(valid_dacon)

        if valid_dacon < best_dacon:
            best_dacon = valid_dacon
            best_model = model
            print(f"{best_dacon} saved.")
        torch.save(model.state_dict(), f"./model_save_epoch_{epoch}_aasist_bandstop_addzero_new_dataset_2loss.pth")

    print(f'best_dacon:{best_dacon:.5f}')

    return best_model


In [28]:
optimizer = torch.optim.AdamW(audio_model.parameters(), lr=CFG['LR'], weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

In [29]:
infer_model = train(audio_model, train_loader, valid_loader, optimizer, scheduler)

Epoch 1 Training:   0%|          | 0/3149 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 3149/3149 [20:02<00:00,  2.62it/s, loss=0.586, dacon_score=0.0772, auc=0.912, brier=0.119, ece=0.0127]
Valid : 100%|██████████| 200/200 [00:26<00:00,  7.51it/s, loss=0.842, dacon_score=0.142, auc=0.874, brier=0.176, ece=0.139]


Epoch 1 - step: [3149/3149] train_loss:[0.58636] dacon_score:[0.07716] auc:[0.91152] brier:[0.11899] ece:[0.01270]
Validation - step: [3149/3149] valid_loss:[0.84231] valid_dacon:[0.14183] valid_auc:[0.87356] valid_brier:[0.17570] valid_ece:[0.13875]
0.14183190930634737 saved.


Epoch 2 Training: 100%|██████████| 3149/3149 [19:56<00:00,  2.63it/s, loss=0.37, dacon_score=0.0372, auc=0.965, brier=0.0727, ece=0.00603] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.70it/s, loss=1.28, dacon_score=0.205, auc=0.848, brier=0.254, ece=0.261]


Epoch 2 - step: [3149/3149] train_loss:[0.36959] dacon_score:[0.03723] auc:[0.96488] brier:[0.07266] ece:[0.00603]
Validation - step: [3149/3149] valid_loss:[1.27747] valid_dacon:[0.20478] valid_auc:[0.84776] valid_brier:[0.25370] valid_ece:[0.26093]


Epoch 3 Training: 100%|██████████| 3149/3149 [19:54<00:00,  2.64it/s, loss=0.31, dacon_score=0.0291, auc=0.975, brier=0.0606, ece=0.00518] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.88it/s, loss=0.871, dacon_score=0.137, auc=0.899, brier=0.177, ece=0.168]


Epoch 3 - step: [3149/3149] train_loss:[0.30962] dacon_score:[0.02906] auc:[0.97476] brier:[0.06059] ece:[0.00518]
Validation - step: [3149/3149] valid_loss:[0.87052] valid_dacon:[0.13652] valid_auc:[0.89934] valid_brier:[0.17722] valid_ece:[0.16753]
0.13651809096336365 saved.


Epoch 4 Training: 100%|██████████| 3149/3149 [19:55<00:00,  2.63it/s, loss=0.276, dacon_score=0.0245, auc=0.98, brier=0.0532, ece=0.00435] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.87it/s, loss=0.794, dacon_score=0.108, auc=0.904, brier=0.144, ece=0.0957]


Epoch 4 - step: [3149/3149] train_loss:[0.27603] dacon_score:[0.02447] auc:[0.97984] brier:[0.05321] ece:[0.00435]
Validation - step: [3149/3149] valid_loss:[0.79398] valid_dacon:[0.10818] valid_auc:[0.90361] valid_brier:[0.14423] valid_ece:[0.09571]
0.10818029288202524 saved.


Epoch 5 Training: 100%|██████████| 3149/3149 [19:52<00:00,  2.64it/s, loss=0.255, dacon_score=0.0217, auc=0.983, brier=0.0485, ece=0.00412]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.89it/s, loss=0.713, dacon_score=0.0994, auc=0.899, brier=0.133, ece=0.0623]


Epoch 5 - step: [3149/3149] train_loss:[0.25465] dacon_score:[0.02174] auc:[0.98281] brier:[0.04845] ece:[0.00412]
Validation - step: [3149/3149] valid_loss:[0.71284] valid_dacon:[0.09939] valid_auc:[0.89897] valid_brier:[0.13317] valid_ece:[0.06232]
0.09938732162117958 saved.


Epoch 6 Training: 100%|██████████| 3149/3149 [19:54<00:00,  2.64it/s, loss=0.239, dacon_score=0.02, auc=0.985, brier=0.0452, ece=0.00448]  
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.91it/s, loss=0.7, dacon_score=0.105, auc=0.897, brier=0.138, ece=0.0782]  


Epoch 6 - step: [3149/3149] train_loss:[0.23885] dacon_score:[0.02001] auc:[0.98479] brier:[0.04515] ece:[0.00448]
Validation - step: [3149/3149] valid_loss:[0.70007] valid_dacon:[0.10544] valid_auc:[0.89741] valid_brier:[0.13837] valid_ece:[0.07822]


Epoch 7 Training: 100%|██████████| 3149/3149 [19:49<00:00,  2.65it/s, loss=0.224, dacon_score=0.0182, auc=0.986, brier=0.0422, ece=0.00362]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.87it/s, loss=0.856, dacon_score=0.101, auc=0.92, brier=0.136, ece=0.108]  


Epoch 7 - step: [3149/3149] train_loss:[0.22439] dacon_score:[0.01822] auc:[0.98647] brier:[0.04220] ece:[0.00362]
Validation - step: [3149/3149] valid_loss:[0.85636] valid_dacon:[0.10078] valid_auc:[0.92013] valid_brier:[0.13552] valid_ece:[0.10788]


Epoch 8 Training: 100%|██████████| 3149/3149 [19:46<00:00,  2.65it/s, loss=0.212, dacon_score=0.0166, auc=0.988, brier=0.0399, ece=0.00245]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.90it/s, loss=0.831, dacon_score=0.118, auc=0.909, brier=0.155, ece=0.133]


Epoch 8 - step: [3149/3149] train_loss:[0.21229] dacon_score:[0.01664] auc:[0.98788] brier:[0.03988] ece:[0.00245]
Validation - step: [3149/3149] valid_loss:[0.83102] valid_dacon:[0.11768] valid_auc:[0.90865] valid_brier:[0.15488] valid_ece:[0.13316]


Epoch 9 Training: 100%|██████████| 3149/3149 [19:46<00:00,  2.65it/s, loss=0.189, dacon_score=0.0145, auc=0.99, brier=0.0348, ece=0.00379]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.89it/s, loss=0.74, dacon_score=0.0995, auc=0.906, brier=0.135, ece=0.0759] 


Epoch 9 - step: [3149/3149] train_loss:[0.18877] dacon_score:[0.01445] auc:[0.99037] brier:[0.03478] ece:[0.00379]
Validation - step: [3149/3149] valid_loss:[0.73991] valid_dacon:[0.09954] valid_auc:[0.90629] valid_brier:[0.13480] valid_ece:[0.07595]


Epoch 10 Training: 100%|██████████| 3149/3149 [19:46<00:00,  2.65it/s, loss=0.185, dacon_score=0.0139, auc=0.991, brier=0.034, ece=0.00303] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.94it/s, loss=0.728, dacon_score=0.0987, auc=0.923, brier=0.135, ece=0.106]


Epoch 10 - step: [3149/3149] train_loss:[0.18468] dacon_score:[0.01386] auc:[0.99079] brier:[0.03401] ece:[0.00303]
Validation - step: [3149/3149] valid_loss:[0.72794] valid_dacon:[0.09870] valid_auc:[0.92312] valid_brier:[0.13483] valid_ece:[0.10620]
0.09869708865880966 saved.


Epoch 11 Training: 100%|██████████| 3149/3149 [19:43<00:00,  2.66it/s, loss=0.179, dacon_score=0.0134, auc=0.991, brier=0.0329, ece=0.00337]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.95it/s, loss=0.727, dacon_score=0.0964, auc=0.915, brier=0.132, ece=0.0836]


Epoch 11 - step: [3149/3149] train_loss:[0.17945] dacon_score:[0.01342] auc:[0.99128] brier:[0.03289] ece:[0.00337]
Validation - step: [3149/3149] valid_loss:[0.72713] valid_dacon:[0.09644] valid_auc:[0.91480] valid_brier:[0.13177] valid_ece:[0.08360]
0.0964397843927145 saved.


Epoch 12 Training: 100%|██████████| 3149/3149 [19:45<00:00,  2.66it/s, loss=0.173, dacon_score=0.0126, auc=0.992, brier=0.0316, ece=0.00269]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.95it/s, loss=0.92, dacon_score=0.117, auc=0.923, brier=0.158, ece=0.158] 


Epoch 12 - step: [3149/3149] train_loss:[0.17271] dacon_score:[0.01263] auc:[0.99188] brier:[0.03157] ece:[0.00269]
Validation - step: [3149/3149] valid_loss:[0.91976] valid_dacon:[0.11719] valid_auc:[0.92345] valid_brier:[0.15796] valid_ece:[0.15772]


Epoch 13 Training: 100%|██████████| 3149/3149 [19:43<00:00,  2.66it/s, loss=0.168, dacon_score=0.0121, auc=0.992, brier=0.0307, ece=0.00247]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.95it/s, loss=0.742, dacon_score=0.0992, auc=0.922, brier=0.137, ece=0.104]


Epoch 13 - step: [3149/3149] train_loss:[0.16792] dacon_score:[0.01210] auc:[0.99238] brier:[0.03068] ece:[0.00247]
Validation - step: [3149/3149] valid_loss:[0.74194] valid_dacon:[0.09921] valid_auc:[0.92228] valid_brier:[0.13716] valid_ece:[0.10425]


Epoch 14 Training: 100%|██████████| 3149/3149 [19:45<00:00,  2.66it/s, loss=0.166, dacon_score=0.0119, auc=0.993, brier=0.0302, ece=0.00255]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.89it/s, loss=0.876, dacon_score=0.107, auc=0.931, brier=0.144, ece=0.144]


Epoch 14 - step: [3149/3149] train_loss:[0.16566] dacon_score:[0.01190] auc:[0.99256] brier:[0.03015] ece:[0.00255]
Validation - step: [3149/3149] valid_loss:[0.87607] valid_dacon:[0.10671] valid_auc:[0.93066] valid_brier:[0.14430] valid_ece:[0.14385]


Epoch 15 Training: 100%|██████████| 3149/3149 [19:45<00:00,  2.66it/s, loss=0.152, dacon_score=0.0108, auc=0.994, brier=0.0275, ece=0.00296]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.95it/s, loss=0.75, dacon_score=0.098, auc=0.916, brier=0.133, ece=0.0903]  


Epoch 15 - step: [3149/3149] train_loss:[0.15229] dacon_score:[0.01079] auc:[0.99364] brier:[0.02749] ece:[0.00296]
Validation - step: [3149/3149] valid_loss:[0.75030] valid_dacon:[0.09799] valid_auc:[0.91583] valid_brier:[0.13335] valid_ece:[0.09026]


Epoch 16 Training: 100%|██████████| 3149/3149 [19:43<00:00,  2.66it/s, loss=0.15, dacon_score=0.0106, auc=0.994, brier=0.0268, ece=0.00326] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.93it/s, loss=0.951, dacon_score=0.124, auc=0.913, brier=0.169, ece=0.154]


Epoch 16 - step: [3149/3149] train_loss:[0.14960] dacon_score:[0.01058] auc:[0.99386] brier:[0.02679] ece:[0.00326]
Validation - step: [3149/3149] valid_loss:[0.95093] valid_dacon:[0.12435] valid_auc:[0.91283] valid_brier:[0.16880] valid_ece:[0.15429]


Epoch 17 Training: 100%|██████████| 3149/3149 [19:46<00:00,  2.65it/s, loss=0.147, dacon_score=0.0101, auc=0.994, brier=0.0261, ece=0.00264] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.92it/s, loss=1, dacon_score=0.127, auc=0.904, brier=0.171, ece=0.147]    


Epoch 17 - step: [3149/3149] train_loss:[0.14658] dacon_score:[0.01011] auc:[0.99418] brier:[0.02613] ece:[0.00264]
Validation - step: [3149/3149] valid_loss:[1.00203] valid_dacon:[0.12744] valid_auc:[0.90378] valid_brier:[0.17051] valid_ece:[0.14680]


Epoch 18 Training: 100%|██████████| 3149/3149 [19:44<00:00,  2.66it/s, loss=0.138, dacon_score=0.00928, auc=0.995, brier=0.0243, ece=0.00272]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.96it/s, loss=0.951, dacon_score=0.117, auc=0.921, brier=0.156, ece=0.152]


Epoch 18 - step: [3149/3149] train_loss:[0.13759] dacon_score:[0.00928] auc:[0.99497] brier:[0.02434] ece:[0.00272]
Validation - step: [3149/3149] valid_loss:[0.95143] valid_dacon:[0.11667] valid_auc:[0.92073] valid_brier:[0.15590] valid_ece:[0.15223]


Epoch 19 Training: 100%|██████████| 3149/3149 [19:44<00:00,  2.66it/s, loss=0.137, dacon_score=0.00926, auc=0.995, brier=0.0242, ece=0.0026] 
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.89it/s, loss=0.983, dacon_score=0.12, auc=0.917, brier=0.16, ece=0.154]  


Epoch 19 - step: [3149/3149] train_loss:[0.13679] dacon_score:[0.00926] auc:[0.99489] brier:[0.02421] ece:[0.00260]
Validation - step: [3149/3149] valid_loss:[0.98272] valid_dacon:[0.11969] valid_auc:[0.91734] valid_brier:[0.15983] valid_ece:[0.15361]


Epoch 20 Training: 100%|██████████| 3149/3149 [19:44<00:00,  2.66it/s, loss=0.136, dacon_score=0.00911, auc=0.995, brier=0.0242, ece=0.00219]
Valid : 100%|██████████| 200/200 [00:25<00:00,  7.95it/s, loss=0.894, dacon_score=0.109, auc=0.924, brier=0.149, ece=0.134]


Epoch 20 - step: [3149/3149] train_loss:[0.13607] dacon_score:[0.00911] auc:[0.99496] brier:[0.02419] ece:[0.00219]
Validation - step: [3149/3149] valid_loss:[0.89409] valid_dacon:[0.10876] valid_auc:[0.92406] valid_brier:[0.14872] valid_ece:[0.13442]
best_dacon:0.09644


In [29]:
base_path = os.path.join(os.getcwd(), 'dataset', 'dacon_dataset')
test_x = [f"{base_path}/test_bandstop_denoising/enhanced_{i}.wav" for i in test_df['id'].tolist()]

In [None]:
test_x = process_batch(test_x)

In [79]:
test_dataset = CustomDataSet(test_x, None)
test_loader = create_data_loader(test_dataset, CFG['BATCH_SIZE'], False, 4)

In [None]:
audio_model.load_state_dict(torch.load('./model_save_epoch_4_aasist_bandstop_addzero_new_dataset_2loss.pth'))

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiLabelTemperatureScaling(nn.Module):
    def __init__(self, temperature_fake_divide=1.5, temperature_real_divide=1.5):
        super(MultiLabelTemperatureScaling, self).__init__()
        # self.temperature_multiply = nn.Parameter(torch.ones(1) * temperature_multiply)
        # self.temperature_divide = nn.Parameter(torch.ones(1) * temperature_divide)
        self.temperature_fake_divide = nn.Parameter(torch.ones(1) * temperature_fake_divide)
        self.temperature_real_divide = nn.Parameter(torch.ones(1) * temperature_real_divide)

    def forward(self, logits):
        # Apply temperature scaling to logits
        scaled_logits = logits.clone()
        scaled_logits[:, 0] = logits[:, 0] / self.temperature_fake_divide
        scaled_logits[:, 1] = logits[:, 1] / self.temperature_real_divide
        # Convert scaled logits to probabilities
        probabilities = torch.sigmoid(scaled_logits)
        return probabilities

In [82]:
import torch.nn.functional as F
audio_model.to(device)
audio_model.eval()
preds = []
test_labels = []
temperature_scaling = MultiLabelTemperatureScaling(temperature_fake_divide=3.5, temperature_real_divide=3.5).to(device)

# Initialize F1 score metric
test_loader_tqdm = tqdm(test_loader, desc=f"inference ")
with torch.no_grad():
    for x in test_loader_tqdm:
        x = x.to(device)
        output = audio_model(x)

        probs = temperature_scaling(output[0])
        preds.append(probs.cpu().detach())

        probabilities = F.softmax(output[1], dim=1)
        predictions = torch.argmax(probabilities, dim=1)

        for i in range(predictions.size(0)):
            test_labels.append(predictions[i].item())  

        
        

inference : 100%|██████████| 2084/2084 [03:44<00:00,  9.27it/s]


In [23]:
from collections import Counter

label_counts = Counter(test_labels)

print(label_counts)


Counter({1: 25785, 4: 7903, 5: 6373, 0: 4499, 3: 3448, 2: 1992})


In [83]:
preds = np.vstack(preds)

In [84]:
import json
import os

def check_boundaries_and_update_pred(index, file_path):
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            data = json.load(f)
            if not data.get('boundaries'): 
                preds[index] = [0, 0]
    else:
        print(f"File {file_path} does not exist")

for i in tqdm(range(50000)):
    name = f"enhanced_TEST_{i:05d}_boundaries.json"
    path = f"./vad_bandstop_boundaries_test/{name}"
    check_boundaries_and_update_pred(i, path)



100%|██████████| 50000/50000 [00:01<00:00, 34393.17it/s]


In [85]:
submit = pd.read_csv('./dataset/dacon_dataset/sample_submission.csv')
submit.iloc[:, 1:] = preds
submit.head()
submit.to_csv('./sw24_4label_epoch6.csv', index=False)

In [37]:
from dacon_submit_api import dacon_submit_api 

result = dacon_submit_api.post_submission_file(
'./sw24_4label_epoch6.csv', 
'', 
'236253', 
'어떻게 너의 목소리를 잊겠어', 
'' )

{'isSubmitted': True, 'detail': 'Success'}
