In [33]:
import warnings
warnings.simplefilter("ignore", UserWarning)

import numpy as np
import json
import pickle 
import h5py
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import random
import soundfile as sf
import os
from pprint import pprint

In [34]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0 , 1, 2'
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
from aia_trans import dual_aia_trans_merge_crm
from solver_merge import Solver
from train_utils import parser_all, SEDataset, SEDataLoader, get_noise_clean_path
args = parser_all.parse_args(args = [])
# model = dual_aia_trans_merge_crm().cuda()

In [9]:
with open('/workspace/SE_2022/train_noise_by_type.pkl', 'rb') as f:
    train_noise = pickle.load(f)

with open('/workspace/SE_2022/val_noise_by_type.pkl', 'rb') as f:
    val_noise = pickle.load(f)

with open('/workspace/SE_2022/train_map.pkl', 'rb') as f:
    noise_clean_map = pickle.load(f)

data_dir = '/workspace/data/train'
train_path_array = get_noise_clean_path(data_dir, train_noise, noise_clean_map)
val_path_array = get_noise_clean_path(data_dir, val_noise, noise_clean_map)
train_dataset = SEDataset(train_path_array, 2)
val_dataset = SEDataset(val_path_array, args.batch_size)
train_dataloader = SEDataLoader(data_set=train_dataset,
                                batch_size=1,
                                num_workers=args.num_workers,
                                pin_memory=True)
val_dataloader = SEDataLoader(data_set=val_dataset,
                                batch_size=1,
                                num_workers=args.num_workers,
                                pin_memory=True)
data = {'tr_loader': train_dataloader, 'cv_loader': val_dataloader}

In [10]:
count = 0
for i  in train_dataloader.get_data_loader():
    count += 1

In [12]:
batch_info = i

In [29]:
batch_feat = batch_info.feats
batch_label = batch_info.labels
noisy_phase = torch.atan2(batch_feat[:,-1,:,:], batch_feat[:,0,:,:])
clean_phase = torch.atan2(batch_label[:,-1,:,:], batch_label[:,0,:,:])
batch_frame_mask_list = batch_info.frame_mask_list
batch_feat, batch_label = (torch.norm(batch_feat, dim=1)) ** 0.5, (torch.norm(batch_label, dim=1)) ** 0.5
batch_feat = torch.stack((batch_feat*torch.cos(noisy_phase), batch_feat*torch.sin(noisy_phase)), dim=1)
batch_label = torch.stack((batch_label*torch.cos(clean_phase), batch_label*torch.sin(clean_phase)), dim=1)

In [32]:
batch_feat[0]

tensor([[[-1.8366e+00, -2.2440e+00,  3.5939e+00,  ...,  2.9526e-01,
          -2.5061e-01, -1.2549e-01],
         [ 1.7138e+00, -6.8231e-01,  1.8847e+00,  ..., -2.2729e-01,
           2.6962e-01, -3.0862e-01],
         [-1.8269e+00,  1.0885e+00, -1.9100e+00,  ..., -3.6195e-01,
           1.8243e-01,  1.6676e-01],
         ...,
         [-1.8517e+00, -2.1858e-01,  2.3331e+00,  ..., -3.3536e-01,
           2.8231e-01, -3.8742e-01],
         [-1.3706e+00, -1.7271e+00,  3.0054e+00,  ...,  2.1722e-01,
           2.4538e-01, -2.0539e-01],
         [-1.5715e+00,  9.4538e-02,  1.7350e+00,  ...,  2.4488e-01,
           1.5673e-01, -5.4930e-01]],

        [[-1.6056e-07,  3.3883e-07,  1.0858e-07,  ..., -1.8989e-06,
           3.6228e-06, -1.0971e-08],
         [ 0.0000e+00,  2.6148e+00,  1.7652e+00,  ...,  1.0331e-01,
          -2.5558e-01, -2.6980e-08],
         [-1.5972e-07, -3.7770e+00,  3.1621e+00,  ...,  1.7544e-01,
           1.3127e-02,  0.0000e+00],
         ...,
         [-1.6188e-07, -1

In [30]:
batch_feat.shape

torch.Size([2, 2, 279, 161])

In [20]:
batch_feat = batch_info.feats
batch_label = batch_info.labels

In [14]:
batch_feat.shape

torch.Size([2, 2, 279, 161])

In [15]:
noisy_phase = torch.atan2(batch_feat[:,-1,:,:], batch_feat[:,0,:,:])
clean_phase = torch.atan2(batch_label[:,-1,:,:], batch_label[:,0,:,:])

In [17]:
noisy_phase.shape

torch.Size([2, 279, 161])

In [None]:

noisy_phase = torch.atan2(batch_feat[:,-1,:,:], batch_feat[:,0,:,:])
clean_phase = torch.atan2(batch_label[:,-1,:,:], batch_label[:,0,:,:])
batch_frame_mask_list = batch_info.frame_mask_list

In [21]:
batch_feat_norm, batch_label_norm = (torch.norm(batch_feat, dim=1)) ** 0.5, (
    torch.norm(batch_label, dim=1)) ** 0.5

In [24]:
batch_feat_norm_single = (torch.norm(batch_feat[0], dim=0)) ** 0.5

In [28]:
batch_feat_norm_single

tensor([[1.8366, 2.2440, 3.5939,  ..., 0.2953, 0.2506, 0.1255],
        [1.7138, 2.7023, 2.5822,  ..., 0.2497, 0.3715, 0.3086],
        [1.8269, 3.9308, 3.6942,  ..., 0.4022, 0.1829, 0.1668],
        ...,
        [1.8517, 1.8779, 2.3342,  ..., 0.3354, 0.4045, 0.3874],
        [1.3706, 2.0784, 3.0258,  ..., 0.2200, 0.2920, 0.2054],
        [1.5715, 0.7433, 1.7964,  ..., 0.5606, 0.5866, 0.5493]])

In [27]:
batch_feat_norm[0]

tensor([[1.8366, 2.2440, 3.5939,  ..., 0.2953, 0.2506, 0.1255],
        [1.7138, 2.7023, 2.5822,  ..., 0.2497, 0.3715, 0.3086],
        [1.8269, 3.9308, 3.6942,  ..., 0.4022, 0.1829, 0.1668],
        ...,
        [1.8517, 1.8779, 2.3342,  ..., 0.3354, 0.4045, 0.3874],
        [1.3706, 2.0784, 3.0258,  ..., 0.2200, 0.2920, 0.2054],
        [1.5715, 0.7433, 1.7964,  ..., 0.5606, 0.5866, 0.5493]])

In [None]:
batch_feat = torch.stack((batch_feat*torch.cos(noisy_phase), batch_feat*torch.sin(noisy_phase)), dim=1)

In [4]:
optimizer = torch.optim.Adam(model.parameters(),
                                args.lr,
                                weight_decay=args.l2)
solver = Solver(data, model, optimizer, args)

In [5]:

solver.train()

Begin to train.....
Epoch:1, Iter:0, the average_loss:0.445519, current_loss:0.445519, 1895 ms/batch.
Epoch:1, Iter:1, the average_loss:0.318780, current_loss:0.192041, 1597 ms/batch.
Epoch:1, Iter:2, the average_loss:0.281174, current_loss:0.205961, 1483 ms/batch.
Epoch:1, Iter:3, the average_loss:0.266311, current_loss:0.221722, 1430 ms/batch.
Epoch:1, Iter:4, the average_loss:0.265650, current_loss:0.263008, 1399 ms/batch.
Epoch:1, Iter:5, the average_loss:0.264356, current_loss:0.257886, 1380 ms/batch.
Epoch:1, Iter:6, the average_loss:0.243883, current_loss:0.121040, 1365 ms/batch.
Epoch:1, Iter:7, the average_loss:0.239832, current_loss:0.211475, 1356 ms/batch.
Epoch:1, Iter:8, the average_loss:0.226150, current_loss:0.116698, 1347 ms/batch.
Epoch:1, Iter:9, the average_loss:0.227096, current_loss:0.235613, 1341 ms/batch.
Epoch:1, Iter:10, the average_loss:0.230714, current_loss:0.266894, 1339 ms/batch.
Epoch:1, Iter:11, the average_loss:0.222600, current_loss:0.133346, 1334 ms/b

KeyboardInterrupt: 