In [None]:
import harmonic_conv
import util
import torch
from torch.nn.modules.activation import PReLU, ReLU
import torch.optim
import torch.nn as nn
import numpy as np

# Network

In [None]:
class Harmonic_Net(nn.Module):
    # HC: Harmonic Convolution
    # RC: Regular Convolution
    """
                                                                  [35-RC->out]
    d0# [in->in in->35 ]                               [70->70 70->35]
    d1#              d                                  u
                    35-----------skip------------------>70
                   [35->35 35->70 ]      [140->140 140->35]
    d2#                         d          u   
                               70-skip--->140     
                              [70->70 70->70]
                                       
    <------------------------- HC ---------------------------->
    """

    """
    Step(a,b): #HC
        [a->a b->b]

    Down(a,b): # HC
         ↓ down
        [a->a a->b]
                 
    DownUp(a,b): # HC
    down ↓       ↑ up
        [a->a b->b]
    
    Up(a,b): # HC
                 ↑ up
        [a->a a->b]
    
    Last(a,b): #RC
        [a-RC->b]
    """
    def __init__(self,input_channel=1,kernel_size=3,conv_type=['HC']*5):
        print(conv_type)
        super().__init__()
        self.step1 = Step(input_channel,35,kernel_size,conv_type[0])
        self.step2 = Step(35,70, kernel_size,conv_type[1])
        self.step3 = Step(70,70,kernel_size,conv_type[2])
        self.step4 = Step(140,35, kernel_size,conv_type[3])
        self.step5 = Step(70,35, kernel_size,conv_type[4])
        self.last = Last(35,input_channel) # 1
        self.down = Down()
        self.up = Up()
    def forward(self, x): # Batch, Channel, Freq.(=8*a), Time(=8*b)
        x = self.step1(x)                   # 1,input,4a,4b -> 1,35,4a,4b
        xd1 = self.down(x)                  # 1,35,4a,4b -> 1,35,2a,2b
        x = self.step2(xd1)                 # 1,35,2a,2b -> 1,70,2a,2b
        xd2 = self.down(x)                  # 1,70,2a,2b -> 1,70,a,b
        x = self.step3(xd2)                 # 1,70,a,b -> 1,70,a,b
        x = torch.cat((x,xd2),1)            # 1,70,a,b -> 1,140,a,b
        x = self.up(x)                      # 1,140,a,b -> 1,140,2a,2b
        x = self.step4(x)                   # 1,140,2a,2b -> 1,35,2a,2b
        x = torch.cat((x,xd1),1)            # 1,35,2a,2b -> 1,70,2a,2b
        x = self.up(x)                      # 1,70,2a,2b -> 1,70,4a,4b
        x = self.step5(x)                   # 1,70,4a,4b -> 1,35,4a,4b
        x = self.last(x)                    # 1,35,4a,4b -> 1,input,4a,4b

        return x
# HC base
class Step(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, conv_type='HC'):
        super().__init__()
        time_pad = int((kernel_size-1)/2)
        if conv_type=='HC':
            self.step = nn.Sequential(
                harmonic_conv.SingleHarmonicConv2d(in_channels, in_channels, kernel_size=kernel_size, anchor=kernel_size, padding=(0,time_pad), padding_mode='zeros'),
                nn.InstanceNorm2d(in_channels),
                nn.ReLU(),
                harmonic_conv.SingleHarmonicConv2d(in_channels, out_channels, kernel_size=kernel_size, anchor=kernel_size, padding=(0,time_pad), padding_mode='zeros'),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU()
            )
        elif conv_type=='RC':
            self.step = nn.Sequential(
                nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=time_pad, padding_mode='zeros'),
                nn.InstanceNorm2d(in_channels),
                nn.ReLU(),
                nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=time_pad, padding_mode='zeros'),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU()
            )
        self.step.apply(init_weights)
    def forward(self, x):
        return self.step(x)

class Down(nn.Module):
    def __init__(self):
        super().__init__()
        self.down = nn.Sequential(
            nn.AvgPool2d(kernel_size=(2,2), stride=(2,2))
        )
        self.down.apply(init_weights)
    def forward(self, x):
        return self.down(x)


class Up(nn.Module):
    def __init__(self):
        super().__init__()
        self.up = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=(2,2))
        )
        self.up.apply(init_weights)
    def forward(self, x):
        return self.up(x)

class Last(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # time_pad = (kernel_size-1)/2
        self.last = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )
        self.last.apply(init_weights)
    def forward(self, x):
        return self.last(x)


class Skip(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        time_pad = int((kernel_size-1)/2)
        self.step = nn.Sequential(
            harmonic_conv.SingleHarmonicConv2d(in_channels, out_channels, kernel_size=kernel_size, anchor=kernel_size, padding=(0,time_pad), padding_mode='zeros'),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU()
        )
        self.step.apply(init_weights)
    def forward(self, x):
        return self.step(x)

def init_weights(m):
    torch.manual_seed(1)
    if type(m) == nn.Conv2d:
        nn.init.normal_(m.weight, mean=0, std=0.01)
    if type(m) == harmonic_conv.SingleHarmonicConv2d:
        nn.init.normal_(m.lowered_weight, mean=0, std=0.01)

# Train class

In [None]:
class Train:
    def __init__(self) -> None:
        pass

    def target_set(
        self, 
        s_tn, 
        n_tn, 
        snr, 
        x_tn=None,
        fftl=512,
        shift=128
    ) -> None:
        """
        s: clean, n: noise, x: noisy
        ~_t: time domain, ~_f: Time-Frequency domain, ~_A: Time-Frequency domain (Amplitude)
        ~_~n: numpy, ~_~t: tensor
        """
        if type(s_tn) is str: self.s_tn,_,_ = util.wavread(s_tn)
        else: self.s_tn = s_tn
        if type(n_tn) is str: self.n_tn,_,_ = util.wavread(n_tn)
        else: self.n_tn = n_tn
        if x_tn is None: 
            self.x_tn, self.n_tn = util.create_mixture(self.s_tn, self.n_tn, snr)
        else: self.x_tn = x_tn

        self.x_fn = util.stft(self.x_tn, fftl=fftl, shift=shift)
        self.x_An = np.abs(self.x_fn)
        self.x_An[self.x_An==0] = np.spacing(1)

        self.fftl = fftl
        self.shift = shift

    def get_s(
        self,
        domain='time',
        fftl=None,
        shift=None
    ) -> np.ndarray:
        if domain == 'time':
            return self.s_tn
        if domain == 'freq':
            if fftl is None:
                fftl = self.fftl
            if shift is None:
                shift = self.shift
            return util.stft(self.s_tn, fftl=fftl, shift=shift)
        else:
            assert(False), f"domain must be 'time' or 'freq'."
    
    def get_n(
        self,
        domain='time',
        fftl=512,
        shift=128
    ) -> np.ndarray:
        if domain == 'time':
            return self.n_tn
        if domain == 'freq':
            return util.stft(self.n_tn, fftl=fftl, shift=shift)
        else:
            assert(False), f"domain must be 'time' or 'freq'."
    
    def train_setting(
        self,
        device='cuda:0',
        isrelu=False,
        input_channel=1,
        lr = 0.05,
        kernel_size=3,
        conv_type=['HC']*7
    ) -> None:
        torch.manual_seed(1234)
        self.model = Harmonic_Net(input_channel=input_channel, kernel_size=kernel_size, conv_type=conv_type).to(device)
        self.device = device

        params = []
        params += [x for x in self.model.parameters()]
        self.optimizer = torch.optim.Adam(params, lr=lr)
    
    def Loss(self, x1,x2,p=2.0):
        lp_loss = torch.sum( torch.pow( torch.abs(x1-x2)+1.0e-6,p) )
        return lp_loss
    
    def train_xy(
        self,
        iter,
        sum_iter,
        loss_p = 1
    ):
        if sum_iter == 0: # 学習の初めの前処理
            self.x_realn = np.real(self.x_fn)
            self.x_realt = util.np_to_torch(self.x_realn).float().to(self.device) # [time, freq]
            self.x_realt = torch.unsqueeze( torch.unsqueeze(self.x_realt, 0),0 ) # [batch=1, channel=1, time, freq]
            self.x_cpxn = np.imag(self.x_fn)
            self.x_cpxt = util.np_to_torch(self.x_cpxn).float().to(self.device) # [time, freq]
            self.x_cpxt = torch.unsqueeze( torch.unsqueeze(self.x_cpxt, 0),0 ) # [batch=1, channel=1, time, freq]
            self.x_target = torch.cat((self.x_realt, self.x_cpxt),1)
            self.x_target, self.pad1, self.pad2 = util.pad32(self.x_target) # [1,1, time+pad1, freq+pad2]
            self.x_target = self.x_target[:,:,:,self.pad2:] # [1,1,time+pad1,freq] HCではfreq binを変えてはいけない
            self.pad2 = 0
            self.x_target = self.x_target.permute(0,1,3,2) # [1,1,freq,time+pad1]
            self.x_target_cut = self.x_target[:,:,1:,:] # Frequencyの直流成分をカット
            self.net_input = torch.rand(size=self.x_target_cut.shape).float().to(self.device) * 0.1

        # 学習
        for i in range(iter+1):
            sum_iter += 1
            print('\r%d回目' %(sum_iter), end='')

            self.optimizer.zero_grad()
            out_2c = self.model(self.net_input)
            loss_num = self.Loss(out_2c, self.x_target_cut, loss_p)
            loss_num.backward()
            self.optimizer.step()

        # 後処理
        sum_iter -= 1
        out_2c = torch.cat((torch.unsqueeze(self.x_target[:,:,0,:],2),out_2c),2)
        out_2c = out_2c.permute(1,0,3,2) # [1,1,time+pad1,freq]
        out_2c = util.torch_to_np(out_2c) # [1,1,time+pad1,freq]
        out_2c = out_2c[:,:,self.pad1:,self.pad2:] # [time,freq]
        out_fn = out_2c[0,0]+1j*out_2c[1,0]
        out_tn = util.istft(np.abs(out_fn)*np.exp(1j*np.angle(out_fn)), x_len=len(self.x_tn), fftl=self.fftl, shift = self.shift)
        return out_tn, sum_iter, loss_num

# Train

In [None]:
import matplotlib.pyplot as plt

dp = Train()

s_tn,sr,subtype = util.wavread('LJ037-0171.wav') # LJspeech -> https://keithito.com/LJ-Speech-Dataset/
noise = np.random.normal(0.0,0.1,size=len(s_tn))
dp.target_set(
    s_tn=s_tn,
    n_tn=noise,
    snr=10,
    fftl=1024,
    shift=64
)

dp.train_setting(
    input_channel=2,
    lr=0.001,
    kernel_size=3 # Note: the original paper's kernel_size and anchor is 7. (the parameter kernelsize=3, anchor=3 now)
)
sdr_ = round(util.sisdr(dp.get_s(), dp.x_tn),4)
util.specshow(dp.x_tn, title=f'Noisy\nsi-sdr: {sdr_}',sr=sr,fftl=1024,shift=64)

sum_iter = 0
for i in [1,10,100,1000,2000,5000,10000]:
    iter = i-sum_iter
    out_tn,sum_iter,_ = dp.train_xy(iter, sum_iter, loss_p=2)
    sdr_ = round(util.sisdr(dp.get_s(), out_tn),4)
    util.specshow(out_tn, title=f'{sum_iter}\nsi-sdr: {sdr_}',fftl=dp.fftl,shift=dp.shift)
    util.wavwrite(f'.wav/{i}.wav', out_tn, sr, subtype)