In [1]:
# import os
cwd = os.getcwd()
import sys
path = os.path.join(cwd, "..\\..\\")
sys.path.append(path)

In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import ISTFT, STFT, magphase

import pytorch_lightning
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

import logging
import warnings
logging.getLogger('lightning').setLevel(0)
warnings.filterwarnings('ignore')
pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR)

from splearn.data import MultipleSubjects, PyTorchDataset, PyTorchDataset2Views, HSSSVEP
from splearn.filter.butterworth import butter_bandpass_filter
from splearn.filter.notch import notch_filter
from splearn.filter.channels import pick_channels
from splearn.utils import Logger, Config


In [3]:
config = {
    "run_name": "ssl_hsssvep",
    "data": {
        "load_subject_ids": np.arange(1,36),
        "selected_channels": ["PO8", "PZ", "PO7", "PO4", "POz", "PO3", "O2", "Oz", "O1"],
        "input_channels": 9,
        "target_sources_num": 40,
        "sample_length": 250,
        "num_classes": 40
    },
    "training": {
        "num_epochs": 100,
        "num_warmup_epochs": 10,
        "learning_rate": 0.03,
        # "gpus": torch.cuda.device_count(),
        "gpus": [0],
        "batchsize": 256
    },
    "model": {
        "projection_size": 1024,
        "optimizer": "adamw",
        "scheduler": "cosine_with_warmup",
    },
    "testing": {
        "test_subject_ids": np.arange(33,34),
        "kfolds": np.arange(0,3),
    },
    "seed": 1234
}

main_logger = Logger(filename_postfix=config["run_name"])
main_logger.write_to_log("Config")
main_logger.write_to_log(config)

config = Config(config)

seed_everything(config.seed)

Global seed set to 1234


1234

In [4]:
def onehot_targets(targets):
    return (np.arange(targets.max()+1) == targets[...,None]).astype(int)


def func_preprocessing(data):
    data_x = data.data
    # selected_channels = ['P7','P3','PZ','P4','P8','O1','Oz','O2','P1','P2','POz','PO3','PO4']
    selected_channels = config.data.selected_channels
    data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=selected_channels)
    # data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0)
    data_x = butter_bandpass_filter(data_x, lowcut=4, highcut=75, sampling_rate=data.sampling_rate, order=6)
    start_t = 125
    end_t = 125 + 250
    data_x = data_x[:,:,:,start_t:end_t]
    data.set_data(data_x)


def leave_one_subject_out(data, **kwargs):
    
    test_subject_id = kwargs["test_subject_id"] if "test_subject_id" in kwargs else 1
    kfold_k = kwargs["kfold_k"] if "kfold_k" in kwargs else 0
    kfold_split = kwargs["kfold_split"] if "kfold_split" in kwargs else 3
    
    # get test data
    # test_sub_idx = data.subject_ids.index(test_subject_id)
    test_sub_idx = np.where(data.subject_ids == test_subject_id)[0][0]
    selected_subject_data = data.data[test_sub_idx]
    selected_subject_targets = data.targets[test_sub_idx]
    # selected_subject_targets = onehot_targets(selected_subject_targets)
    test_dataset = PyTorchDataset(selected_subject_data, selected_subject_targets)
    # num_targets = selected_subject_targets.shape[1]

    # get train val data
    indices = np.arange(data.data.shape[0])
    train_val_data = data.data[indices!=test_sub_idx, :, :, :]
    
    train_val_data = train_val_data.reshape((train_val_data.shape[0]*train_val_data.shape[1], train_val_data.shape[2], train_val_data.shape[3]))
    train_val_targets = data.targets[indices!=test_sub_idx, :]
    train_val_targets = train_val_targets.reshape((train_val_targets.shape[0]*train_val_targets.shape[1]))
    
    # train test split
    (X_train, y_train), (X_val, y_val) = data.dataset_split_stratified(train_val_data, train_val_targets, k=kfold_k, n_splits=kfold_split)
    # y_train = onehot_targets(y_train)
    # y_val = onehot_targets(y_val)
    # print("X_train.shape, X_val.shape", X_train.shape, X_val.shape, y_train.shape, y_val.shape)
    
    # create dataset
    train_dataset = PyTorchDataset(X_train, y_train)
    val_dataset = PyTorchDataset(X_val, y_val)

    return train_dataset, val_dataset, test_dataset

data = MultipleSubjects(
    dataset=HSSSVEP, 
    root=os.path.join(path, "../data/hsssvep"), 
    subject_ids=config.data.load_subject_ids, 
    func_preprocessing=func_preprocessing,
    func_get_train_val_test_dataset=leave_one_subject_out,
    verbose=True, 
)

print("Final data shape:", data.data.shape)

test_subject_id = 33
kfold_k = 0

train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k)
train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False)

print("train_loader", train_loader.dataset.data.shape, train_loader.dataset.targets.shape)
print("val_loader", val_loader.dataset.data.shape, val_loader.dataset.targets.shape)
print("test_loader", test_loader.dataset.data.shape, test_loader.dataset.targets.shape)

Load subject: 1
Load subject: 2
Load subject: 3
Load subject: 4
Load subject: 5
Load subject: 6
Load subject: 7
Load subject: 8
Load subject: 9
Load subject: 10
Load subject: 11
Load subject: 12
Load subject: 13
Load subject: 14
Load subject: 15
Load subject: 16
Load subject: 17
Load subject: 18
Load subject: 19
Load subject: 20
Load subject: 21
Load subject: 22
Load subject: 23
Load subject: 24
Load subject: 25
Load subject: 26
Load subject: 27
Load subject: 28
Load subject: 29
Load subject: 30
Load subject: 31
Load subject: 32
Load subject: 33
Load subject: 34
Load subject: 35
Final data shape: (35, 240, 9, 250)
train_loader (5440, 9, 250) (5440,)
val_loader (2720, 9, 250) (2720,)
test_loader (240, 9, 250) (240,)


In [5]:
    
# class ResUNet143_Subbandtime(nn.Module, Base):
#     def __init__(self, input_channels, target_sources_num):
#         super(ResUNet143_Subbandtime, self).__init__()
        
#         self.input_channels = input_channels
#         self.target_sources_num = target_sources_num

#         window_size = 64
#         hop_size = 25
#         center = True
#         pad_mode = "reflect"
#         window = "hann"
#         activation = "leaky_relu"
#         momentum = 0.01

#         self.subbands_num = 1 # 4
#         self.K = 4  # outputs: |M|, cos∠M, sin∠M, Q

#         self.downsample_ratio = 2 ** 3 # 5  # This number equals 2^{#encoder_blcoks}

#         self.stft = STFT(
#             n_fft=window_size,
#             hop_length=hop_size,
#             win_length=window_size,
#             window=window,
#             center=center,
#             pad_mode=pad_mode,
#             freeze_parameters=True,
#         )

#         self.istft = ISTFT(
#             n_fft=window_size,
#             hop_length=hop_size,
#             win_length=window_size,
#             window=window,
#             center=center,
#             pad_mode=pad_mode,
#             freeze_parameters=True,
#         )
        
#         self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
        
#         self.encoder_block1 = EncoderBlockRes4B(
#             in_channels=input_channels,
#             out_channels=32,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
        
#         self.encoder_block2 = EncoderBlockRes4B(
#             in_channels=32,
#             out_channels=64,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.encoder_block3 = EncoderBlockRes4B(
#             in_channels=64,
#             out_channels=128,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.encoder_block4 = EncoderBlockRes4B(
#             in_channels=128,
#             out_channels=256,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.encoder_block5 = EncoderBlockRes4B(
#             in_channels=256,
#             out_channels=384,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.encoder_block6 = EncoderBlockRes4B(
#             in_channels=384,
#             out_channels=384,
#             kernel_size=(3, 3),
#             downsample=(1, 2),
#             activation=activation,
#             momentum=momentum,
#         )
        
#         conv_block_in_channels = 128 # 384
        
#         self.conv_block7a = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.conv_block7b = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.conv_block7c = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.conv_block7d = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
        
#         self.decoder_block1 = DecoderBlockRes4B(
#             in_channels=384,
#             out_channels=384,
#             kernel_size=(3, 3),
#             upsample=(1, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block2 = DecoderBlockRes4B(
#             in_channels=384,
#             out_channels=384,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block3 = DecoderBlockRes4B(
#             in_channels=384,
#             out_channels=256,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block4 = DecoderBlockRes4B(
#             in_channels=128,
#             out_channels=128,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block5 = DecoderBlockRes4B(
#             in_channels=128,
#             out_channels=64,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block6 = DecoderBlockRes4B(
#             in_channels=64,
#             out_channels=32,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )

#         self.after_conv_block1 = EncoderBlockRes4B(
#             in_channels=32,
#             out_channels=32,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )

#         self.after_conv2 = nn.Conv2d(
#             in_channels=32,
#             out_channels=target_sources_num
#             * input_channels
#             * self.K
#             * self.subbands_num,
#             kernel_size=(1, 1),
#             stride=(1, 1),
#             padding=(0, 0),
#             bias=True,
#         )
        
#         self.out_conv_block = EncoderBlockRes4B(
#             in_channels=target_sources_num
#             * input_channels
#             * self.subbands_num,
#             out_channels=target_sources_num,
#             kernel_size=(1, 1),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )

#         self.init_weights()
        
#     def init_weights(self):
#         init_bn(self.bn0)
#         init_layer(self.after_conv2)
        
#     def feature_maps_to_wav(
#         self,
#         input_tensor: torch.Tensor,
#         sp: torch.Tensor,
#         sin_in: torch.Tensor,
#         cos_in: torch.Tensor,
#         audio_length: int,
#     ) -> torch.Tensor:
#         r"""Convert feature maps to waveform.
#         Args:
#             input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
#             sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
#             sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
#             cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
#         Outputs:
#             waveform: (batch_size, target_sources_num * input_channels, segment_samples)
#         """
#         batch_size, _, time_steps, freq_bins = input_tensor.shape

#         x = input_tensor.reshape(
#             batch_size,
#             self.target_sources_num,
#             self.input_channels,
#             self.K,
#             time_steps,
#             freq_bins,
#         )
#         # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)

#         mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
#         _mask_real = torch.tanh(x[:, :, :, 1, :, :])
#         _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
#         linear_mag = torch.tanh(x[:, :, :, 3, :, :])
#         _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
#         # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

#         # Y = |Y|cos∠Y + j|Y|sin∠Y
#         #   = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
#         #   = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
#         out_cos = (
#             cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
#         )
#         out_sin = (
#             sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
#         )
#         # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
#         # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

#         # Calculate |Y|.
#         out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
#         # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

#         # Calculate Y_{real} and Y_{imag} for ISTFT.
#         out_real = out_mag * out_cos
#         out_imag = out_mag * out_sin
#         # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

#         # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
#         shape = (
#             batch_size * self.target_sources_num * self.input_channels,
#             1,
#             time_steps,
#             freq_bins,
#         )
#         out_real = out_real.reshape(shape)
#         out_imag = out_imag.reshape(shape)

#         # ISTFT.
#         x = self.istft(out_real, out_imag, audio_length)
#         # (batch_size * target_sources_num * input_channels, segments_num)

#         # Reshape.
#         waveform = x.reshape(
#             batch_size, self.target_sources_num * self.input_channels, audio_length
#         )
#         # (batch_size, target_sources_num * input_channels, segments_num)

#         return waveform
        
#     def forward(self, x):
        
#         subband_x = x

#         mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
#         # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
        
        
        
#         # Batch normalize on individual frequency bins.
#         x = mag.transpose(1, 3)
#         x = self.bn0(x)
#         x = x.transpose(1, 3)
#         # (batch_size, input_channels * subbands_num, time_steps, freq_bins)
        
#         # Pad spectrogram to be evenly divided by downsample ratio.
#         origin_len = x.shape[2]
#         pad_len = (
#             int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
#             - origin_len
#         )
#         x = F.pad(x, pad=(0, 0, 0, pad_len))
#         # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
        
#         # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
#         x = x[..., 0 : x.shape[-1] - 1]  # (bs, input_channels, T, F)
#         # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
        
#         # UNet
#         print("x", x.shape)
#         (x1_pool, x1) = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
#         # print(x1_pool.shape, x1.shape)
#         (x2_pool, x2) = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
#         # print(x2_pool.shape, x2.shape)
#         (x3_pool, x3) = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
#         # print(x3_pool.shape, x3.shape)
#         # (x4_pool, x4) = self.encoder_block4(x3_pool)  # x4_pool: (bs, 256, T / 16, F / 16)
#         # (x5_pool, x5) = self.encoder_block5(x4_pool)  # x5_pool: (bs, 384, T / 32, F / 32)
#         # (x6_pool, x6) = self.encoder_block6(x5_pool)  # x6_pool: (bs, 384, T / 32, F / 64)
#         (x_center, _) = self.conv_block7a(x3_pool)  # (bs, 384, T / 32, F / 64)
#         (x_center, _) = self.conv_block7b(x_center)  # (bs, 384, T / 32, F / 64)
#         # (x_center, _) = self.conv_block7c(x_center)  # (bs, 384, T / 32, F / 64)
#         # (x_center, _) = self.conv_block7d(x_center)  # (bs, 384, T / 32, F / 64)
#         # x7 = self.decoder_block1(x_center, x6)  # (bs, 384, T / 32, F / 32)
#         # x8 = self.decoder_block2(x7, x5)  # (bs, 384, T / 16, F / 16)
#         # x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
#         # print("x_center.shape, x3.shape", x_center.shape, x3.shape)
#         x10 = self.decoder_block4(x_center, x3)  # (bs, 128, T / 4, F / 4)
#         x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
#         x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
#         print("x12", x12.shape)
        
#         (x, _) = self.after_conv_block1(x12)  # (bs, 32, T, F)
        
#         x = self.after_conv2(x)
#         # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
#         # print(33, "x.shape", x.shape)

#         # Recover shape
#         x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 256 -> 257.

#         x = x[:, :, 0:origin_len, :]
#         # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
#         print(99, x.shape)
#         audio_length = subband_x.shape[2]
        
#         # Recover each subband spectrograms to subband waveforms. Then synthesis
#         # the subband waveforms to a waveform.
#         C1 = x.shape[1] // self.subbands_num
#         C2 = mag.shape[1] // self.subbands_num

#         separated_subband_audio = torch.cat(
#             [
#                 self.feature_maps_to_wav(
#                     input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
#                     sp=mag[:, j * C2 : (j + 1) * C2, :, :],
#                     sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
#                     cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
#                     audio_length=audio_length,
#                 )
#                 for j in range(self.subbands_num)
#             ],
#             dim=1,
#         )
#         # （batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
        
#         separated_subband_audio = torch.unsqueeze(separated_subband_audio, 2)
#         (y, _) = self.out_conv_block(separated_subband_audio)
        
#         y = torch.squeeze(y, 2)
        
#         return y
        
        
        
# tmp_x = torch.rand(3, 9, 1000)
# # 
# # input_dict = {
# #     "waveform": tmp_x
# # }

# tmp_layer = ResUNet143_Subbandtime(input_channels=9, target_sources_num=10)
# tmp_y = tmp_layer(tmp_x)
# tmp_y.shape

# # torch.Size([3, 9, 10, 33]) torch.Size([3, 9, 10, 33]) torch.Size([3, 9, 10, 33])

In [6]:
def init_layer(layer: nn.Module):
    r"""Initialize a Linear or Convolutional layer."""
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)
            
def init_bn(bn: nn.Module):
    r"""Initialize a Batchnorm layer."""
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)
    bn.running_mean.data.fill_(0.0)
    bn.running_var.data.fill_(1.0)

In [7]:
# -*- coding: utf-8 -*-
"""Common 2D convolutions
"""

import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils import weight_norm
import torch.nn.functional as F
from typing import Tuple, List

from splearn.nn.modules.functional import Swish
from splearn.nn.utils import get_class_name


class Conv2d(nn.Module):
    """
    Input: 4-dim tensor
        Shape [batch, in_channels, H, W]
    Return: 4-dim tensor
        Shape [batch, out_channels, H, W]
        
    Args:
        in_channels : int
            Should match input `channel`
        out_channels : int
            Return tensor with `out_channels`
        kernel_size : int or 2-dim tuple
        stride : int or 2-dim tuple, default: 1
        padding : int or 2-dim tuple or True
            Apply `padding` if given int or 2-dim tuple. Perform TensorFlow-like 'SAME' padding if True
        dilation : int or 2-dim tuple, default: 1
        groups : int or 2-dim tuple, default: 1
        w_in: int, optional
            The size of `W` axis. If given, `w_out` is available.
    
    Usage:
        x = torch.randn(1, 22, 1, 256)
        conv1 = Conv2dSamePadding(22, 64, kernel_size=17, padding=True, w_in=256)
        y = conv1(x)
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding="SAME", dilation=1, groups=1, w_in=None, bias=True):
        super().__init__()
        
        padding = padding
        self.kernel_size = kernel_size = kernel_size
        self.stride = stride = stride
        self.dilation = dilation = dilation
        
        self.padding_same = False
        if padding == "SAME":
            self.padding_same = True
            padding = (0,0)
        
        if isinstance(padding, int):
            padding = (padding, padding)
            
        if isinstance(kernel_size, int):
            self.kernel_size = kernel_size = (kernel_size, kernel_size)
            
        if isinstance(stride, int):
            self.stride = stride = (stride, stride)
        
        if isinstance(dilation, int):
            self.dilation = dilation = (dilation, dilation)
            
        self.conv = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=0 if padding==True else padding, 
            dilation=dilation, 
            groups=groups,
            bias=bias
        )
        
        self.weight = self.conv.weight
        
        if w_in is not None:
            self.w_out = int( ((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1]-1)-1) / 1) + 1 )
        if self.padding_same == "SAME": # if SAME, then replace, w_out = w_in, obviously
            self.w_out = w_in
            
    def forward(self, x):
        if self.padding_same == True:
            x = self.pad_same(x, self.kernel_size, self.stride, self.dilation)
        return self.conv(x)
    
    # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
    def get_same_padding(self, x: int, k: int, s: int, d: int):
        return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)

    # Dynamically pad input x with 'SAME' padding for conv with specified args
    def pad_same(self, x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
        ih, iw = x.size()[-2:]
        pad_h, pad_w = self.get_same_padding(ih, k[0], s[0], d[0]), self.get_same_padding(iw, k[1], s[1], d[1])
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
        return x
    
######

class ConvBlockRes(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
        r"""Residual block."""
        super(ConvBlockRes, self).__init__()

        self.activation = activation
        
        padding = [kernel_size[0] // 2, kernel_size[1] // 2]

        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        # self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )
        
        # self.conv2 = nn.Conv2d(
        #     in_channels=out_channels,
        #     out_channels=out_channels,
        #     kernel_size=kernel_size,
        #     stride=(1, 1),
        #     dilation=(1, 1),
        #     padding=padding,
        #     bias=False,
        # )

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=(1, 1),
                padding=(0, 0),
            )
            self.is_shortcut = True
        else:
            self.is_shortcut = False

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        # init_bn(self.bn2)
        init_layer(self.conv1)
        # init_layer(self.conv2)

        if self.is_shortcut:
            init_layer(self.shortcut)

    def forward(self, x):
        origin = x
        x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
        # x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))

        if self.is_shortcut:
            x1 = self.shortcut(origin)            
            return x1 + x
        else:
            return origin + x
        
        
# in_channels=384,
# out_channels=384,


# activation = "leaky_relu"
# momentum = 0.01
# tmp_layer = ConvBlockRes(in_channels=9, out_channels=9, kernel_size=(3,3), activation=activation, momentum=momentum)
# tmp_x = torch.rand(3, 9, 240, 240)
# tmp_y = tmp_layer(tmp_x)
# tmp_y.shape

In [8]:
from typing import List#, NoReturn


class Base:
    def __init__(self):
        r"""Base function for extracting spectrogram, cos, and sin, etc."""
        pass

    def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
        r"""Calculate spectrogram.
        Args:
            input: (batch_size, segments_num)
            eps: float
        Returns:
            spectrogram: (batch_size, time_steps, freq_bins)
        """
        (real, imag) = self.stft(input)
        return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5

    def spectrogram_phase(
        self, input: torch.Tensor, eps: float = 0.0
    ) -> List[torch.Tensor]:
        r"""Calculate the magnitude, cos, and sin of the STFT of input.
        Args:
            input: (batch_size, segments_num)
            eps: float
        Returns:
            mag: (batch_size, time_steps, freq_bins)
            cos: (batch_size, time_steps, freq_bins)
            sin: (batch_size, time_steps, freq_bins)
        """
        (real, imag) = self.stft(input)
        mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
        cos = real / mag
        sin = imag / mag
        return mag, cos, sin

    def wav_to_spectrogram_phase(
        self, input: torch.Tensor, eps: float = 1e-10
    ) -> List[torch.Tensor]:
        r"""Convert waveforms to magnitude, cos, and sin of STFT.
        Args:
            input: (batch_size, channels_num, segment_samples)
            eps: float
        Outputs:
            mag: (batch_size, channels_num, time_steps, freq_bins)
            cos: (batch_size, channels_num, time_steps, freq_bins)
            sin: (batch_size, channels_num, time_steps, freq_bins)
        """
        batch_size, channels_num, segment_samples = input.shape

        # Reshape input with shapes of (n, segments_num) to meet the
        # requirements of the stft function.
        x = input.reshape(batch_size * channels_num, segment_samples)

        mag, cos, sin = self.spectrogram_phase(x, eps=eps)
        # mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)

        _, _, time_steps, freq_bins = mag.shape
        mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
        cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
        sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)

        return mag, cos, sin

    def wav_to_spectrogram(
        self, input: torch.Tensor, eps: float = 1e-10
    ) -> List[torch.Tensor]:

        mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
        return mag
    
    
class EncoderBlockRes4B(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, downsample, activation, momentum, 
    ):
        r"""Encoder block, contains 8 convolutional layers."""
        super(EncoderBlockRes4B, self).__init__()

        self.conv_block1 = ConvBlockRes(
            in_channels, out_channels, kernel_size, activation, momentum,
        )
        self.conv_block2 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum,
        )
        # self.conv_block3 = ConvBlockRes(
        #     out_channels, out_channels, kernel_size, activation, momentum
        # )
        # self.conv_block4 = ConvBlockRes(
        #     out_channels, out_channels, kernel_size, activation, momentum
        # )
        self.downsample = downsample

    def forward(self, x):
        encoder = self.conv_block1(x)
        encoder = self.conv_block2(encoder)
        # encoder = self.conv_block3(encoder)
        # encoder = self.conv_block4(encoder)
        encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
        return encoder_pool, encoder
    
class DecoderBlockRes4B(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, upsample, activation, momentum
    ):
        r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
        super(DecoderBlockRes4B, self).__init__()
        self.kernel_size = kernel_size
        self.stride = upsample
        self.activation = activation

        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.stride,
            stride=self.stride,
            padding=(0, 0),
            bias=False,
            dilation=(1, 1),
        )

        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.conv_block2 = ConvBlockRes(
            out_channels * 2, out_channels, kernel_size, activation, momentum
        )
        # self.conv_block3 = ConvBlockRes(
        #     out_channels, out_channels, kernel_size, activation, momentum
        # )
        # self.conv_block4 = ConvBlockRes(
        #     out_channels, out_channels, kernel_size, activation, momentum
        # )
        # self.conv_block5 = ConvBlockRes(
        #     out_channels, out_channels, kernel_size, activation, momentum
        # )

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_layer(self.conv1)

    def forward(self, input_tensor, concat_tensor):
        x = self.conv1(F.relu_(self.bn1(input_tensor)))
        x = torch.cat((x, concat_tensor), dim=1)
        x = self.conv_block2(x)
        # x = self.conv_block3(x)
        # x = self.conv_block4(x)
        # x = self.conv_block5(x)
        return x

In [9]:
    
class MyModel(nn.Module, Base):
    def __init__(self, input_channels, target_sources_num):
        super(MyModel, self).__init__()
        
        self.input_channels = input_channels
        self.target_sources_num = target_sources_num
        
        signal_length = 250

        window_size = 64
        hop_size = 25
        center = True
        pad_mode = "reflect"
        window = "hann"
        activation = "leaky_relu"
        momentum = 0.01

        self.subbands_num = 1 # 4
        self.K = 4  # outputs: |M|, cos∠M, sin∠M, Q

        self.downsample_ratio = 2 ** 3 # 5  # This number equals 2^{#encoder_blcoks}

        self.stft = STFT(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

        self.istft = ISTFT(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )
        
        self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
        
        self.encoder_block1 = EncoderBlockRes4B(
            in_channels=input_channels,
            out_channels=16,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        
        self.encoder_block2 = EncoderBlockRes4B(
            in_channels=16,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        
        self.encoder_block3 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=64,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block4 = EncoderBlockRes4B(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
#         self.encoder_block4 = EncoderBlockRes4B(
#             in_channels=128,
#             out_channels=256,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.encoder_block5 = EncoderBlockRes4B(
#             in_channels=256,
#             out_channels=384,
#             kernel_size=(3, 3),
#             downsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.encoder_block6 = EncoderBlockRes4B(
#             in_channels=384,
#             out_channels=384,
#             kernel_size=(3, 3),
#             downsample=(1, 2),
#             activation=activation,
#             momentum=momentum,
#         )
        
#         conv_block_in_channels = 128 # 384
        
#         self.conv_block7a = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.conv_block7b = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.conv_block7c = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.conv_block7d = EncoderBlockRes4B(
#             in_channels=conv_block_in_channels,
#             out_channels=conv_block_in_channels,
#             kernel_size=(3, 3),
#             downsample=(1, 1),
#             activation=activation,
#             momentum=momentum,
#         )
        
#         self.decoder_block1 = DecoderBlockRes4B(
#             in_channels=384,
#             out_channels=384,
#             kernel_size=(3, 3),
#             upsample=(1, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block2 = DecoderBlockRes4B(
#             in_channels=384,
#             out_channels=384,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block3 = DecoderBlockRes4B(
#             in_channels=384,
#             out_channels=256,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block4 = DecoderBlockRes4B(
#             in_channels=128,
#             out_channels=128,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block5 = DecoderBlockRes4B(
#             in_channels=128,
#             out_channels=64,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )
#         self.decoder_block6 = DecoderBlockRes4B(
#             in_channels=64,
#             out_channels=32,
#             kernel_size=(3, 3),
#             upsample=(2, 2),
#             activation=activation,
#             momentum=momentum,
#         )

        self.after_conv_block1 = EncoderBlockRes4B(
            in_channels=64,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=target_sources_num
            * input_channels
            * self.K
            * self.subbands_num,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )
        
        # self.out_conv_block = EncoderBlockRes4B(
        #     in_channels=target_sources_num
        #     * input_channels
        #     * self.subbands_num,
        #     out_channels=target_sources_num,
        #     kernel_size=(1, signal_length),
        #     downsample=(1, 1),
        #     activation=activation,
        #     momentum=momentum,
        #     padding=(0,0),
        #     shortcut=False
        # )
        
        self.out_conv_block = Conv2d(
            in_channels=target_sources_num
            * input_channels
            * self.subbands_num,
            out_channels=target_sources_num,
            kernel_size=(1, signal_length),
            stride=(1, 1),
            dilation=(1, 1),
            padding=(0,0),
        )

        self.init_weights()
        
    def init_weights(self):
        init_bn(self.bn0)
        init_layer(self.after_conv2)
        
    def feature_maps_to_wav(
        self,
        input_tensor: torch.Tensor,
        sp: torch.Tensor,
        sin_in: torch.Tensor,
        cos_in: torch.Tensor,
        audio_length: int,
    ) -> torch.Tensor:
        r"""Convert feature maps to waveform.
        Args:
            input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
            sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
            sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
            cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
        Outputs:
            waveform: (batch_size, target_sources_num * input_channels, segment_samples)
        """
        batch_size, _, time_steps, freq_bins = input_tensor.shape

        x = input_tensor.reshape(
            batch_size,
            self.target_sources_num,
            self.input_channels,
            self.K,
            time_steps,
            freq_bins,
        )
        # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)

        mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
        _mask_real = torch.tanh(x[:, :, :, 1, :, :])
        _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
        linear_mag = torch.tanh(x[:, :, :, 3, :, :])
        _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
        # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Y = |Y|cos∠Y + j|Y|sin∠Y
        #   = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
        #   = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
        out_cos = (
            cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
        )
        out_sin = (
            sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
        )
        # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
        # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Calculate |Y|.
        out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
        # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Calculate Y_{real} and Y_{imag} for ISTFT.
        out_real = out_mag * out_cos
        out_imag = out_mag * out_sin
        # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
        shape = (
            batch_size * self.target_sources_num * self.input_channels,
            1,
            time_steps,
            freq_bins,
        )
        out_real = out_real.reshape(shape)
        out_imag = out_imag.reshape(shape)

        # ISTFT.
        x = self.istft(out_real, out_imag, audio_length)
        # (batch_size * target_sources_num * input_channels, segments_num)

        # Reshape.
        waveform = x.reshape(
            batch_size, self.target_sources_num * self.input_channels, audio_length
        )
        # (batch_size, target_sources_num * input_channels, segments_num)

        return waveform
        
    def forward(self, x):
        
        subband_x = x

        mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
        # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
        
        # Batch normalize on individual frequency bins.
        x = mag.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        # (batch_size, input_channels * subbands_num, time_steps, freq_bins)
        # print(11, x.shape)
        
        # Pad spectrogram to be evenly divided by downsample ratio.
        origin_len = x.shape[2]
        # print(22, origin_len)
        pad_len = (
            int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
            - origin_len
        )
        x = F.pad(x, pad=(0, 0, 0, pad_len))
        # print(33, x.shape)
        # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
        
        # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
        x = x[..., 0 : x.shape[-1] - 1]  # (bs, input_channels, T, F)
        # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
        # print(44, x.shape)
        
        (x1_pool, x1) = self.encoder_block1(x)
        (x2_pool, x2) = self.encoder_block2(x1)
        (x3_pool, x3) = self.encoder_block3(x2)
        
        
        (x, _) = self.after_conv_block1(x3)  # (bs, 32, T, F)
        
        x = self.after_conv2(x)
        # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
        # # print(33, "x.shape", x.shape)

        # Recover shape
        x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 256 -> 257.

        x = x[:, :, 0:origin_len, :]
        
        # print(55, x.shape)
        
        audio_length = subband_x.shape[2]
        
        # Recover each subband spectrograms to subband waveforms. Then synthesis
        # the subband waveforms to a waveform.
        C1 = x.shape[1] // self.subbands_num
        C2 = mag.shape[1] // self.subbands_num

        separated_subband_audio = torch.cat(
            [
                self.feature_maps_to_wav(
                    input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
                    sp=mag[:, j * C2 : (j + 1) * C2, :, :],
                    sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
                    cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
                    audio_length=audio_length,
                )
                for j in range(self.subbands_num)
            ],
            dim=1,
        )
        
        
        separated_subband_audio = torch.unsqueeze(separated_subband_audio, 2)
        # print(66, separated_subband_audio.shape)
        
        y = self.out_conv_block(separated_subband_audio)
        
        y = torch.squeeze(y, 2)
        y = torch.squeeze(y, 2)
                
        
#         # UNet
#         # print("x", x.shape)
#         (x1_pool, x1) = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
#         # print(x1_pool.shape, x1.shape)
#         (x2_pool, x2) = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
#         # print(x2_pool.shape, x2.shape)
#         (x3_pool, x3) = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
#         # print(x3_pool.shape, x3.shape)
#         # (x4_pool, x4) = self.encoder_block4(x3_pool)  # x4_pool: (bs, 256, T / 16, F / 16)
#         # (x5_pool, x5) = self.encoder_block5(x4_pool)  # x5_pool: (bs, 384, T / 32, F / 32)
#         # (x6_pool, x6) = self.encoder_block6(x5_pool)  # x6_pool: (bs, 384, T / 32, F / 64)
#         (x_center, _) = self.conv_block7a(x3_pool)  # (bs, 384, T / 32, F / 64)
#         (x_center, _) = self.conv_block7b(x_center)  # (bs, 384, T / 32, F / 64)
#         # (x_center, _) = self.conv_block7c(x_center)  # (bs, 384, T / 32, F / 64)
#         # (x_center, _) = self.conv_block7d(x_center)  # (bs, 384, T / 32, F / 64)
#         # x7 = self.decoder_block1(x_center, x6)  # (bs, 384, T / 32, F / 32)
#         # x8 = self.decoder_block2(x7, x5)  # (bs, 384, T / 16, F / 16)
#         # x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
#         # print("x_center.shape, x3.shape", x_center.shape, x3.shape)
#         x10 = self.decoder_block4(x_center, x3)  # (bs, 128, T / 4, F / 4)
#         x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
#         x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
        
#         (x, _) = self.after_conv_block1(x12)  # (bs, 32, T, F)
        
#         x = self.after_conv2(x)
#         # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
#         # print(33, "x.shape", x.shape)

#         # Recover shape
#         x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 256 -> 257.

#         x = x[:, :, 0:origin_len, :]
#         # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')

#         audio_length = subband_x.shape[2]
        
#         # Recover each subband spectrograms to subband waveforms. Then synthesis
#         # the subband waveforms to a waveform.
#         C1 = x.shape[1] // self.subbands_num
#         C2 = mag.shape[1] // self.subbands_num

#         separated_subband_audio = torch.cat(
#             [
#                 self.feature_maps_to_wav(
#                     input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
#                     sp=mag[:, j * C2 : (j + 1) * C2, :, :],
#                     sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
#                     cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
#                     audio_length=audio_length,
#                 )
#                 for j in range(self.subbands_num)
#             ],
#             dim=1,
#         )
#         # （batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
        
#         separated_subband_audio = torch.unsqueeze(separated_subband_audio, 2)
#         (y, _) = self.out_conv_block(separated_subband_audio)
        
#         y = torch.squeeze(y, 2)
        
        return y
        
        
        
tmp_x = torch.rand(3, 9, 250)
# 
# input_dict = {
#     "waveform": tmp_x
# }

tmp_layer = MyModel(input_channels=9, target_sources_num=10)
tmp_y = tmp_layer(tmp_x)
tmp_y.shape

# 99 torch.Size([3, 360, 41, 33])


torch.Size([3, 10])

In [10]:
# import torch.nn as nn
# import torch.optim as optim
# train_acc = torchmetrics.Accuracy()

# model = nn.Linear(100, 1) # predict logits for 5 classes
# x = torch.randn(1, 10, 100)
# y = torch.randint(0, 2, (1, 10, 1)).double()
# print(y.shape, y)

# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.SGD(model.parameters(), lr=1e-1)

# for epoch in range(20):
#     optimizer.zero_grad()
#     output = model(x)
#     print("output.shape, y.shape", output.shape, y.shape)
#     loss = criterion(output, y)
#     loss.backward()
#     optimizer.step()
#     acc = train_acc(output, y.long())
#     print('Loss: {:.3f}, Acc: {:.3f} '.format(loss.item(), acc.item()))


# import torch.nn as nn
# import torch.optim as optim
# train_acc = torchmetrics.Accuracy()

# model = nn.Conv1d(10, 10, kernel_size=1000, groups=10)
# x = torch.randn(3, 10, 1000)
# y = torch.randint(0, 3, (3,))
# print(y.shape, y)

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=1e-1)

# for epoch in range(20):
#     optimizer.zero_grad()
#     output = model(x)
#     output = torch.squeeze(output)
#     print("output.shape, y.shape", output.shape, y.shape)
#     loss = criterion(output, y)
#     loss.backward()
#     optimizer.step()
#     acc = train_acc(output, y.long())
#     print('Loss: {:.3f}, Acc: {:.3f} '.format(loss.item(), acc.item()))



In [11]:
import torchmetrics
from splearn.nn.base import LightningModel


class LightningModelClassifier(LightningModel):
    def __init__(
        self,
        optimizer="adamw",
        scheduler="cosine_with_warmup",
        optimizer_learning_rate: float=1e-3,
        optimizer_epsilon: float=1e-6,
        optimizer_weight_decay: float=0.0005,
        scheduler_warmup_epochs: int=10,
        criterion=None
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
        
        self.criterion_classifier = criterion
        if self.criterion_classifier is None:
            self.criterion_classifier = nn.CrossEntropyLoss()
                
    def build_model(self, model):
        self.model = model

    def forward(self, x):
        y_hat = self.model(x)
        return y_hat
    
    def step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion_classifier(y_hat, y)
        return y_hat, y, loss

    def training_step(self, batch, batch_idx):
        y_hat, y, loss = self.step(batch, batch_idx)
        acc = self.train_acc(y_hat, y)
        self.log('train_loss', loss, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat, y, loss = self.step(batch, batch_idx)
        acc = self.valid_acc(y_hat, y)
        self.log('valid_loss', loss, on_step=True)
        return loss

    def test_step(self, batch, batch_idx):
        y_hat, y, loss = self.step(batch, batch_idx)
        acc = self.test_acc(y_hat, y)
        self.log('test_loss', loss)
        return loss
    
    def training_epoch_end(self, outs):
        self.log('train_acc_epoch', self.train_acc.compute())
        
    def validation_epoch_end(self, outs):
        self.log('valid_acc_epoch', self.valid_acc.compute())
    
    def test_epoch_end(self, outs):
        self.log('test_acc_epoch', self.test_acc.compute())


In [12]:

from splearn.nn.base import LightningModelClassifier


class MultilabelLClassifier(LightningModelClassifier):
    def __init__(
        self,
        optimizer="adamw",
        scheduler="cosine_with_warmup",
        optimizer_learning_rate: float=1e-3,
        optimizer_epsilon: float=1e-6,
        optimizer_weight_decay: float=0.0005,
        scheduler_warmup_epochs: int=10,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.criterion_classifier = nn.CrossEntropyLoss() # nn.BCEWithLogitsLoss()
    
    def build_model(self, model, model_output_dim, num_classes, **kwargs):
        self.model = model
        # self.classifier = nn.Linear(model_output_dim*num_classes, num_classes)
        # self.classifier = nn.Conv1d(num_classes, num_classes, kernel_size=model_output_dim, groups=num_classes)

    def forward(self, x):
        x = self.model(x)
        # x = torch.flatten(x, 1)
        # y_hat = self.classifier(x)
        # y_hat = torch.squeeze(y_hat, 2)
        return x
    
    def train_val_step(self, batch, batch_idx):
        x, y = batch
        # y = torch.unsqueeze(y, 2).double()
        y_hat = self.forward(x)
        # y_hat = torch.sigmoid(y_hat)
        loss = self.criterion_classifier(y_hat, y.long())
        # loss = F.cross_entropy(y_hat, y)
        return y_hat, y, loss

    def training_step(self, batch, batch_idx):
        y_hat, y, loss = self.train_val_step(batch, batch_idx)
        acc = self.train_acc(y_hat, y.long())
        self.log('train_loss', loss, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat, y, loss = self.train_val_step(batch, batch_idx)
        acc = self.valid_acc(y_hat, y.long())
        self.log('valid_loss', loss, on_step=True)
        return loss

    def test_step(self, batch, batch_idx):
        y_hat, y, loss = self.train_val_step(batch, batch_idx)
        acc = self.test_acc(y_hat, y.long())
        self.log('test_loss', loss)
        return loss

    
# x = torch.rand(3, 9, config.data.sample_length)
# y = torch.randint(0, 2, (3,))

# unet = ResUNet143_Subbandtime(input_channels=config.data.input_channels, target_sources_num=config.data.target_sources_num)
# model = MultilabelLClassifier(
#     optimizer=config.model.optimizer,
#     scheduler=config.model.scheduler,
#     optimizer_learning_rate=config.training.learning_rate,
#     scheduler_warmup_epochs=config.training.num_warmup_epochs,
# )
# model.build_model(model=unet, model_output_dim=config.data.sample_length, num_classes=config.data.num_classes)

# tmp_y = model(x)
# print("tmp_y", tmp_y.shape)
# print(tmp_y)

# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

# for epoch in range(10):
#     optimizer.zero_grad()
#     output = model(x)
#     loss = criterion(output, y)
#     loss.backward()
#     optimizer.step()
#     print('Loss: {:.3f}'.format(loss.item()))


In [13]:
# # x = torch.rand(3, 9, 1000)
# # y = torch.randint(0, 2, (3,)).double()

# # unet = ResUNet143_Subbandtime(input_channels=config.data.input_channels, target_sources_num=config.data.target_sources_num)
# # model = MultilabelLClassifier(
# #     optimizer=config.model.optimizer,
# #     scheduler=config.model.scheduler,
# #     optimizer_learning_rate=config.training.learning_rate,
# #     scheduler_warmup_epochs=config.training.num_warmup_epochs,
# # )
# # model.build_model(model=unet, model_output_dim=config.data.sample_length)

# model = nn.Linear(32,5)
# x = torch.rand(3, 32)
# y = torch.randint(0, 2, (3,))
# print(y)

# tmp_y = model(x)
# print("tmp_y", tmp_y.shape)
# print(tmp_y)

# criterion = nn.CrossEntropyLoss() # torch.nn.BCEWithLogitsLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

# for epoch in range(20):
#     optimizer.zero_grad()
#     output = model(x)
#     loss = criterion(output, y)
#     # loss = F.cross_entropy(output, y)
#     loss.backward()
#     optimizer.step()
#     print('Loss: {:.3f}'.format(loss.item()))


In [14]:
test_subject_id = 33
kfold_k = 0

## init data
train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k)
train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False)

## init model
unet = MyModel(input_channels=config.data.input_channels, target_sources_num=config.data.target_sources_num)
model = MultilabelLClassifier(
    optimizer=config.model.optimizer,
    scheduler=config.model.scheduler,
    optimizer_learning_rate=config.training.learning_rate,
    scheduler_warmup_epochs=config.training.num_warmup_epochs,
)
model.build_model(model=unet, model_output_dim=config.data.sample_length, num_classes=config.data.num_classes)

## init training
sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k)
logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.run_name, sub_dir=sub_dir)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor])
trainer.fit(model, train_loader, val_loader)

## test

result = trainer.test(dataloaders=test_loader, verbose=True)
test_acc = result[0]['test_acc_epoch']
test_acc


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 1234
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc_epoch': 0.1875, 'test_loss': 16.77286720275879}
--------------------------------------------------------------------------------


0.1875

In [15]:
test_acc

0.1875

In [16]:
# test_subject_id = 33
# kfold_k = 0

# train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k)
# train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False)

# print("train_loader", train_loader.dataset.data.shape, train_loader.dataset.targets.shape)
# print("val_loader", val_loader.dataset.data.shape, val_loader.dataset.targets.shape)
# print("test_loader", test_loader.dataset.data.shape, test_loader.dataset.targets.shape)

In [17]:
# tmp_acc = torchmetrics.Accuracy()

# # index = 0:2

# x = torch.tensor(train_loader.dataset.data[0:10])
# y = torch.tensor(train_loader.dataset.targets[0:10])
# # x = torch.unsqueeze(x, 0)
# # y = torch.unsqueeze(y, 0)
# # y = torch.unsqueeze(y, 2)
# print(x.shape, y.shape)
# y_hat = model(x)
# y_hat = torch.sigmoid(y_hat)

# acc = tmp_acc(y_hat, y.long())
# print("acc", acc)

In [18]:
# # trial = pred_y[0]
# # for i in trial:
# #     print(i)

# N = 2
# C = 40

# outputs = torch.squeeze(y_hat)
# labels = torch.squeeze(y)

# outputs = torch.sigmoid(outputs)  # torch.Size([N, C]) e.g. tensor([[0., 0.5, 0.]])
# outputs[outputs < 0.5] = 0
# outputs[outputs >= 0.5] = 1
# accuracy = (outputs == labels).sum()/(N*C)*100


In [19]:
# outputs