In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import librosa
import numpy as np
import whisper
import pandas as pd
import random
import os
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from src.models import (
    lcnn,
    specrnet,
    whisper_specrnet,
    rawnet3,
    whisper_lcnn,
    meso_net,
    whisper_meso_net
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
class Config:
    SR = 32000
    N_MFCC = 13
    # Dataset
    ROOT_FOLDER = './'
    # Training
    N_CLASSES = 2
    BATCH_SIZE = 96
    N_EPOCHS = 50
    LR = 3e-4
    # Others
    SEED = 42

CONFIG = Config()

In [5]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CONFIG.SEED) # Seed 고정

In [6]:
test = pd.read_csv('../SW/test.csv')

In [7]:
model_config = {
    'fc1_dim': 1024,
    'frontend_algorithm': ["mfcc"],  # 이 배열이 어떻게 사용되는지 모델 정의를 확인해야 함
    'input_channels': 1
}

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 인스턴스 생성, device 매개변수도 포함되어야 하는 경우 발생할 수 있으므로 추가
model = meso_net.FrontendMesoInception4(fc1_dim=model_config['fc1_dim'],
                       frontend_algorithm=model_config['frontend_algorithm'],
                       input_channels=model_config['input_channels'],
                       device=device)

Using ['mfcc'] frontend


In [9]:
weights_path = './trained_models/meso+mfcc/weights.pth'

In [10]:
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
model.eval()

FrontendMesoInception4(
  (Incption1_conv1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (Incption1_conv2_1): Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (Incption1_conv2_2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (Incption1_conv3_1): Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (Incption1_conv3_2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
  (Incption1_conv4_1): Conv2d(1, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (Incption1_conv4_2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
  (Incption1_bn): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (Incption2_conv1): Conv2d(11, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (Incption2_conv2_1): Conv2d(11, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (Incption2_conv2_2): Conv2d(4, 4, kernel_size=(3,

In [20]:
SAMPLING_RATE = 16_000
APPLY_NORMALIZATION = True
APPLY_TRIMMING = True
APPLY_PADDING = True
FRAMES_NUMBER = 480_000  # Originally 64_600
win_length = 400  # int((25 / 1_000) * SAMPLING_RATE)
hop_length = 160


MFCC_FN = torchaudio.transforms.MFCC(
    sample_rate=SAMPLING_RATE,
    n_mfcc=128,
    melkwargs={
        "n_fft": 512,
        "win_length": win_length,
        "hop_length": hop_length,
    },
).to(device)


SOX_SILENCE = [
    # Trim silence longer than 0.2s and louder than 1% volume
    ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
]

class SimpleAudioDataset(Dataset):
    def __init__(self, csv_file, subset, transform=None, return_meta=False):
        self.samples = pd.read_csv(csv_file)
        self.transform = transform
        self.subset = subset
        self.return_meta = return_meta

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample = self.samples.iloc[index]
        path = "../SW/" + str(sample["path"])
        file_id = sample["id"]

        waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION)
        real_sec_length = len(waveform[0]) / sample_rate

        waveform, sample_rate = apply_preprocessing(waveform, sample_rate)

        return_data = [waveform, sample_rate]
        if self.return_meta:
            return_data.append((file_id, path, self.subset, real_sec_length))

        return return_data


def apply_preprocessing(waveform, sample_rate):
    if sample_rate != SAMPLING_RATE:
        waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)

    if waveform.dim() > 1 and waveform.shape[0] > 1:
        waveform = waveform[:1, ...]

    if APPLY_TRIMMING:
        waveform, sample_rate = apply_trim(waveform, sample_rate)

    if APPLY_PADDING:
        waveform = apply_pad(waveform, FRAMES_NUMBER)
    return waveform, sample_rate


def resample_wave(waveform, sample_rate, target_sample_rate):
    waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
        waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
    )
    return waveform, sample_rate


def apply_trim(waveform, sample_rate):
    waveform_trimmed, sample_rate_trimmed = torchaudio.sox_effects.apply_effects_tensor(
        waveform, sample_rate, SOX_SILENCE
    )

    if waveform_trimmed.size()[1] > 0:
        waveform = waveform_trimmed
        sample_rate = sample_rate_trimmed
    
    return waveform, sample_rate


def apply_pad(waveform, cut):
    waveform = waveform.squeeze(0)
    waveform_len = waveform.shape[0]

    if waveform_len >= cut:
        return waveform[:cut]

    num_repeats = int(cut / waveform_len) + 1
    padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]

    return padded_waveform

In [21]:
csv_file = '../SW/test.csv'
subset = 'test'

audio_dataset = SimpleAudioDataset(csv_file=csv_file, subset=subset, return_meta=True)
audio_loader = DataLoader(audio_dataset, batch_size=16, shuffle=False)
print(audio_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fa5dc5669a0>


In [22]:
def inference(model, loader, device):
    model.eval()
    model.to(device)
    all_probs = []
    with torch.no_grad():
        for batch in tqdm(loader):
            if type(batch) is list:
                features, sample_rate = batch[0], batch[1]  # data_loader에서 features와 sample_rate 분리.
                features = features.to(device)
            else:
                features = batch.float().to(device)  # 이 경우에는 batch가 이미 텐서일 때.
            probs = model(features)
            probs = torch.sigmoid(probs)
            probs = probs.cpu().detach().numpy()
            all_probs.extend(probs)
    
    return all_probs

In [23]:
preds = inference(model, audio_loader, device)

  0%|                                                                                                 | 1/3125 [00:00<19:19,  2.69it/s]

tensor([[[[-3.5445e+02, -1.6710e+01,  1.1079e+02,  ..., -1.7561e+02,
           -1.7415e+02, -1.6293e+02],
          [ 8.4408e+01,  4.6711e+01,  2.1555e+01,  ...,  3.2955e+01,
            3.2773e+01,  3.2692e+01],
          [-2.5486e+01,  2.0013e+01,  2.1526e+01,  ..., -5.1567e+01,
           -3.1446e+01, -1.0841e+01],
          ...,
          [ 3.2288e-05,  3.2710e-05,  3.3134e-05,  ...,  1.1960e-05,
            1.1880e-05,  1.1796e-05],
          [ 1.2332e-05,  1.2513e-05,  1.2693e-05,  ..., -6.8617e-06,
           -6.7205e-06, -6.5829e-06],
          [ 5.6295e-05,  5.7031e-05,  5.7762e-05,  ...,  1.1436e-05,
            1.1255e-05,  1.1081e-05]]],


        [[[-1.4766e+02, -1.6300e+02, -1.6068e+02,  ..., -1.9398e+02,
           -2.1464e+02, -1.8817e+02],
          [ 1.3624e+02,  1.3937e+02,  1.4358e+02,  ...,  2.5106e+01,
            4.0132e+01,  4.1161e+01],
          [ 1.8788e+00, -1.8004e+01, -2.9482e+01,  ...,  1.1691e+02,
            1.1575e+02,  8.2711e+01],
          ...,
   

  0%|                                                                                                 | 2/3125 [00:00<19:04,  2.73it/s]

tensor([[[[-1.3523e+02, -1.2583e+02, -1.4506e+02,  ..., -2.3623e+02,
           -2.4758e+02, -2.5094e+02],
          [ 1.4724e+01,  2.4617e+01,  1.0761e+01,  ...,  5.4242e+01,
            6.1364e+01,  6.9576e+01],
          [-4.9529e+01, -5.4039e+01, -5.3669e+01,  ..., -4.4621e+01,
           -3.2781e+01, -1.5433e+01],
          ...,
          [-1.4878e-05, -1.4994e-05, -1.5112e-05,  ..., -1.0131e-05,
           -1.0009e-05, -9.8909e-06],
          [-9.7092e-06, -9.7652e-06, -9.8241e-06,  ..., -4.8793e-05,
           -4.8151e-05, -4.7499e-05],
          [-3.6455e-05, -3.6930e-05, -3.7406e-05,  ..., -4.5722e-05,
           -4.5107e-05, -4.4492e-05]]],


        [[[-1.2278e+02, -2.5072e+02, -2.6437e+02,  ..., -1.4014e+02,
           -2.5297e+02, -1.7492e+02],
          [-4.9077e+01, -2.7904e-01, -6.3117e+00,  ...,  1.1521e+02,
            8.5282e+01,  6.2756e+01],
          [-1.1543e+02, -1.0458e+02, -8.5607e+01,  ..., -8.1125e+00,
            6.9651e+00, -2.1906e+01],
          ...,
   

  0%|                                                                                                 | 3/3125 [00:01<18:56,  2.75it/s]

tensor([[[[-5.6337e+01, -5.5371e+01, -7.7872e+01,  ..., -6.7552e+01,
           -7.2632e+01, -5.8182e+01],
          [-1.0933e+01, -7.6833e+00, -2.5809e+01,  ..., -1.7587e+01,
           -1.4904e+01, -1.4257e+01],
          [ 8.5037e+00, -1.7096e+01, -1.5225e+01,  ..., -1.9793e+01,
           -1.3356e+01, -1.0738e+01],
          ...,
          [ 2.9959e-05,  3.0391e-05,  3.0822e-05,  ...,  1.5607e-05,
            1.5355e-05,  1.5104e-05],
          [ 7.0094e-06,  7.1436e-06,  7.2762e-06,  ...,  2.6268e-05,
            2.5896e-05,  2.5516e-05],
          [ 1.7632e-05,  1.7840e-05,  1.8045e-05,  ...,  5.1646e-06,
            5.0800e-06,  4.9908e-06]]],


        [[[-1.2557e+02, -1.0165e+02, -1.0322e+02,  ..., -7.2718e+01,
           -6.7113e+01, -6.8480e+01],
          [ 5.5437e+01,  7.6288e+01,  7.6561e+01,  ...,  7.8319e+01,
            7.9119e+01,  8.2171e+01],
          [ 3.6132e+01,  2.3559e+01,  1.7908e+01,  ..., -2.8625e+01,
           -2.3087e+01, -1.2621e+01],
          ...,
   

  0%|                                                                                                 | 4/3125 [00:01<19:35,  2.66it/s]

tensor([[[[-9.6869e+01, -3.5542e+01, -7.3376e-01,  ..., -3.5507e+01,
           -5.3883e+01, -3.8733e+01],
          [ 6.5265e+01,  6.8722e+01,  4.6264e+01,  ...,  9.8887e+01,
            8.6942e+01,  8.9670e+01],
          [-1.4694e+00, -1.7697e+01, -3.4030e+01,  ..., -7.9244e+00,
           -1.8939e+01, -2.4635e+01],
          ...,
          [-1.6877e-05, -1.7107e-05, -1.7340e-05,  ..., -1.6550e-05,
           -1.6396e-05, -1.6236e-05],
          [ 3.8684e-06,  3.9426e-06,  4.0167e-06,  ...,  4.1350e-06,
            4.0897e-06,  4.0437e-06],
          [-2.1045e-06, -2.1213e-06, -2.1365e-06,  ..., -1.1788e-05,
           -1.1677e-05, -1.1560e-05]]],


        [[[-2.8821e+02, -2.8826e+02, -3.0519e+02,  ..., -1.6446e+02,
           -1.6538e+02, -1.6191e+02],
          [ 1.0740e+02,  1.1576e+02,  1.0352e+02,  ...,  1.1411e+02,
            1.1073e+02,  9.3652e+01],
          [ 2.1711e+00,  8.9400e+00, -6.8585e+00,  ..., -4.7811e+01,
           -3.3418e+01, -3.6879e+01],
          ...,
   

  0%|▏                                                                                                | 5/3125 [00:01<19:26,  2.67it/s]

tensor([[[[-1.0316e+02, -8.3749e+01, -8.0802e+01,  ..., -4.7045e+00,
           -6.4262e+00,  6.4437e+00],
          [ 2.8555e+01,  4.8594e+01,  4.6793e+01,  ...,  8.4762e+00,
            4.6763e+00,  2.3235e+01],
          [ 1.4594e+01,  2.0831e+01,  2.4774e+01,  ...,  7.9491e+00,
           -1.2981e+00, -3.4203e-01],
          ...,
          [-8.5013e-05, -8.6022e-05, -8.7016e-05,  ..., -3.8966e-06,
           -3.8869e-06, -3.8786e-06],
          [-7.9011e-05, -7.9996e-05, -8.0972e-05,  ...,  8.5915e-06,
            8.4507e-06,  8.3061e-06],
          [-5.4166e-05, -5.4864e-05, -5.5552e-05,  ..., -2.3372e-05,
           -2.3095e-05, -2.2819e-05]]],


        [[[-4.3963e+01, -9.1845e+01, -6.9672e+01,  ..., -2.6342e+02,
           -3.9404e+01, -8.1325e+01],
          [ 6.4056e+01,  5.7113e+01,  6.4841e+01,  ..., -7.8658e+01,
            2.0408e+01,  4.3268e+01],
          [-8.1023e+00, -4.5833e+01, -3.3207e+01,  ...,  7.1725e+01,
            1.7128e+01, -3.1104e+01],
          ...,
   

  0%|▏                                                                                                | 6/3125 [00:02<19:24,  2.68it/s]

tensor([[[[-2.8048e+02, -3.2719e+02, -3.1412e+02,  ..., -1.6927e+02,
           -1.8565e+02, -1.9742e+02],
          [-2.6053e+01, -1.4600e+01, -2.2017e+01,  ..., -1.1055e+02,
           -1.2280e+02, -1.3345e+02],
          [-7.4092e+01, -5.1563e+01, -5.0461e+01,  ...,  6.3370e+01,
            6.6750e+01,  6.1396e+01],
          ...,
          [ 3.9632e-05,  4.0103e-05,  4.0568e-05,  ...,  1.4735e-06,
            1.3774e-06,  1.2813e-06],
          [-8.5787e-06, -8.6824e-06, -8.7882e-06,  ..., -3.1565e-06,
           -3.1444e-06, -3.1324e-06],
          [ 7.8033e-06,  7.8306e-06,  7.8501e-06,  ...,  1.2202e-05,
            1.2042e-05,  1.1881e-05]]],


        [[[-2.9734e+02, -3.2674e+02, -2.8707e+02,  ..., -4.2500e+02,
           -4.7045e+02, -4.8617e+02],
          [ 7.8561e+01,  9.4140e+01,  1.1250e+02,  ...,  5.8163e+01,
            6.7982e+01,  8.7253e+01],
          [-6.1709e+01, -7.7057e+01, -7.7208e+01,  ...,  4.8077e+01,
            4.4329e+01,  4.9430e+01],
          ...,
   

  0%|▏                                                                                                | 7/3125 [00:02<19:42,  2.64it/s]

tensor([[[[-4.9037e+00,  9.4668e+00,  2.1524e+01,  ..., -1.7630e+02,
           -1.8411e+02, -1.3175e+02],
          [ 6.4552e+01,  7.6688e+01,  9.1613e+01,  ...,  1.2487e+02,
            7.3314e+01,  3.2244e+01],
          [ 2.2805e+01,  2.5994e+01,  2.7958e+01,  ...,  3.0951e+01,
            6.8892e+01,  7.8934e+01],
          ...,
          [-1.6253e-06, -1.5715e-06, -1.5179e-06,  ..., -2.1988e-05,
           -2.1741e-05, -2.1493e-05],
          [ 3.1979e-06,  3.3191e-06,  3.4372e-06,  ...,  5.6979e-06,
            5.5757e-06,  5.4541e-06],
          [-2.1280e-05, -2.1486e-05, -2.1692e-05,  ..., -1.1681e-06,
           -1.1539e-06, -1.1400e-06]]],


        [[[-3.7384e+02, -3.3023e+02, -3.6918e+02,  ..., -1.7367e+02,
           -1.6174e+02, -1.1668e+02],
          [ 1.6825e+01,  3.3680e+01,  4.0275e+01,  ...,  3.7229e+01,
            3.5132e+01, -2.6171e+00],
          [ 2.1291e+01, -1.1720e+01, -8.2263e+00,  ..., -6.3046e+01,
           -5.8458e+01, -2.3050e+01],
          ...,
   

  0%|▏                                                                                                | 8/3125 [00:03<20:28,  2.54it/s]

tensor([[[[-1.0739e+02, -1.1988e+02, -1.2885e+02,  ..., -1.1640e+02,
           -1.2510e+02, -8.0409e+01],
          [ 2.6859e+01,  3.4147e+01,  3.7069e+01,  ...,  5.2460e+01,
            5.1073e+01,  3.5656e+01],
          [-5.0044e+01, -6.1640e+01, -4.4575e+01,  ..., -4.5057e+01,
           -3.5214e+01, -4.2503e+01],
          ...,
          [-9.7865e-06, -9.8880e-06, -9.9907e-06,  ...,  3.9016e-05,
            3.8541e-05,  3.8061e-05],
          [ 1.5795e-06,  1.6148e-06,  1.6504e-06,  ...,  1.1714e-05,
            1.1563e-05,  1.1411e-05],
          [ 6.2683e-06,  6.3707e-06,  6.4757e-06,  ...,  9.1997e-06,
            9.0469e-06,  8.8981e-06]]],


        [[[-7.7535e+01, -9.5336e+01, -1.0263e+02,  ..., -9.5286e+01,
           -1.0125e+02, -9.2045e+01],
          [ 1.6179e+01,  4.3430e+00,  4.1221e+00,  ..., -1.1652e+02,
           -1.1508e+02, -1.0543e+02],
          [ 5.6668e+01,  3.1183e+01,  2.4448e+01,  ...,  2.4737e+01,
            3.8585e+01,  2.0762e+01],
          ...,
   

  0%|▎                                                                                                | 9/3125 [00:03<20:16,  2.56it/s]

tensor([[[[-1.4269e+02, -1.4187e+02, -1.3023e+02,  ..., -1.9518e+02,
           -1.6812e+02, -1.1363e+02],
          [ 6.9027e+00, -2.1990e+01, -3.5445e+01,  ..., -1.1022e+01,
            8.0006e-01, -1.1071e+00],
          [ 5.9588e+01,  6.0623e+01,  5.9478e+01,  ...,  1.3227e+02,
            1.2769e+02,  9.1470e+01],
          ...,
          [ 2.9555e-05,  3.0036e-05,  3.0512e-05,  ..., -2.0613e-05,
           -2.0496e-05, -2.0375e-05],
          [ 2.4278e-06,  2.5290e-06,  2.6277e-06,  ..., -1.3353e-05,
           -1.3289e-05, -1.3226e-05],
          [-7.7219e-06, -7.7953e-06, -7.8695e-06,  ..., -7.2317e-07,
           -7.4394e-07, -7.6379e-07]]],


        [[[-3.1141e+02, -2.5010e+02, -2.2653e+02,  ..., -2.1804e+02,
           -2.1952e+02, -2.1885e+02],
          [ 7.6610e+01,  7.5876e+01,  8.6486e+01,  ...,  8.6835e+01,
            7.1105e+01,  6.4378e+01],
          [ 3.3816e+01,  2.6450e+01,  4.7404e+01,  ...,  4.6570e+01,
            4.7752e+01,  3.9634e+01],
          ...,
   

  0%|▎                                                                                               | 10/3125 [00:03<19:48,  2.62it/s]

tensor([[[[ 6.3241e+01,  6.8821e+01,  5.7739e+01,  ...,  3.9577e+01,
            4.2782e+01,  3.8644e+01],
          [ 1.5344e+01,  1.5881e+01,  9.3338e+00,  ...,  7.5275e+01,
            7.3879e+01,  6.4556e+01],
          [ 1.3315e+01,  2.3449e+01,  2.2143e+01,  ..., -6.3548e+00,
           -1.9984e+01, -1.0548e+01],
          ...,
          [-1.8488e-05, -1.8698e-05, -1.8909e-05,  ..., -2.0227e-05,
           -2.0026e-05, -1.9822e-05],
          [-1.1305e-05, -1.1427e-05, -1.1547e-05,  ..., -2.2472e-05,
           -2.2198e-05, -2.1914e-05],
          [ 7.2326e-06,  7.3220e-06,  7.4082e-06,  ..., -2.1147e-05,
           -2.0880e-05, -2.0613e-05]]],


        [[[-1.7642e+02, -1.7120e+02, -1.6836e+02,  ...,  5.4837e+01,
            5.1421e+01,  5.2284e+01],
          [ 7.3072e+01,  5.0008e+01,  3.4627e+01,  ...,  6.9041e+00,
           -3.5635e+00, -6.6932e+00],
          [ 1.6799e+01, -1.3328e+01, -2.0952e+01,  ..., -3.5081e+01,
           -4.2159e+01, -3.6836e+01],
          ...,
   

  0%|▎                                                                                               | 10/3125 [00:04<20:54,  2.48it/s]


KeyboardInterrupt: 

In [125]:
real_probabilities = [prob[0] for prob in preds]
fake_probabilities = [1 - prob[0] for prob in preds]

In [127]:
submit = pd.read_csv('../SW/sample_submission.csv')

# 예측 값을 DataFrame에 추가
# 주의: sample_submission.csv의 형식에 따라 열 이름이 다를 수 있습니다.
submit['fake'] = fake_probabilities
submit['real'] = real_probabilities

# 확인을 위해 데이터 출력
print(submit.head())

# CSV 파일로 저장
submit.to_csv('./baseline_submit3.csv', index=False)

           id      fake      real
0  TEST_00000  0.496536  0.503464
1  TEST_00001  0.509840  0.490160
2  TEST_00002  0.466382  0.533618
3  TEST_00003  0.444550  0.555450
4  TEST_00004  0.456526  0.543474
