In [1]:
import warnings
warnings.filterwarnings('ignore')

from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio
import numpy as np
import glob
import collections
from joblib import Parallel, delayed
import collections
from sklearn.metrics import roc_curve

import librosa, librosa.display
import scipy
from scipy.io import wavfile
import soundfile as sf

import torch
from torch import Tensor
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from tensorboardX import SummaryWriter

from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

import sys 
import os
import math
import yaml
import copy
from scipy import signal
import wave
import cv2

In [2]:
######### helper functions ############

## For data
def pad(x, max_len=48000):
    "make sure all data samples have the same length"
    
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    # need to pad
    num_repeats = int(max_len / x_len)+1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]  
    return padded_x

## For SincNet
def next_power_of_2(x):  
    return 1 if x == 0 else 2**(x - 1).bit_length()

def to_2tuple(x):
    if isinstance(x, collections.abc.Iterable):
        return x
    return (x, x)

## For Transformer
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


## For Optimizer
def rate(step, model_size, factor, warmup):
    if step == 0:
        step = 1
    return factor * (
        model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
    )

In [3]:
######### helper class ############
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm,self).__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [4]:
class SincConv(nn.Module):   
    "Apply SincNet filter to the raw audios by using Sinc-based convolution"

    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10 ** (mel / 2595) - 1)

    def __init__(self, out_channels, kernel_size=128, sample_rate=16000, in_channels=1,
                 stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50):

        super(SincConv,self).__init__()

        if in_channels != 1:
            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
            raise ValueError(msg)

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        if kernel_size%2==0:
            self.kernel_size=self.kernel_size+1
            
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        self.sample_rate = sample_rate
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz

        low_hz = 30
        high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)

        mel = np.linspace(self.to_mel(low_hz),
                          self.to_mel(high_hz),
                          self.out_channels + 1)
        hz = self.to_hz(mel)
        

        self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))

        self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))

        # Apply Hamming window
        n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) 
        self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);

        n = (self.kernel_size - 1) / 2.0
        self.band_pass = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate 

 

    def forward(self, x):
        self.band_pass = self.band_pass.to(x.device)

        self.window_ = self.window_.to(x.device)

        low = self.min_low_hz  + torch.abs(self.low_hz_)
        
        high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2)
        band=(high-low)[:,0]
        
        f_times_t_low = torch.matmul(low, self.band_pass)
        f_times_t_high = torch.matmul(high, self.band_pass)

        band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.band_pass/2))*self.window_ 
        band_pass_center = 2*band.view(-1,1)
        band_pass_right= torch.flip(band_pass_left,dims=[1])
        
        
        band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)

        
        band_pass = band_pass / (2*band[:,None])      

        self.filters = (band_pass).view(self.out_channels, 1, self.kernel_size)

        return F.conv1d(x, self.filters, stride=self.stride,
                        padding=self.padding, dilation=self.dilation,
                         bias=None, groups=1) 

In [5]:
class PatchEmbed(nn.Module):
    "Segment SincNet features into patches"
    
    def __init__(self, feature_size, patch_size, embed_dim, in_chans=1):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, H, W = x.shape
        x = x.view(B, 1, H, W)
        x = self.proj(x)
        x = x.flatten(2)
        return x

In [6]:
class EmbedReduce(nn.Module):
    "Reduce the dimension of patches"

    def __init__(self, current_len, seq_size):    
        super(EmbedReduce,self).__init__()
        self.linear1=nn.Linear(current_len, seq_size[0])
        self.lin_ln1 = nn.LayerNorm(seq_size[0])
        self.linear2=nn.Linear(seq_size[0], seq_size[1])
        self.lin_ln2 = nn.LayerNorm(seq_size[1])
        self.linear3=nn.Linear(seq_size[1], seq_size[2])
        self.lin_ln3 = nn.LayerNorm(seq_size[2])
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.lin_ln1(x)
        
        x = self.linear2(x)
        x = self.lin_ln2(x)
        
        x = self.linear3(x)
        x = self.lin_ln3(x)

        return x.transpose(1, 2) 

In [7]:
class PositionalEncoding(nn.Module):
    "Implement the PositionalEncoding function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)  
        

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False) 
        return self.dropout(x)

In [8]:
######### Class for Transformer ############

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    

def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))]

        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout)

        x = (x.transpose(1, 2).contiguous()
            .view(nbatches, -1, self.h * self.d_k))
        
        del query
        del key
        del value
        return self.linears[-1](x)
    
    
class EncoderLayer(nn.Module):

    def __init__(self, d_model, d_hidden, h, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention(h, d_model)
        self.feed_forward = PositionwiseFeedForward(d_model, d_hidden,)
        self.sublayer = clones(SublayerConnection(d_model, dropout), 2)

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)
    
    
class Encoder(nn.Module):

    def __init__(self, d_model, d_hidden, N, h, dropout):
        super(Encoder, self).__init__()
        self.layers = clones(EncoderLayer(d_model, d_hidden, h, dropout), N)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask) 
        return x
    
    
class Transformer(nn.Module):
    def __init__(self, d_model, d_hidden, N=6, h=8,dropout=0.1):
        super(Transformer, self).__init__()
        self.d_model = d_model
        self.d_hidden = d_hidden
        
        self.encoder = Encoder(self.d_model, self.d_hidden, N, h, dropout)
        
        print('initialization: xavier')
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    def forward(self, x):
        src_mask = torch.cat((torch.zeros(x.shape[0],1,1),torch.ones(x.shape[0],1,x.shape[1]-1)),dim=2).cuda()
        x = self.encoder(x,src_mask)
        return x

In [9]:
######### The entire model ############

class Box(nn.Module):
    def __init__(self, model_args, device):
        super(Box, self).__init__()
        
        self.device = device

        self.Sinc_conv = SincConv(out_channels = model_args['num_filter'],
                                  kernel_size = model_args['filt_len'],
                                  in_channels = model_args['in_channels'])
        self.feature_dim = int((model_args['samp_len']-model_args['filt_len']+1)/model_args['max_pool_len'])
        self.lnorm1 = LayerNorm([model_args['num_filter'],self.feature_dim])
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        self.d_embed = model_args['patch_embed']
        self.patchEmbed = PatchEmbed(feature_size = (model_args['num_filter'], self.feature_dim),
                                     patch_size = model_args['patch_size'],
                                     embed_dim = self.d_embed)
        self.seq_len = (model_args['num_filter']//model_args['patch_size'])*(self.feature_dim//model_args['patch_size'])
        self.EmbedReduce = EmbedReduce(current_len = self.seq_len, seq_size = model_args['seq_size'])
        
        self.posEncode = PositionalEncoding(d_model = self.d_embed, dropout = model_args['drop_out'])
        
        self.transformer = Transformer(d_model = self.d_embed, 
                                       d_hidden = model_args['encoder_hidden'], 
                                       N = model_args['num_block'],
                                       h = model_args['num_head'],
                                       dropout = model_args['drop_out'])
        
        
        self.mlp = nn.Sequential(nn.LayerNorm(model_args['seq_size'][2]),
                                 nn.Linear(in_features = model_args['seq_size'][2], out_features = model_args['seq_size'][2]),
                                 nn.Linear(in_features = model_args['seq_size'][2], out_features = model_args['nb_classes']))
        
       
    def forward(self, x, y = None,is_test=False):
        batch = x.shape[0]
        len_seq = x.shape[1]
        x = x.view(batch,1,len_seq)
        
        x = self.Sinc_conv(x)    
        x = F.max_pool1d(torch.abs(x), 3)
        x = self.lnorm1(x)
        x = self.leaky_relu(x)
        
        x = self.patchEmbed(x)
        x = self.EmbedReduce(x)

        x = self.posEncode(x)
        
        x = self.transformer(x).transpose(1,2)
        
        x = self.mlp(x.mean(dim=1))
        output=F.softmax(x,dim=1)
        
        return output

In [None]:
######### Loading dataset ############
AudioFile = collections.namedtuple('AudioFile',
    ['file_name','path','label', 'key'])


class ADDDataset(Dataset):
    def __init__(self, data_path=None, label_path=None,transform=None,
                 is_train=True,is_eval=False,feature=None,track=None):
        self.data_path = data_path
        self.label_path = label_path
        self.transform = transform
        self.track = track
        self.feature = feature
        
        self.dset_name = 'eval' if is_eval else 'train' if is_train else 'dev'
        cache_fname = 'cache_ADD_{}_{}.npy'.format(self.dset_name,self.track)
        if (self.dset_name == 'eval'):
            cache_fname = 'cache_ADD_{}_{}.npy'.format(self.dset_name,self.track)
            self.cache_fname = os.path.join("/home/menglu/123/Deepfake/built", cache_fname) #need to change the directory
        else:   
            cache_fname = 'cache_ADD_{}.npy'.format(self.dset_name)
            self.cache_fname = os.path.join("/home/menglu/123/Deepfake/built", cache_fname)
  
        if os.path.exists(self.cache_fname):
            self.data_x, self.data_y, self.files_meta = torch.load(self.cache_fname)
            print('Dataset loaded from cache', self.cache_fname)
        else: 
            self.files_meta = self.parse_protocols_file(self.label_path)
            data = list(map(self.read_file, self.files_meta))
            self.data_x, self.data_y= map(list, zip(*data))
            if self.transform:
                self.data_x = Parallel(n_jobs=5, prefer='threads')(delayed(self.transform)(x) for x in self.data_x)                          
            torch.save((self.data_x, self.data_y, self.files_meta), self.cache_fname)
        
    def __len__(self):
        self.length = len(self.data_x)
        return self.length
   
    def __getitem__(self, idx):
        x = self.data_x[idx]
        y = self.data_y[idx]
        return x, y
    
    def read_file(self, meta):   
        data_x, sample_rate = librosa.load(meta.path,sr=16000)       
        data_y = meta.key
        return data_x, float(data_y)
      
    def parse_line(self,line):
        tokens = line.strip().split(' ')
        audio_path=os.path.join(self.data_path, tokens[0]).replace('\\','/')
        return AudioFile(file_name=tokens[0], path = audio_path,
                         label=tokens[1], key=int(tokens[1] == 'genuine'))
        
    def parse_protocols_file(self, label_path):
        lines = open(label_path).readlines()
        files_meta = map(self.parse_line, lines)
        return list(files_meta)

    def __len__(self):
        self.length = len(self.data_x)
        return self.length
   
    def __getitem__(self, idx):
        x = self.data_x[idx]
        y = self.data_y[idx]
        return x, y
    
    def read_file(self, meta):   
        data_x, sample_rate = librosa.load(meta.path,sr=16000)       
        data_y = meta.key
        return data_x, float(data_y)
      
    def parse_line(self,line):
        tokens = line.strip().split(' ')
        audio_path=os.path.join(self.data_path, tokens[0]).replace('\\','/')
        return AudioFile(file_name=tokens[0], path = audio_path,
                         label=tokens[1], key=int(tokens[1] == 'genuine'))
        
    def parse_protocols_file(self, label_path):
        lines = open(label_path).readlines()
        files_meta = map(self.parse_line, lines)
        return list(files_meta)

In [None]:
# Dataloader for TRAINING set
database_path = "/home/menglu/123/Dataset/ADD2022/ADD_train_dev/train"  #path of folder that stores traning data
label_path = "/home/menglu/123/Dataset/ADD2022/label/train_label.txt"
transform = transforms.Compose([
    lambda x: pad(x),
    lambda x: Tensor(x)
])
batch_size = 32

train_set = ADDDataset(data_path=database_path,label_path=label_path,is_train=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,drop_last=True)

# Dataloader for VALIDATION set
dev_data_path = "/home/menglu/123/Dataset/ADD2022/ADD_train_dev/dev"
dev_label_path = "/home/menglu/123/Dataset/ADD2022/label/dev_label.txt"

dev_set = ADDDataset(data_path = dev_data_path,label_path = dev_label_path,is_train=False, transform=transform)
dev_loader = DataLoader(dev_set, batch_size=batch_size, shuffle=True)

In [None]:
def train_epoch(data_loader, model, lr, optim, device, scheduler = None):
    "Function for training process"
    
    running_loss = 0
    num_correct = 0.0
    num_total = 0.0
    model.train()
    weight = torch.FloatTensor([1.0, 9.0]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)

    for batch_x, batch_y in data_loader:   
        batch_size = batch_x.size(0)
        num_total += batch_size
        batch_x = batch_x.to(device)
        batch_y = batch_y.view(-1).type(torch.int64).to(device)
        batch_out = model(batch_x,batch_y)
        batch_loss = criterion(batch_out, batch_y)
        _, batch_pred = batch_out.max(dim=1)
        num_correct += (batch_pred == batch_y).sum(dim=0).item()
        running_loss += (batch_loss.item() * batch_size)

        optim.zero_grad()
        batch_loss.backw'+current_model+'ard()
        optim.step()
        if scheduler !=None:
            scheduler.step()
       
    running_loss /= num_total
    train_accuracy = (num_correct/num_total)*100
    return running_loss, train_accuracy

In [None]:
def evaluate_accuracy(data_loader, model, device):
    "Function for validation process"
    
    num_correct = 0.0
    num_total = 0.0
    model.eval()

    for batch_x, batch_y in data_loader:
        batch_size = batch_x.size(0)
        num_total += batch_size
        batch_x = batch_x.to(device)
        batch_y = batch_y.view(-1).type(torch.int64).to(device)
        batch_out = model(batch_x,batch_y)
        _, batch_pred = batch_out.max(dim=1)
        num_correct += (batch_pred == batch_y).sum(dim=0).item()
    return 100 * (num_correct / num_total)

In [None]:
np.random.seed(1234)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# GPU device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Parameter
config = yaml.safe_load(open('model_config.yaml'))
lr = config['lr']
warmup = config['warmup']
num_epochs = config['epoch']

d_model = config['model']['patch_embed']
num_filter = config['model']['num_filter']
num_block = config['model']['num_block']
num_head = config['model']['num_head']

# Model Initialization
model = Box(config['model'],device).to(device)
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
print(nb_params)

# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), 
                             lr = lr, betas=(0.9, 0.98), eps=1e-9)
lr_scheduler = LambdaLR(optimizer=optimizer,
                        lr_lambda=lambda step: rate(step, d_model, factor=1, warmup=config["warmup"]),)

In [None]:
### create folder to save model parameters
model_tag = 'SincNet_Transformer_{}_{}_{}_{}_{}_{}'.format(batch_size,d_model, num_filter, num_block, num_head,lr)

#need to change the directory
model_save_path = os.path.join('/home/menglu/123/Deepfake/built', model_tag)
os.makedirs(model_save_path)
print(model_tag)

In [None]:
# Training and validation 
writer = SummaryWriter('logs/{}'.format(model_tag))
best_acc = 0
for epoch in range(num_epochs):
    running_loss, train_accuracy = train_epoch(train_loader,model, lr,optimizer, device, lr_scheduler)
    valid_accuracy = evaluate_accuracy(dev_loader, model, device)
    writer.add_scalar('train_accuracy', train_accuracy, epoch)
    writer.add_scalar('valid_accuracy', valid_accuracy, epoch)
    writer.add_scalar('loss', running_loss, epoch)
    print('\n{} - {} - {:.4f} - {:.4f}'.format(epoch,
                                               running_loss, train_accuracy, valid_accuracy))
    best_acc = max(valid_accuracy, best_acc)
    torch.save(model.state_dict(), os.path.join(model_save_path, 'epoch_{}.pth'.format(epoch)))

writer.close()