In [None]:
import warnings
warnings.filterwarnings(action='ignore') 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torch.optim as optim
import torchvision.transforms as transforms
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader

import os
import sys
import math
import time
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import hashlib
import itertools
import librosa
import librosa.display
from pympler import asizeof
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import IPython.display # IPython.display for audio output

In [None]:
# tensorize preprocessed files (Dataset, DataLoader)
def volume_check(waves, specs, threshold=10):
    """
    threshold : threshold of data block. unit : GB
    """
    # TODO : check whether volume of waves & specs under threshold.
    if asizeof.asizeof(waves) + asizeof.asizeof(specs) > threshold * 2**30:    
        return True
    return False

class MIBERT_Dataset(Dataset):
    """
    Dataset for MIBERT
    : [ wave, masked_wave, mel_spectrogram, masked_mel_spectrogram, boundaries, labels, ... ]
    each sample shape info
    : wave : [hop_length * k]
    : spectrogram : [n_mels, k+1] if center = True, else  [n_mels, k + 1 - n_fft/hop_length ]
    : boundaries : [ num_seg + 1 ]
    : labels : [ num_seg ]
    : seg_nums : int
    : masking_indice : [ int(masking_ratio * len(set(labels))) ]

    """
    def __init__(self, meta_csv, config, transform=None):
        self.df = pd.read_csv(meta_csv)
        self.columns = self.df.columns
        self.transform = transform

        self.waves, self.specs, self.masked_waves, self.masked_specs = [], [], [], []
        self.boundaries_collection, self.labels_collection, self.seg_nums, self.masking_indice = [], [], [], []
        for i in range(len(self.df)): # TODO : Random selection?
            if volume_check(waves, specs, config['data_block_threshold']): # TODO pseudo volume check using waves and spectrograms.  
                unit = self.df.iloc[i]
                cropped_wave_path, mel_spectrogram_path = unit[2], unit[3]
                if len(unit) == 6:
                    # case : using single feature and single algorithm of msaf
                    maksed_wave_path, masked_other_path = unit[4], unit[5]
                    
                    cropped_wave = librosa.load(cropped_wave_path, sr=config['sr']) 
                    masked_wave = librosa.load(masked_wave_path, sr=config['sr'])
                    
                    with open(mel_spectrogram_path, 'rb') as f:
                        mel_spec = pickle.load(f)
                    with open(masked_other_path, 'rb') as f:
                        maksed_spec, boundaries, labels, seg_num, masking_index = pickle.load(f)
                    
                    self.waves.append(cropped_wave)
                    self.specs.append(mel_spec)
                    self.masked_waves.append(masked_wave)
                    self.masked_specs.append(maksed_spec)
                    self.boundaries_collection.append(boundaries)
                    self.labels_collection.append(labels)
                    self.seg_nums.append(seg_num)
                    self.masking_indice.append(masking_index)                    
                else:
                    # case : using multiple features & algorithms of msaf
                    # TODO : 
                    pass

    def __len__(self):
        return len(self.waves)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {
            'wave' : self.waves[idx],
            'spec' : self.specs[idx],
            'masked_wave' : self.masked_waves[idx],
            'masked_spec' : self.masked_specs[idx],
            'boundaries' : self.boundaries_collection[idx],
            'labels' : self.labels_collection[idx],
            'seg_num' : self.seg_nums,
            'masking_idx' : self.masking_indice[idx]
        }
        if self.transform:
            sample = self.transform(sample)
        return sample 

# mibert_trn_dataset = MIBERT_Dataset('./meta.csv',config)
# mibert_tst_dataset = MIBERT_Dataset('./meta.csv',config)

# trn_mibert_dataloader = DataLoader(mibert_trn_dataset, batch_size = config['batch_size'], drop_last=True)
# tst_mibert_dataloader = DataLoader(mibert_tst_dataset, batch_size = config['batch_size'], drop_last=True)

In [None]:
def tensorize_padding_batch(input, mode = 'wave', n_mels = 128):
    """
    input is wave(dim=2, [batch_size, hop_length * k ] ) or spectrogram(dim=3, [ batch_size, n_mels, k+1 ]).
    k is different among batch samples.
    mode is flag whether 'wave' or 'spec'
    """
    real_length = [sample.shape[-1] for sample in input]
    max_length = max(real_length)
    
    pad_size = 1 if mode is 'wave' else n_mels
    reshape_size = -1 if mode is 'wave' else [n_mels, -1]
    zero_pad = [0] * pad_size

    padded_input = np.array(
        [np.concatenate(
            (sample, np.tile(zero_pad, max_length - sample.shape[-1]).reshape(reshape_size))
            , axis= -1
            ) for sample in input ]
        )
    return torch.from_numpy(padded_input).float(), real_length

In [None]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)

class FramizationCNN(nn.Module):
    def __init__(self, config):
        super(FramizationCNN, self).__init__()
        # feature = config['FramizationCNN']['feature']
        feature = 128
        self.conv1 = nn.Conv1d(in_channels = 1, out_channels = feature, kernel_size = 128, stride =64 , padding = 64, bias = False)
        self.conv2 = nn.Conv1d(in_channels = feature, out_channels = feature, kernel_size = 128, stride = 2, padding = 64, bias = False)
        self.conv3 = nn.Conv1d(in_channels = feature, out_channels = feature, kernel_size = 128, stride = 2, padding = 64, bias = False)
        self.conv4 = nn.Conv1d(in_channels = feature, out_channels = feature, kernel_size = 256, stride = 2, padding = 128, bias = False)
        
        self.bn1 = nn.BatchNorm1d(feature)
        self.bn2 = nn.BatchNorm1d(feature)
        self.bn3 = nn.BatchNorm1d(feature)
        self.bn4 = nn.BatchNorm1d(feature)

        self.init_weight()
    
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_layer(self.conv3)
        init_layer(self.conv4)

        init_bn(self.bn1)
        init_bn(self.bn2)
        init_bn(self.bn3)
        init_bn(self.bn4)
        
    def forward(self, x):
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        x = F.relu_(self.bn3(self.conv3(x)))
        x = self.bn4(self.conv4(x))
        return x    

In [None]:
def gelu(x):
    "Implementation of the gelu activation function by Hugging Face"
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
    
class FramewiseFC(nn.Module):
    def __init__(self, config):
        super(FramewiseFC, self).__init__()
        layer_params = config['FraemwiseFC']
        self.fc1 = nn.Linear(layer_params['fc1_in_dim'], layer_params['fc1_out_dim'])
        self.fc2 = nn.Linear(layer_params['fc2_in_dim'], layer_params['fc2_out_dim'])

    def forward(self, x):
        return self.fc2(gelu(self.fc1(x)))

In [None]:
def time2frame(t, sr=22050, hop_length=512):
    return t * sr / hop_length

# TODO : to modify batch version for speed up - need for optimization
def segment_pooling(frames, boundaries_batch, real_len, config, mode='max'):
    """
    pooling by segment according to boundaries/labels of msaf
    < input >
    frames : [batch_size, num_frames, feature]
    boundaries : [batch_size, num_boundaries]
    real_len : [batch_size] unit : frame
    mode : pooling mode. ex) 'max' : max-pooling / 'avg' : average-pooling / ...
    """

    # TODO : real segment zero-padding according to config['num_segment'] 
    # BERT에서 어떻게 처리하는지 알아보기.

    result = []
    if mode == 'max':
        pool_func = torch.max
    if mode == 'avg':
        pool_func = torch.mean
        
    for idx, boundaries in enumerate(boundaries_batch):
        batch_buffer = []
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            batch_buffer.append(pool_func(frames[idx, math.floor(time2frame(start)):math.ceil(time2frame(end))], axis=0).values) # [feature]
        result.append(torch.stack(batch_buffer))
    return torch.stack(result)

In [None]:
# BERT

def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)

def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)


class LayerNorm(nn.Module):
    "A layernorm module in the TF style (epsilon inside the square root)."
    def __init__(self, hidden, variance_epsilon=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(hidden))
        self.beta  = nn.Parameter(torch.zeros(hidden))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta

class Time2Vec(nn.Module):
    def __init__(self, in_dim=3, out_dim=1024):
        super(Time2Vec,self).__init__()
        self.w = nn.parameter.Parameter(torch.randn(in_dim, out_dim))
        self.b = nn.parameter.Parameter(torch.randn(out_dim))
        self.F_sin = torch.sin
        self.F_tanh = torch.tanh
        self.F_relu = torch.relu

    def forward(self, t):
        """
        t : [batch_size, seg_num, 2 or 3] / 2 or 3 means start, end of boundary, (length of segment )
        ----------------------
        h : [batch_size, seg_num, feature]
        """
        h = torch.matmul(t, self.w) + self.b
        splited_h = torch.split(h, math.ceil(h.shape[1]/4), dim=1)
        h = torch.cat([splited_h[0],
                       self.F_sin(splited_h[1]),
                       self.F_tanh(splited_h[2]),
                       self.F_relu(splited_h[3])], dim=1)
        return h
        
class BoundaryEmbedding(nn.Module):
    """
    Boundary Embedding module to merge information among frames, position(boundary, time), etc.  
    """
    def __init__(self, config):
        super(BoundaryEmbedding,self).__init__()
        self.boundary_embed = Time2Vec(config['Time2Vec']['in_dim'], config['Time2Vec']['out_dim'])
        self.norm = LayerNorm(config['Time2Vec']['out_dim'])

    def forward(self, x, boundaries, add_length=True):
        """
        x : [batch_size, seg_num, feature]
        boundaries : [batch_size, seg_num, 2] / 2 means start, end of boundary
        """
        if add_length :
            segment_length = (boundaries[:,:,1] - boundaries[:,:,0]).unsqueeze(-1) # [batch_size, seg_num, 1]
            boundaries = torch.cat((boundaries, segment_length), dim=-1) # [batch_size, seg_num, 3]
        t = self.boundary_embed(boundaries) # [batch_size, seg_num, feature]
        h = x + t
        return self.norm(h)          

class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Dot-Product Self-Attention
    """
    def __init__(self, config):
        super(MultiHeadSelfAttention,self).__init__()
        self.proj_q = nn.Linear(config['BERT']['hidden'], config['BERT']['hidden']) # W_query
        self.proj_k = nn.Linear(config['BERT']['hidden'], config['BERT']['hidden']) # W_key 
        self.proj_v = nn.Linear(config['BERT']['hidden'], config['BERT']['hidden']) # W_value
        self.n_heads = config['BERT']['head']

    def forward(self, x, mask=None):
        """
        x : [batch_size, seg_num, feature]
        mask : 
        """
        # x : [ batch_size, seg_num, feature ]
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        # (projection) q, k ,v : [batch_size, seg_num, hidden ]
        q, k ,v = [split_last(x, (self.n_heads, -1)).transpose(1,2) for x in [q,k,v]]
        # (split) q, k, v : [batch_size, seg_num, head, w (= hidden/head) ]
        # (transpose) q, k, v : [batch_size, head, seg_num, w] 
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        # scores : [batch_size, head, seg_num, seg_num] 

        # TODO : masking
        # if mask is not None:
        #     mask = mask[:, None, None, :].float()
        #     scores -= 10000.0 * (1.0 - mask)
        
        scores = F.softmax(scores, dim=-1)
        # scores : [batch_size, head, seg_num, seg_num]
        h = (scores @ v).transpose(1,2).contiguous()
        # h : [batch_size, head, seg_num, w]
        # (transpose) h : [batch_size, seg_num, head, w]
        h = merge_last(h, 2)
        # h : [batch_size, seg_num, hidden]
        self.scores = scores
        return h
    
class PositionWiseFeedForward(nn.Module):
    """
    FFN(Feed Forward Neural Networks) for each position
    """
    def __init__(self, config):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(config['BERT']['hidden'], config['BERT']['hidden_ff'])
        self.fc2 = nn.Linear(config['BERT']['hidden_ff'], config['BERT']['hidden'])

    def forward(self, x):
        return self.fc2(gelu(self.fc1(x)))              

class BERT(nn.Module):
    def __init__(self, config):
        super(BERT, self).__init__()
        self.be = BoundaryEmbedding(config)
        self.num_layers = config['BERT']['n_layers']
        self.attn = MultiHeadSelfAttention(config)
        self.proj = nn.Linear(config['BERT']['hidden'], config['BERT']['hidden'])
        self.norm1 = LayerNorm(config['BERT']['hidden'])
        self.pwff = PositionWiseFeedForward(config)
        self.norm2 = LayerNorm(config['BERT']['hidden'])

    def forward(self, x, boundaries, mask=None):
        """
        x : [batch_size, num_segment, feature]
        boundaries : [batch_size, num_segment, 2]
        ------------------------
        h : [batch_size, num_segment + 1, hidden ]
        """

        h = self.be(x, boundaries, add_length=True)
        for _ in range(self.num_layers):
            h = self.attn(h, mask)
            h = self.norm1(h + self.proj(h))
            h = self.norm2(h + self.pwff(h))
        return h        

In [None]:
class DomainConcatFC(nn.Module):
    def __init__(self, config):
        super(DomainConcatFC, self).__init__()
        self.fc1 = nn.Linear(config['BERT']['hidden']*config['num_domain'],config['DomainConcatFC']['hidden'])
        self.fc2 = nn.Linear(config['DomainConcatFC']['hidden'], config['BERT']['hidden'])

    def forward(self, hidden_wave, hidden_spec):
        """
        hidden_* : [batch_size, num_segment, feature]
        """
        concat_feature = torch.cat((hidden_wave, hidden_spec), dim=-1) # [ batch_size, num_segment, 2 * feature ]
        h = self.fc2(gelu(self.fc1(concat_feature)))
        return h

In [None]:
# Deep InfoMax(DIM) 
# reference : https://github.com/rdevon/DIM
# reference : https://github.com/DuaneNielsen/DeepInfomaxPytorch

def raise_measure_error(measure):
    supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1']
    raise NotImplementedError(
        'Measure `{}` not supported. Supported: {}'.format(measure,
                                                           supported_measures))

def get_positive_expectation(p_samples, mesaure, average = True):
    log_2 = math.log(2.)

    if measure == 'GAN':
        Ep = - F.softplus(-p_samples)
    elif measure == 'JSD':
        Ep = log_2 - F.softplus(-p_samples) 
    elif measure == 'X2':
        Ep = p_samples ** 2
    elif measure == 'KL':
        Ep = p_samples
    elif measure == 'RKL':
        Ep = -torch.exp(-p_samples)
    elif measure == 'DV':
        Ep = p_samples
    elif measure == 'H2':
        Ep = 1. - torch.exp(-p_samples)
    elif measure == 'W1':
        Ep = p_samples
    else:
        raise_measure_error(measure)
    
    if average:
        return Ep.mean()
    else:
        return Ep

def get_negative_expectation(q_samples, measure, average = True):
    log_2 = math.log(2.)

    if measure == 'GAN':
        Eq = F.softplus(-q_samples) + q_samples
    elif measure == 'JSD':
        Eq = F.softplus(-q_samples) + q_samples - log_2 
    elif measure == 'X2':
        Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
    elif measure == 'KL':
        Eq = torch.exp(q_samples - 1.)
    elif measure == 'RKL':
        Eq = q_samples - 1.
    elif measure == 'DV':
        Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
    elif measure == 'H2':
        Eq = torch.exp(q_samples) - 1.
    elif measure == 'W1':
        Eq = q_samples
    else:
        raise_measure_error(measure)
    
    if average:
        return Ep.mean()
    else:
        return Ep

def generator_loss(q_samples, measure, loss_type=None):
    """Computes the loss for the generator of a GAN.
    Args:
        q_samples: fake samples.
        measure: Measure to compute loss for.
        loss_type: Type of loss: basic `minimax` or `non-saturating`.
    """
    if not loss_type or loss_type == 'minimax':
        return get_negative_expectation(q_samples, measure)
    elif loss_type == 'non-saturating':
        return -get_positive_expectation(q_samples, measure)
    else:
        raise NotImplementedError(
            'Generator loss type `{}` not supported. '
            'Supported: [None, non-saturating, boundary-seek]')

def fenchel_dual_loss(x, measure):
    """
    x : [batch_size, batch_size]
    """
    batch_size = x.size(0)
    pos_mask = torch.eye(batch_size).to(x.device)
    neg_mask = 1 - pos_mask

    E_pos = get_positive_expectation(x, measure, average = False)
    E_neg = get_negative_expectation(x, measure, average = False)

    E_pos = (E_pos * pos_mask).sum() / pos_mask.sum()
    E_neg = (E_neg * neg_mask).sum() / neg_mask.sum()
    loss = E_neg - E_pos
    return loss

def info_nce_loss(x):
    pass

def donsker_varadhan_loss(x):
    pass

def compute_dim_loss(x, mode, measure):
    if mode == 'fd':
        return fenchel_dual_loss(x, measure)
    if mode == 'nce':
        return info_nce_loss(x)
    if mode == 'dv':
        return donsker_varadhan_loss(x)


class GlobalDiscriminator(nn.Module):
    """
    global discriminator of DIM.
    ==============================
    <input>
    : global feature, G : [ batch_size, feature_dim ]
    : local feature sequence, LS : [batch_size, seg_num, feature_dim]
    ==============================
    <output>
    : boolean of global discriminator return, B ; [batch_size, seg_num]
    ==============================    
    """
    def __init__(self, config):
        super(GlobalDiscriminator, self).__init__()
        self.num_segment = config['num_segment']
        self.layer_0 = nn.Linear(config['BERT']['hidden']*2, config['GlobalDiscriminator']['hidden'])
        self.layer_1 = nn.Linear(config['GlobalDiscriminator']['hidden'], config['GlobalDiscriminator']['hidden_ns'])
        self.layer_2 = nn.Linear(config['GlobalDiscriminator']['hidden_ns'] * self.num_segment, 1)

    def forward(self, g, ls):
        bs = g.shape[0]
        h = torch.cat((ls, g.repeat(self.num_segment,1,1).transpose(1,0)), dim=-1) # [batch_size, num_segment, feature_dim * 2]
        h = F.relu(self.layer_0(h)) # [batch_size, num_segment, GD_hidden]
        h = F.relu(self.layer_1(h)) # [batch_size, num_segment,]
        h = self.layer_2(h.view(bs, -1)) # [batch_size, num_segment * GD_hidden_ns] -> [batch_size, 1]
        return h

class LocalDiscriminator(nn.Module):
    """
    local discriminator of DIM.
    ==============================
    <input>
    : global feature, G : [ batch_size * num_segment, feature_dim ]
    : local features, L : [ batch_size * num_segment, feature_dim ]
    (option) index, I : [batch_size] or time_vector, tv : [batch_size, feature_dim]
    ==============================
    <output>
    : boolean of local discrminator return, B : [batch_size * num_segment, 1]
    ==============================    
    """
    def __init__(self, config):
        super(LocalDiscriminator, self).__init__()
        self.layer_0 = nn.Linear(config['BERT']['hidden']*2, config['LocalDiscriminator']['hidden'])
        self.layer_1 = nn.Linear(config['LocalDiscriminator']['hidden'], 1)

    def forward(self,g, l, idx=None):
        if idx:
            pass
        else:
            h = torch.cat((g, l), dim=-1) # [batch_size * num_segment, 2 * feature_dim]
            h = F.relu(self.layer_0(h)) # [batch_size * num_segment, LD_hidden ]
            h = self.layer_1(h) # [batch_size * num_segment, 1]
            return h

class MaskDiscriminator(nn.Module):
    """
    masked discriminator of DIM.
    ==============================
    <input>
    : local feature, L : [ batch_size , num_segment , feature_dim ]
    : origin feature, O : [ batch_size, num_segment , feature_dim ]
    ==============================
    <output>
    : boolean of masked discriminator return, B : [batch_size , num_segment ]   
    ==============================    
    """
    def __init__(self, config):
        super().__init__()
        self.layer_0 = nn.Linear(config['BERT']['hidden']*2, config['MaskDiscriminator']['hidden'])
        self.layer_1 = nn.Linear(config['MaskDiscriminator']['hidden'], 1)

    def forward(self, ls, origins):
        h = torch.cat((ls, origins), dim=-1) # [batch_size , num_segment , feature_dim * 2]
        h = F.relu(self.layer_0(h)) # [batch_size, num_segment , MD_hidden]
        h = self.layer_1(h) # [batch_size , num_segment, 1]
        return h

class PriorMatching(nn.Module):
    """
    prior matching of DIM.
    ==============================
    <input>
    : x : <global feature, G : [batch_size, feature_dim] >  or < prior, P : [ batch_size, feature_dim ]>
    ==============================
    <output>
    : boolean of prior matching return, B : [batch_size] = True if prior, False if global feature  
    ==============================    
    """
    def __init__(self, config):
        super().__init__()
        self.layer_0 = nn.Linear(config['BERT']['hidden'], config['PriorMatching']['hidden'])
        self.layer_1 = nn.Linear(config['PriorMatching']['hidden'], config['PriorMatching']['hidden'])
        self.layer_2 = nn.Linear(config['PriorMatching']['hidden'], 1)

    def forward(self, x):
        h = F.relu(self.layer_0(x))
        h = F.relu(self.layer_1(h))
        return torch.sigmoid(self.layer_2(h))

class MIBERT_Loss(nn.Module):
    def __init__(self, config):
        super(MIBERT_Loss, self).__init__()
        self.GD = GlobalDiscriminator(config)
        self.LD = LocalDiscriminator(config)
        self.MD = MaskDiscriminator(config)
        self.PM = PriorMatching(config)

        self.alpha = config['alpha']
        self.beta = config['beta']
        self.gamma = config['gamma']
        self.delta = config['delta']
        self.num_segment = config['num_segment']

    def forward(self, g, ls, ls_prime, origins):
        """
        g : [batch_size, 1, feature]
        ls : [batch_size, num_segment, feature]
        ls_prime : [batch_size, num_segment, feature]
        origins : [batch_Size, num_segment, feature]
        """
        # g, ls = x[:,0:1,:], x[:,1:,:] # [batch_size, 1, feature], [batch_size, num_segment, feature]
        
        Ej = -F.softplus(-self.GD(g.squeeze(), ls)).mean()
        Em = F.softplus(self.GD(g.squeeze(), ls_prime)).mean()
        GLOBAL = (Em - Ej) * self.alpha

        Ej = -F.softplus(-self.LD(g.repeat(1,self.num_segment,1 ), ls)).mean()
        Em = F.softplus(self.LD(g.repeat(1, self.num_segment, 1), ls_prime)).mean()
        LOCAL = (Em - Ej) * self.beta

        Ej = -F.softplus(-self.MD(ls, origins)).mean()
        Em = F.softplus(self.MD(ls_prime, origins)).mean()
        MASK = (Em - Ej) * self.gamma

        prior = torch.rand_like(g.squeeze())
        term_a = torch.log(self.PM(prior)).mean()
        term_b = torch.log(1.0 - self.PM(g.squeeze())).mean()
        PRIOR = -( term_a + term_b ) * self.delta

        return GLOBAL + LOCAL + MASK + PRIOR

In [None]:
config = {
    'batch_size' : 4,
    
    'data_block_threshold' : 10,
    'sr' : 44100, 
    'FramizationCNN' : {
        'feature' : 128,
    },
    'FraemwiseFC' : {
        'fc1_in_dim' : 128,
        'fc1_out_dim' : 1024,
        'fc2_in_dim' : 1024,
        'fc2_out_dim' : 1024,
    },
    'Time2Vec':{
        'in_dim': 3,
        'out_dim': 1024,
    },
    'num_segment' : 15,   
    't2v_add_length' : True,
    'BERT' : {
        'hidden' : 1024,
        'head' : 8,
        'n_layers' : 3,
        'hidden_ff' : 2048,
    },
    'num_domain' : 2,
    'DomainConcatFC' : {
        'hidden' : 3072
    },
    'init_segment' : 1024,
    'measure' : 'JSD',
    'mode' : 'fd',
    'GlobalDiscriminator' : {
    	'hidden' : 1024,
    	'hidden_ns' : 64,
    },
    'LocalDiscriminator' : {
    	'hidden' : 1024,
    },
    'MaskDiscriminator' : {
    	'hidden' : 1024,
    },
    'PriorMatching' : {
    	'hidden' : 1024,
    },
    'alpha' : 0.3,
    'beta' : 0.3,
    'gamma' : 0.3,
    'delta' : 0.1
}


In [None]:
# For debugging
debug_info_lst = [1530, 150, 1529, 631]
# masking
def mel_filterbank(spectrogram, sr, n_fft, n_mels):
    mel_filter = librosa.filters.mel(sr = sr , n_fft = n_fft, n_mels = n_mels)
    return (mel_filter @ np.abs(spectrogram)**2)

batch_size = 4

debug_wavs = [ np.random.normal(size=512*k) for k in debug_info_lst ]
debug_specs = [mel_filterbank(librosa.stft(w), 44100, 2048, 128) for w in debug_wavs]

W, w_real_len = tensorize_padding_batch(debug_wavs, mode="wave")
S, s_real_len = tensorize_padding_batch(debug_specs, mode="spec")
print("shape of wave : ",W.shape)
print("shape of spectrogram : ", S.shape)

frameCNN = FramizationCNN(config)
FFC = FramewiseFC(config)

framed_w = frameCNN(torch.unsqueeze(W, 1)) # unsqueeze : [batch_size, 1, wave_length]
print("shape of frameCNN result : ", framed_w.shape)

# transpose for fully-connected layer
framed_w = torch.transpose(framed_w, 1,2)
S = torch.transpose(S, 1,2)

ffc_w = FFC(framed_w)
ffc_s = FFC(S)

print("shape of wave FramewiseFC result : ", ffc_w.shape)
print("shape of spec FramewiseFC result : ", ffc_s.shape)

boundaries = [
    sorted(
        [0, k*512/44100]+
        [random.uniform(0, k * 512 / 44100) for sn in range(config['num_segment']-1)]
        ) 
    for k in debug_info_lst
]
labels = [
    [
        random.randrange(0, config['num_segment']//2)
        for _ in range(len(boundaries[idx])-1)
    ]
    for idx, _ in enumerate(debug_info_lst)
]

sp_w_frame = segment_pooling(ffc_w, boundaries,w_real_len, config) # [batch_size, num_segment, feature]
sp_s_frame = segment_pooling(ffc_s, boundaries,s_real_len, config) 

boundaries = torch.tensor(boundaries)
boundaries = torch.stack((boundaries[:,:-1] , boundaries[:,1:]), dim=-1) # [batch_size, num_segment, 2]

# add init segment block to boundaries / segment_pooling frames 
init_segment = nn.Parameter(torch.randn(config['init_segment']))
boundaries_tensor = torch.zeros(size = [config['batch_size'], config['num_segment']+1, 2])
boundaries_tensor[:,1:,:] = boundaries # [batch_size, num_segment+1, 2]
sp_w_frame = torch.cat((init_segment.repeat(config['batch_size'],1).unsqueeze(1), sp_w_frame), dim =1)
sp_s_frame = torch.cat((init_segment.repeat(config['batch_size'],1).unsqueeze(1), sp_s_frame), dim =1)

print('shape of segment pool result about wave frame : ', sp_w_frame.shape) # [batch_size, num_segment+1, feature]
print('shape of segment pool result about spec frame : ', sp_s_frame.shape)
print("shape of boundaries tensor : ", boundaries_tensor.shape)

bert = BERT(config)
h_w = bert(sp_w_frame, boundaries_tensor)
h_s = bert(sp_s_frame, boundaries_tensor)

print('shape of BERT result about wave : ', h_w.shape)
print('shape of BERT result about spec : ', h_s.shape)

dcfc = DomainConcatFC(config)

h = dcfc(h_w, h_s)

print('shape of DCFC result : ', h.shape)

mb_loss = MIBERT_Loss(config)

g = h[:,0:1,:]
ls = h[:,1:,:]
ls_prime = torch.cat((ls[1:], ls[0].unsqueeze(0)), dim=0)

loss = mb_loss(g, ls, ls_prime, ls)

print(loss)

shape of wave :  torch.Size([4, 783360])
shape of spectrogram :  torch.Size([4, 128, 1531])
shape of frameCNN result :  torch.Size([4, 128, 1531])
shape of wave FramewiseFC result :  torch.Size([4, 1531, 1024])
shape of spec FramewiseFC result :  torch.Size([4, 1531, 1024])
shape of segment pool result about wave frame :  torch.Size([4, 16, 1024])
shape of segment pool result about spec frame :  torch.Size([4, 16, 1024])
shape of boundaries tensor :  torch.Size([4, 16, 2])
shape of BERT result about wave :  torch.Size([4, 16, 1024])
shape of BERT result about spec :  torch.Size([4, 16, 1024])
shape of DCFC result :  torch.Size([4, 16, 1024])
tensor(1.3848, grad_fn=<AddBackward0>)
