In [1]:
import torch
import torch.nn as nn

# 입력 텐서 정의
input_tensor = torch.tensor([[[1, 1, 1]], [[0, 0, 0]]], dtype=torch.float32)

# 텐서를 FloatTensor로 변환하고 배치 차원 추가
con = nn.Conv1d(in_channels=1, out_channels=5, kernel_size=1,
                              stride=1, dilation=1, padding=0, bias=False)

# 출력 텐서 확인



---

In [None]:
import config
import torch
import numpy as np
from torchlibrosa.stft import STFT, ISTFT, magphase
import torch.nn as nn
import torch.nn.functional as F

In [None]:
window_size = 2048
hop_size = config.hop_samples
window = 'hann'
pad_mode = 'reflect'
center = True
momentum = 0.01
downsample_ratio = 2**6
channels=2
activation='relu'

In [None]:
stft = STFT(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, 
            pad_mode=pad_mode, freeze_parameters=True)

In [None]:
def wav_to_spectrogram(input):
    """Waveform to spectrogram.

    Args:
        input: (batch_size, segment_samples, channels_num)

    Outputs:
        output: (batch_size, channels_num, time_steps, freq_bins)
    """
    sp_list = []
    #####
    channels_num = input.shape[2]
    for channel in range(channels_num):
        sp_list.append(spectrogram(input[:, :, channel]))
    #####
    # for _ in range(self.channels):
    #     sp_list.append(self.spectrogram(input[:,:,0]))

    output = torch.cat(sp_list, dim=1)
    return output

In [None]:
def spectrogram(input):
    (real, imag) = stft(input)
    return (real ** 2 + imag ** 2) ** 0.5

In [None]:
bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)

In [None]:
encoder_block1 = EncoderBlock(in_channels=channels, out_channels=32, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)

In [None]:
def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_emb(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.uniform_(layer.weight, -0.1, 0.1)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)

def act(x, activation):
    if activation == 'relu':
        return F.relu_(x)

    elif activation == 'leaky_relu':
        return F.leaky_relu_(x, negative_slope=0.2)

    elif activation == 'swish':
        return x * torch.sigmoid(x)

    else:
        raise Exception('Incorrect activation!')

def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, size, activation, momentum, classes_num = 527):
        super(ConvBlock, self).__init__()

        self.activation = activation
        pad = size // 2

        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)

        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

        self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
        # change autotagging size
        #####
        self.emb1 = nn.Linear(classes_num, out_channels, bias=True)
        self.emb2 = nn.Linear(classes_num, out_channels, bias=True)
        ####
        # self.emb1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=size,
        #                       stride=1, dilation=1, padding=pad, bias=False)
        # self.emb2 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=size,
        #                       stride=1, dilation=1, padding=pad, bias=False)
        self.init_weights()
        
    def init_weights(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_emb(self.emb1)
        init_emb(self.emb2)

    # latent query embedded 
    def forward(self, x, condition):
        c1 = self.emb1(condition)
        c2 = self.emb2(condition)
        x = act(self.bn1(self.conv1(x)), self.activation) + c1[:, :, None, None]
        x = act(self.bn2(self.conv2(x)), self.activation) + c2[:, :, None, None]
        return x

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample, activation, momentum, classes_num = 527):
        super(EncoderBlock, self).__init__()
        size = 3

        self.conv_block = ConvBlock(in_channels, out_channels, size, activation, momentum, classes_num)
        self.downsample = downsample

    def forward(self, x, condition):
        encoder = self.conv_block(x, condition)
        encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
        return encoder_pool, encoder

---

In [None]:
import librosa

y, _ = librosa.load("/Users/cooky/HDD/Drum/tactspack/drum loops/drumroll 12_sel.wav", mono=True)

In [None]:
# input = y.reshape(1,y.shape[0],1)
input = np.array([y,y,y,y,y]).reshape(5,y.shape[0],1)

In [None]:
input.shape

(5, 18462, 1)

In [None]:
input = torch.tensor(input, dtype=torch.float32)

In [None]:
input.shape

torch.Size([5, 18462, 1])

In [None]:
sp = wav_to_spectrogram(input)

In [None]:
x = sp.transpose(1,3)
x = bn0(x)
x = x.transpose(1,3)

In [None]:
x.shape

torch.Size([5, 1, 58, 1025])

In [None]:
origin_len = x.shape[2]
pad_len = int(np.ceil(x.shape[2] / downsample_ratio)) \
    * downsample_ratio - origin_len
x = F.pad(x, pad=(0, 0, 0, pad_len))


In [None]:
x.shape

torch.Size([5, 1, 64, 1025])

In [None]:
x = x[..., 0 : x.shape[-1] - 1]

In [None]:
x.shape

torch.Size([5, 1, 64, 1024])

In [None]:
encoder_block1 = EncoderBlock(in_channels=channels, out_channels=32, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = 99)

In [None]:
condition = torch.tensor(np.zeros((1,99)),dtype=torch.float32)

In [None]:
encoder_block1(x, condition)

RuntimeError: Given groups=1, weight of size [32, 5, 3, 3], expected input[1, 1, 64, 1024] to have 5 channels, but got 1 channels instead

In [None]:

size = 3
pad = size//2
conv_block = ConvBlock(channels, 32, size, activation, momentum, 99)


In [None]:
conv1 = nn.Conv2d(in_channels=1, 
                              out_channels=5,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

In [None]:
conv1(x).shape

torch.Size([5, 5, 64, 1024])

In [None]:
bn1 = nn.BatchNorm2d(5, momentum=momentum)

In [None]:
x1 = act(bn1(conv1(x)), activation)

In [None]:
x1.shape

torch.Size([5, 5, 64, 1024])

In [None]:
emb1 = nn.Linear(99, 5, bias=True)

In [None]:
emb_imsi = nn.Conv1d(in_channels=channels, 
                              out_channels=5*5,
                              kernel_size=size, stride=1,
                              dilation=1, padding=pad, bias=False)

In [None]:
condition = torch.tensor(np.zeros((5,99)),dtype=torch.float32)

In [None]:
c1 = emb1(condition)
# c1 = emb_imsi(condition)


In [None]:
c1.shape

torch.Size([5, 5])

In [None]:
c11 = c1[:,:,None,None]
c11.shape

torch.Size([5, 5, 1, 1])

In [None]:
(x1 + c11).shape

torch.Size([5, 5, 64, 1024])

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, size, activation, momentum, classes_num = 527):
        super(ConvBlock, self).__init__()

        self.activation = activation
        pad = size // 2

        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)

        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

        self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
        # change autotagging size
        #####
        self.emb1 = nn.Linear(classes_num, out_channels, bias=True)
        self.emb2 = nn.Linear(classes_num, out_channels, bias=True)
        ####
        # self.emb1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=size,
        #                       stride=1, dilation=1, padding=pad, bias=False)
        # self.emb2 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=size,
        #                       stride=1, dilation=1, padding=pad, bias=False)
        self.init_weights()
        
    def init_weights(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_emb(self.emb1)
        init_emb(self.emb2)

    # latent query embedded 
    def forward(self, x, condition):
        c1 = self.emb1(condition)
        c2 = self.emb2(condition)
        x = act(self.bn1(self.conv1(x)), self.activation) + c1[:, :, None, None]
        x = act(self.bn2(self.conv2(x)), self.activation) + c2[:, :, None, None]
        return x

In [None]:

encoder = conv_block(x, condition)


RuntimeError: The size of tensor a (1024) must match the size of tensor b (32) at non-singleton dimension 4

In [None]:
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder

NameError: name 'encoder' is not defined

----

In [3]:
import torchvggish.vggish as vggish

model_urls = {
    'vggish': 'https://github.com/harritaylor/torchvggish/'
              'releases/download/v0.1/vggish-10086976.pth',
    'pca': 'https://github.com/harritaylor/torchvggish/'
           'releases/download/v0.1/vggish_pca_params-970ea276.pth'
}
model_urls = {
    'vggish': './torchvggish/cp/vggish-10086976.pth',
    'pca': './torchvggish/cp/vggish_pca_params-970ea276.pth'
}

In [4]:
model = vggish.VGGish(model_urls)

In [None]:
import torchaudio.prototype.pipelines.VGGishBundle.VGGish as VGGish

OSError: dlopen(/Users/cooky/miniforge3/envs/cid2rch/lib/python3.10/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at14RecordFunctionC1ENS_11RecordScopeEb
  Referenced from: <E741B6D5-E348-3601-ACC9-BC3101AD112C> /Users/cooky/miniforge3/envs/cid2rch/lib/python3.10/site-packages/torchaudio/lib/libtorchaudio.so
  Expected in:     <AAE88793-2D9D-3CCA-96C4-EAC30CEA4202> /Users/cooky/miniforge3/envs/cid2rch/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib

---

In [1]:
# Localize All Around Sounds


from museval.metrics import validate
from numba.core.types.containers import DictKeysIterableType
import numpy as np
import librosa
import os
import sys
import math
import bisect
import pickle
import soundfile as sf
import subprocess

import noisereduce as nr
from asp.utils import get_segment_bgn_end_samples, np_to_pytorch, get_mix_data, evaluate_sdr, wiener, split_nparray_with_overlap
# from losses import get_loss_func

import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torch.optim as optim
from torch.nn.parameter import Parameter
import torch.distributed as dist
from torchlibrosa.stft import STFT, ISTFT, magphase
import pytorch_lightning as pl


#####
import torchvggish.vggish as vggish




def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_emb(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.uniform_(layer.weight, -0.1, 0.1)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


def init_gru(rnn):
    """Initialize a GRU layer. """
    
    def _concat_init(tensor, init_funcs):
        (length, fan_out) = tensor.shape
        fan_in = length // len(init_funcs)
    
        for (i, init_func) in enumerate(init_funcs):
            init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
        
    def _inner_uniform(tensor):
        fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
        nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
    
    for i in range(rnn.num_layers):
        _concat_init(
            getattr(rnn, 'weight_ih_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, _inner_uniform]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)

        _concat_init(
            getattr(rnn, 'weight_hh_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, nn.init.orthogonal_]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)


def act(x, activation):
    if activation == 'relu':
        return F.relu_(x)

    elif activation == 'leaky_relu':
        return F.leaky_relu_(x, negative_slope=0.2)

    elif activation == 'swish':
        return x * torch.sigmoid(x)

    else:
        raise Exception('Incorrect activation!')



class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, size, activation, momentum, classes_num = 128):
        super(ConvBlock, self).__init__()

        self.activation = activation
        pad = size // 2

        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)

        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(size, size), stride=(1, 1), 
                              dilation=(1, 1), padding=(pad, pad), bias=False)

        self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
        # change autotagging size
        self.emb1 = nn.Linear(classes_num, out_channels, bias=True)
        self.emb2 = nn.Linear(classes_num, out_channels, bias=True)
        #####
        # self.emb_conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=size,
        #                       stride=1, dilation=1, padding=pad, bias=False)
        # self.emb_conv2 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=size,
        #                       stride=1, dilation=1, padding=pad, bias=False)
        #####
        self.init_weights()
        
    def init_weights(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_emb(self.emb1)
        init_emb(self.emb2)

    # latent query embedded 
    def forward(self, x, condition):
        # c1 = self.emb1(condition)
        # c2 = self.emb2(condition)
        # x = act(self.bn1(self.conv1(x)), self.activation) + c1[:, :, None, None]
        # x = act(self.bn2(self.conv2(x)), self.activation) + c2[:, :, None, None]
        #####
        c1_ = self.emb_conv1(condition)
        c1 = self.emb1(c1_)
        c2_ = self.emb_conv2(condition)
        c2 = self.emb2(c2_)
        x = act(self.bn1(self.conv1(x)), self.activation) + c1[:, :, None]
        x = act(self.bn2(self.conv2(x)), self.activation) + c2[:, :, None]
        #####
        
        return x


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample, activation, momentum, classes_num = 527):
        super(EncoderBlock, self).__init__()
        size = 3

        self.conv_block = ConvBlock(in_channels, out_channels, size, activation, momentum, classes_num)
        self.downsample = downsample

    def forward(self, x, condition):
        encoder = self.conv_block(x, condition)
        encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
        return encoder_pool, encoder


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, activation, momentum, classes_num = 527):
        super(DecoderBlock, self).__init__()
        size = 3
        self.activation = activation

        self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels, 
            out_channels=out_channels, kernel_size=(size, size), stride=stride, 
            padding=(0, 0), output_padding=(0, 0), bias=False, dilation=(1, 1))

        self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
        self.conv_block2 = ConvBlock(out_channels * 2, out_channels, size, activation, momentum, classes_num)
        # change autotagging size
        self.emb1 = nn.Linear(classes_num, out_channels, bias=True)

    def init_weights(self):
        init_layer(self.conv1)
        init_bn(self.bn)
        init_emb(self.emb1)

    def prune(self, x):
        """Prune the shape of x after transpose convolution.
        """
        x = x[:, :, 0 : - 1, 0 : - 1]
        return x

    def forward(self, input_tensor, concat_tensor, condition):
        c1 = self.emb1(condition)
        x = act(self.bn1(self.conv1(input_tensor)), self.activation) + c1[:, :, None, None]
        x = self.prune(x)
        x = torch.cat((x, concat_tensor), dim=1)
        x = self.conv_block2(x, condition)
        return x



class ZeroShotASP(pl.LightningModule):
    '''
    Args:
    channels (int): the audio channel, default:1 (mono)
    config (module): the configuration module as in config.py
    at_model (module): the sound event detection system
    dataset (module): the dataset variable to control the randomness in each epoch (not affect in evaluation mode) 
    '''
    def __init__(self, config, at_model, dataset, channels=100):
        super().__init__()

        # hyper parameters
        window_size = 2048
        hop_size = config.hop_samples
        center = True
        pad_mode = 'reflect'
        window = 'hann'
        activation = 'relu'
        momentum = 0.01
        self.check_flag = False
        #####
        self.channels=channels
        #####
        self.config = config
        self.at_model = at_model
        self.opt_thres = pickle.load(open('opt_thres.pkl', 'rb'))
        self.loss_func = get_loss_func(self.config.loss_type)
        self.dataset = dataset
        if self.config.using_whiting:
            temp = np.load("whiting_weight.npy", allow_pickle=True)
            temp = temp.item()
            self.whiting_kernel = temp["kernel"]
            self.whiting_bias = temp["bias"]

        self.downsample_ratio = 2 ** 6   # This number equals 2^{#encoder_blcoks}

        self.stft = STFT(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, 
            pad_mode=pad_mode, freeze_parameters=True)

        self.istft = ISTFT(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, 
            pad_mode=pad_mode, freeze_parameters=True)

        self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)

        self.encoder_block1 = EncoderBlock(in_channels=channels, out_channels=32*channels, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.encoder_block2 = EncoderBlock(in_channels=32, out_channels=64, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.encoder_block3 = EncoderBlock(in_channels=64, out_channels=128, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.encoder_block4 = EncoderBlock(in_channels=128, out_channels=256, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.encoder_block5 = EncoderBlock(in_channels=256, out_channels=512, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.encoder_block6 = EncoderBlock(in_channels=512, out_channels=1024, 
            downsample=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.conv_block7 = ConvBlock(in_channels=1024, out_channels=2048, 
            size=3, activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.decoder_block1 = DecoderBlock(in_channels=2048, out_channels=1024, 
            stride=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.decoder_block2 = DecoderBlock(in_channels=1024, out_channels=512, 
            stride=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.decoder_block3 = DecoderBlock(in_channels=512, out_channels=256, 
            stride=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.decoder_block4 = DecoderBlock(in_channels=256, out_channels=128, 
            stride=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.decoder_block5 = DecoderBlock(in_channels=128, out_channels=64, 
            stride=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)
        self.decoder_block6 = DecoderBlock(in_channels=64, out_channels=32, 
            stride=(2, 2), activation=activation, momentum=momentum, classes_num = config.latent_dim)

        self.after_conv_block1 = ConvBlock(in_channels=32, out_channels=32, 
            size=3, activation=activation, momentum=momentum, classes_num = config.latent_dim)

        self.after_conv2 = nn.Conv2d(in_channels=32, out_channels=channels, 
            kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn0)
        init_layer(self.after_conv2)

    def spectrogram(self, input):
        (real, imag) = self.stft(input)
        return (real ** 2 + imag ** 2) ** 0.5

    def wav_to_spectrogram(self, input):
        """Waveform to spectrogram.

        Args:
          input: (batch_size, segment_samples, channels_num)

        Outputs:
          output: (batch_size, channels_num, time_steps, freq_bins)
        """
        sp_list = []
        #####
        # channels_num = input.shape[2]
        # for channel in range(channels_num):
        #     sp_list.append(self.spectrogram(input[:, :, channel]))
        #####
        for _ in range(self.channels):
            sp_list.append(self.spectrogram(input[:,:,0]))

        output = torch.cat(sp_list, dim=1)
        return output


    def spectrogram_to_wav(self, input, spectrogram, length=None):
        """Spectrogram to waveform.

        Args:
          input: (batch_size, segment_samples, channels_num)
          spectrogram: (batch_size, channels_num, time_steps, freq_bins)

        Outputs:
          output: (batch_size, segment_samples, channels_num)
        """
        channels_num = input.shape[2]
        wav_list = []
        for channel in range(channels_num):
            (real, imag) = self.stft(input[:, :, channel])
            (_, cos, sin) = magphase(real, imag)
            wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos, 
                spectrogram[:, channel : channel + 1, :, :] * sin, length))
        
        output = torch.stack(wav_list, dim=2)
        return output

    def forward(self, input, condition):
        """
        Args:
          input: (batch_size, segment_samples, channels_num)

        Outputs:
          output_dict: {
            'wav': (batch_size, segment_samples, channels_num),
            'sp': (batch_size, channels_num, time_steps, freq_bins)}
        """
        sp = self.wav_to_spectrogram(input)    
        """(batch_size, channels_num, time_steps, freq_bins)"""

        # Batch normalization
        x = sp.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        """(batch_size, chanenls, time_steps, freq_bins)"""

        # Pad spectrogram to be evenly divided by downsample ratio.
        origin_len = x.shape[2]
        pad_len = int(np.ceil(x.shape[2] / self.downsample_ratio)) \
            * self.downsample_ratio - origin_len
        x = F.pad(x, pad=(0, 0, 0, pad_len))
        """(batch_size, channels, padded_time_steps, freq_bins)"""

        # Let frequency bins be evenly divided by 2, e.g., 513 -> 512
        x = x[..., 0 : x.shape[-1] - 1]     # (bs, channels, T, F)

        # UNet
        (x1_pool, x1) = self.encoder_block1(x, condition)  # x1_pool: (bs, 32, T / 2, F / 2)
        (x2_pool, x2) = self.encoder_block2(x1_pool, condition)    # x2_pool: (bs, 64, T / 4, F / 4)
        (x3_pool, x3) = self.encoder_block3(x2_pool, condition)    # x3_pool: (bs, 128, T / 8, F / 8)
        (x4_pool, x4) = self.encoder_block4(x3_pool, condition)    # x4_pool: (bs, 256, T / 16, F / 16)
        (x5_pool, x5) = self.encoder_block5(x4_pool, condition)    # x5_pool: (bs, 512, T / 32, F / 32)
        (x6_pool, x6) = self.encoder_block6(x5_pool, condition)    # x6_pool: (bs, 1024, T / 64, F / 64)
        x_center = self.conv_block7(x6_pool, condition)    # (bs, 2048, T / 64, F / 64)
        x7 = self.decoder_block1(x_center, x6, condition)  # (bs, 1024, T / 32, F / 32)
        x8 = self.decoder_block2(x7, x5, condition)    # (bs, 512, T / 16, F / 16)
        x9 = self.decoder_block3(x8, x4, condition)    # (bs, 256, T / 8, F / 8)
        x10 = self.decoder_block4(x9, x3, condition)   # (bs, 128, T / 4, F / 4)
        x11 = self.decoder_block5(x10, x2, condition)  # (bs, 64, T / 2, F / 2)
        x12 = self.decoder_block6(x11, x1, condition)  # (bs, 32, T, F)
        x = self.after_conv_block1(x12, condition)     # (bs, 32, T, F)
        x = self.after_conv2(x)             # (bs, channels, T, F)

        # Recover shape
        x = F.pad(x, pad=(0, 1))
        x = x[:, :, 0 : origin_len, :]

        sp_out = torch.sigmoid(x) * sp

        # Spectrogram to wav
        length = input.shape[1]
        wav_out = self.spectrogram_to_wav(input, sp_out, length)

        output_dict = {"wav": wav_out, "sp": sp_out}
        return output_dict

    def get_new_indexes(self, x):
        indexes = [*range(x.shape[0])]
        return indexes
    def get_auto_tagging(self, data):
        waveforms = data["waveform"]
        class_ids = data["class_id"]
        audio_num = len(waveforms)
        at_waveforms = []
        output_dicts = self.at_model.inference(waveforms) # B, T, C
        sed_vectors = output_dicts["framewise_output"]
        for i in range(audio_num):
            # obtain the sed_vector
            sed_vector = np.convolve(
                sed_vectors[i, :, class_ids[i]], np.ones(self.config.segment_frames),
                mode = "same"
            )
            anchor_index = math.floor(np.argmax(sed_vector) * self.config.clip_samples / self.config.hop_samples / 1024)
            (bgn_sample, end_sample) = get_segment_bgn_end_samples(
                anchor_index, self.config.segment_frames,
                self.config.hop_samples, self.config.clip_samples
            )
            at_waveforms.append(waveforms[i, bgn_sample: end_sample])
        at_waveforms = np.array(at_waveforms)
        output_dicts = self.at_model.inference(at_waveforms)
        at_vectors = output_dicts["latent_output"]
        return at_waveforms, at_vectors


    def combine_batch(self, x, y):
        xy = []
        assert len(x) == len(y), "two combined batches should be in the same length"
        for i in range(len(x)):
            xy += [x[i], y[i]]
        return np.array(xy)
    
    def training_step(self, batch, batch_idx):
        self.train()
        self.device_type = next(self.parameters()).device
        if not self.check_flag:
            self.check_flag = True

        combine_batch = {}
        combine_batch["class_id"] = self.combine_batch(batch["class_id_1"], batch["class_id_2"])
        combine_batch["waveform"] = self.combine_batch(batch["waveform_1"], batch["waveform_2"])
        
        # laten embedding from the sound event detection/auto tagging system
        at_waveforms, at_vectors = self.get_auto_tagging(combine_batch)
        tmp = np.zeros_like(at_vectors) # [batch, classes_num]

        indexes = self.get_new_indexes(tmp)
        if self.config.using_whiting:
            at_vectors = (at_vectors + self.whiting_bias).dot(self.whiting_kernel)
            at_vectors = at_vectors[:,:self.config.latent_dim]

        # define input data by mixing
        mixtures, sources, conditions, _ = get_mix_data(
            at_waveforms, at_vectors, combine_batch["class_id"], indexes,
            mix_type = "mixture"
        )
        if len(mixtures) > 0:
            # conver to tensor
            mixtures = np_to_pytorch(np.array(mixtures)[:, :, None], self.device_type)
            sources = np_to_pytorch(np.array(sources)[:, :, None], self.device_type)
            conditions = np_to_pytorch(np.array(conditions), self.device_type)
            # train
            batch_output_dict = self(mixtures, conditions)
            loss = self.loss_func(batch_output_dict["wav"], sources)
            return loss
        else:
            return None
    def training_epoch_end(self, outputs):
        self.dataset.generate_queue()
        self.check_flag = False

    def validation_step(self, batch, batch_idx):
        mixture_sdr = []
        clean_sdr = []
        silence_sdr = []
        self.device_type = next(self.parameters()).device
        combine_batch = {}
        combine_batch["class_id"] = self.combine_batch(batch["class_id_1"], batch["class_id_2"])
        combine_batch["waveform"] = self.combine_batch(batch["waveform_1"], batch["waveform_2"])
        
        # laten embedding from the sound event detection/auto tagging system
        at_waveforms, at_vectors = self.get_auto_tagging(combine_batch)
        tmp = np.zeros_like(at_vectors) # [batch, classes_num]
        # new un-conflict indexes 
        indexes = self.get_new_indexes(tmp)
        if self.config.using_whiting:
            at_vectors = (at_vectors + self.whiting_bias).dot(self.whiting_kernel)
            at_vectors = at_vectors[:,:self.config.latent_dim]

        # define mixture data
        mixtures, sources, conditions, gds = get_mix_data(
            at_waveforms, at_vectors, combine_batch["class_id"], indexes,
            mix_type = "mixture"
        )
        if len(mixtures) > 0:
            # conver to tensor
            mixtures = np_to_pytorch(np.array(mixtures)[:, :, None], self.device_type)
            sources = np_to_pytorch(np.array(sources)[:, :, None], self.device_type)
            conditions = np_to_pytorch(np.array(conditions), self.device_type)
            gds = np_to_pytorch(np.array(gds), self.device_type)
            # train
            batch_output_dict = self(mixtures, conditions)
            preds = batch_output_dict["wav"]
        
            mixture_sdr = evaluate_sdr(
                ref = sources.data.cpu().numpy(), 
                est = preds.data.cpu().numpy(),
                class_ids = gds.data.cpu().numpy(),
                mix_type = "mixture"
            )
            
        # define clean data
        mixtures, sources, conditions, gds = get_mix_data(
            at_waveforms, at_vectors, combine_batch["class_id"], indexes,
            mix_type = "clean"
        )
        if len(mixtures) > 0:
            # conver to tensor
            mixtures = np_to_pytorch(np.array(mixtures)[:, :, None], self.device_type)
            sources = np_to_pytorch(np.array(sources)[:, :, None], self.device_type)
            conditions = np_to_pytorch(np.array(conditions), self.device_type)
            gds = np_to_pytorch(np.array(gds), self.device_type)
            # train
            batch_output_dict = self(mixtures, conditions)
            preds = batch_output_dict["wav"]
        
            clean_sdr = evaluate_sdr(
                ref = sources.data.cpu().numpy(), 
                est = preds.data.cpu().numpy(),
                class_ids = gds.data.cpu().numpy(),
                mix_type = "clean"
            )   
        # define mixture data
        mixtures, sources, conditions, gds = get_mix_data(
            at_waveforms, at_vectors, combine_batch["class_id"], indexes,
            mix_type = "silence"
        )
        if len(mixtures) > 0:
            # conver to tensor
            mixtures = np_to_pytorch(np.array(mixtures)[:, :, None], self.device_type)
            sources = np_to_pytorch(np.array(sources)[:, :, None], self.device_type)
            conditions = np_to_pytorch(np.array(conditions), self.device_type)
            gds = np_to_pytorch(np.array(gds), self.device_type)
            # train
            batch_output_dict = self(mixtures, conditions)
            preds = batch_output_dict["wav"]
            silence_sdr = evaluate_sdr(
                ref = mixtures.data.cpu().numpy(), 
                est = preds.data.cpu().numpy(),
                class_ids = gds.data.cpu().numpy(),
                mix_type = "silence"
            )
        return {"mixture": mixture_sdr, "clean": clean_sdr, "silence": silence_sdr}

    def validation_epoch_end(self, validation_step_outputs):
        self.device_type = next(self.parameters()).device
        mixture_sdr = []
        clean_sdr = []
        silence_sdr = []
        for d in validation_step_outputs:
            mixture_sdr += [dd[0] for dd in d["mixture"]]
            clean_sdr += [dd[0] for dd in d["clean"]]
            silence_sdr += [dd[0] for dd in d["silence"]]
        mixture_sdr = np.mean(np.array(mixture_sdr))
        clean_sdr = np.mean(np.array(clean_sdr))
        silence_sdr = np.mean(np.array(silence_sdr))
        
        self.log("mixture_sdr", mixture_sdr, on_epoch = True, prog_bar=True, sync_dist=True)
        self.log("clean_sdr", clean_sdr, on_epoch = True, prog_bar=True, sync_dist=True)
        self.log("silence_sdr", silence_sdr, on_epoch = True, prog_bar=True, sync_dist=True)
        
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, test_step_outputs):
        self.validation_epoch_end(test_step_outputs)             

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr = self.config.learning_rate, 
            betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0., amsgrad = True
        )
        def lr_foo(epoch):       
            if epoch < 3:
                # warm up lr
                lr_scale = 0.1 ** (3 - epoch)
            else:
                lr_scale = 0.1 ** (bisect.bisect_left(self.config.lr_scheduler_epoch, epoch))

            return lr_scale
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lr_foo
        )

        return [optimizer], [scheduler]



  from .autonotebook import tqdm as notebook_tqdm


OSError: dlopen(/Users/cooky/miniforge3/envs/cid2rch/lib/python3.10/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at14RecordFunctionC1ENS_11RecordScopeEb
  Referenced from: <E741B6D5-E348-3601-ACC9-BC3101AD112C> /Users/cooky/miniforge3/envs/cid2rch/lib/python3.10/site-packages/torchaudio/lib/libtorchaudio.so
  Expected in:     <AAE88793-2D9D-3CCA-96C4-EAC30CEA4202> /Users/cooky/miniforge3/envs/cid2rch/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib