In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification 
import pandas as pd
import librosa
import numpy as np
from tqdm import tqdm

import sys
sys.path.append("..")
from src.models import EModel, AASIST, Wav2Vec2Facebook
from src.datamodules import AASIST2DataModule, AASISTCenterLossDataset

In [4]:
# model_name_or_path = "facebook/wav2vec2-base"
# feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
# sampling_rate = feature_extractor.sampling_rate

path = "/home/work/joono/joono/joono/DV_DV.Deep/x2hfl5gz/checkpoints/best-checkpoint_aug_oneshot.ckpt"
# model = Wav2Vec2Facebook.load_from_checkpoint(path, args={})
model = AASIST.load_from_checkpoint(path, args={})

In [5]:
# Collate 함수 정의
def collate_fn(batch):
    # signals = zip(*batch)
    signals = batch
    max_length = max([signal.size(0) for signal in signals])
    padded_signals = torch.zeros(len(signals), max_length)
    for i, signal in enumerate(signals):
        padded_signals[i, :signal.size(0)] = signal
    # labels = torch.tensor(labels)
    return padded_signals

In [7]:
test_df = pd.read_csv('../dataset/test.csv', index_col=None)
test_df['path'] = '../dataset/' + test_df['path'].str[1:]
# test_dataset = TestDataset(test_df)
test_dataset = AASISTCenterLossDataset(test_df, train_mode=False)
# test_loader = DataLoader(test_dataset, shuffle=False, num_workers=24, batch_size=32)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=24, batch_size=32, collate_fn=collate_fn)

In [44]:
from torchaudio.models import conv_tasnet_base
from IPython.display import Audio


model = conv_tasnet_base(num_sources=2)

In [48]:
audio, _, _ = test_dataset[3]
print(audio.shape)

Audio(audio, rate=16000)

# result = model(audio)
# Audio(result[0, 1, :].detach().numpy(), rate=16000)

torch.Size([96000])


In [8]:
def inference(model, test_loader, device):
    model.to(device)
    model = model.eval()
    predictions = []
    with torch.no_grad():
        # for inputs, labels in tqdm(test_loader):
        for inputs in tqdm(test_loader):
            
            inputs = inputs.to(device)
            # labels = labels.to(device)
            
            probs, _ = model(inputs)

            probs  = probs.cpu().detach().numpy()
            predictions += probs.tolist()
            
    return predictions

In [9]:
preds = inference(model=model, test_loader=test_loader, device='cuda:0')
# preds = model.inference(test_loader=test_loader)

100%|██████████| 1563/1563 [01:32<00:00, 16.88it/s]


In [11]:
submit = pd.read_csv('/home/work/joono/joono/dataset/sample_submission.csv')

max_thres = 0.5
min_thres = 0.5

for i in tqdm(range(len(preds))):
    if      preds[i][0] > max_thres : submit.iloc[i, 1] = 1
    elif    preds[i][0] < min_thres : submit.iloc[i, 1] = 0 
    else                            : submit.iloc[i, 1] = preds[i][0]
    if      preds[i][1] > max_thres : submit.iloc[i, 2] = 1
    elif    preds[i][1] < min_thres : submit.iloc[i, 2] = 0 
    else                            : submit.iloc[i, 2] = preds[i][1]

submit[10000:10050]

100%|██████████| 50000/50000 [00:09<00:00, 5429.96it/s]


Unnamed: 0,id,fake,real
10000,TEST_10000,1,1
10001,TEST_10001,1,0
10002,TEST_10002,0,1
10003,TEST_10003,0,1
10004,TEST_10004,0,1
10005,TEST_10005,1,1
10006,TEST_10006,1,1
10007,TEST_10007,0,1
10008,TEST_10008,0,1
10009,TEST_10009,0,1


In [None]:
# submit.to_csv('j0.5.csv', index=False)

In [13]:
submit.to_csv('AASIST_train_test_align_centerloss_test_zero_one2.csv', index=False)

In [None]:
preds