In [1]:
import argparse
import yaml
import os
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from src.data_utilities import EnhancementDataset
from src.loss_utilities import snr_loss
from models import EnhancementNet

In [2]:
parser = argparse.ArgumentParser(description='enh')

parser.add_argument('--batch-size', type=int, default=6,
                    help='input batch size for training')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train')
parser.add_argument('--cuda', action='store_true', default=True,
                    help='enables CUDA training')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='learning rate')
parser.add_argument('--config', default='./configs/config_enhancement.yaml', type=str,
                    help='model config')
parser.add_argument('--seed', type=int, default=2023028,
                    help='random seed')
parser.add_argument('--training-file-path', default='/engram/naplab/users/ch3212/google_moving_large_stage2/tr', type=str,
                    help='training file path')
parser.add_argument('--validation-file-path', default='/engram/naplab/users/ch3212/google_moving_large_stage2/cv', type=str,
                    help='validation file path')
parser.add_argument('--checkpoint-path', type=str,  default='/engram/naplab/users/ch3212/NMI/enh',
                    help='path to save the model')

_StoreAction(option_strings=['--checkpoint-path'], dest='checkpoint_path', nargs=None, const=None, default='/engram/naplab/users/ch3212/NMI/enh', type=<class 'str'>, choices=None, required=False, help='path to save the model', metavar=None)

In [3]:
def load_config(path):
    with open(path, 'r') as ymlfile:
        config = yaml.safe_load(ymlfile)
    return config


def save_checkpoint(filepath, obj):
    print("Saving checkpoint to {}".format(filepath))
    torch.save(obj, filepath)
    print("Complete.")

In [4]:
def train(train_loader, validation_loader, model, optimizer, scheduler, summary_writer, args):
   
    for epoch in range(1, args.epochs + 1):

        model.train()
        train_loss = 0.

        for batch_idx, data in enumerate(train_loader):
            batch_s1 = Variable(data[0]).contiguous()
            batch_s2 = Variable(data[1]).contiguous()
            batch_est_s1 = Variable(data[2]).contiguous()
            batch_est_s2 = Variable(data[3]).contiguous()
            batch_noise = Variable(data[4]).contiguous()
            batch_mix = batch_s1 + batch_s2 + batch_noise

            if args.cuda:
                batch_mix = batch_mix.cuda()
                batch_s1 = batch_s1.cuda()
                batch_s2 = batch_s2.cuda()
                batch_est_s1 = batch_est_s1.cuda()
                batch_est_s2 = batch_est_s2.cuda()
                
            batch_mix = torch.cat([batch_mix, batch_mix], dim=0)
            batch_est = torch.cat([batch_est_s1, batch_est_s2], dim=0)
            batch_clean = torch.cat([batch_s1, batch_s2], dim=0)
            batch_output = model(batch_mix, batch_est)

            optimizer.zero_grad()
            loss = torch.mean(snr_loss(batch_clean.view(batch_clean.size(0)*2, batch_clean.size(2)), 
                                       batch_output.view(batch_clean.size(0)*2, batch_clean.size(2)))
                             )   
                
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.)
            train_loss += loss.data.item() 
            optimizer.step()
            
        train_loss /= (batch_idx+1)
        
        print('train', epoch, train_loss)
        
        summary_writer.add_scalar("training/train_loss", train_loss, epoch)
        
        checkpoint_path = "{}/epoch_{:03d}".format(args.checkpoint_path, epoch)
        
        save_checkpoint(
            checkpoint_path,
            {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
        })
        
        
        model.eval()
        validation_loss = 0.
        
        for batch_idx, data in enumerate(validation_loader):
            batch_s1 = Variable(data[0]).contiguous()
            batch_s2 = Variable(data[1]).contiguous()
            batch_est_s1 = Variable(data[2]).contiguous()
            batch_est_s2 = Variable(data[3]).contiguous()
            batch_noise = Variable(data[4]).contiguous()
            batch_mix = batch_s1 + batch_s2 + batch_noise

            if args.cuda:
                batch_mix = batch_mix.cuda()
                batch_s1 = batch_s1.cuda()
                batch_s2 = batch_s2.cuda()
                batch_est_s1 = batch_est_s1.cuda()
                batch_est_s2 = batch_est_s2.cuda()
                
            batch_mix = torch.cat([batch_mix, batch_mix], dim=0)
            batch_est = torch.cat([batch_est_s1, batch_est_s2], dim=0)
            batch_clean = torch.cat([batch_s1, batch_s2], dim=0)

            with torch.no_grad():

                batch_output = model(batch_mix, batch_est)
                loss = torch.mean(snr_loss(batch_clean.view(batch_clean.size(0)*2, batch_clean.size(2)), 
                                           batch_output.view(batch_clean.size(0)*2, batch_clean.size(2)))
                                 )  

                validation_loss += loss.data.item() 
                
        validation_loss /= (batch_idx+1)
        summary_writer.add_scalar("training/val_loss", validation_loss, epoch)
        
        print('eval', epoch, validation_loss)
        
        if epoch % 2 == 0:
            scheduler.step()

In [5]:
args, _ = parser.parse_known_args()

args.cuda = args.cuda and torch.cuda.is_available()

if args.cuda:
    torch.cuda.manual_seed(args.seed)
    kwargs = {'num_workers': 4, 'pin_memory': True} 
else:
    kwargs = {}
    
random.seed(args.seed)
torch.manual_seed(args.seed)

<torch._C.Generator at 0x2b8e74feef90>

In [6]:
train_loader = DataLoader(EnhancementDataset(args.training_file_path), 
                          batch_size=args.batch_size, 
                          shuffle=True, 
                          **kwargs)


validation_loader = DataLoader(EnhancementDataset(args.validation_file_path), 
                               batch_size=args.batch_size, 
                               shuffle=False, 
                               **kwargs)

In [7]:
config = load_config(args.config)

In [8]:
model = EnhancementNet(
    enc_dim = config['enc_dim'],
    feature_dim = config['feature_dim'],
    hidden_dim = config['hidden_dim'],
    enc_win = config['enc_win'],
    enc_stride = config['enc_stride'],
    num_block = config['num_block'],
    num_layer = config['num_layer'],
    kernel_size = config['kernel_size'],
    num_spk = config['num_spk'],
)

if args.cuda:
    model.cuda()

Receptive field: 1271 frames.


In [9]:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler  = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

In [10]:
sw = SummaryWriter(os.path.join(args.checkpoint_path, 'logs'))

In [None]:
train(train_loader, validation_loader, model, optimizer, scheduler, sw, args)

train 1 -11.376143914057561
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_001
Complete.
eval 1 -12.099213304519653
train 2 -12.957555113077163
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_002
Complete.
eval 2 -12.840064935684204
train 3 -13.495704421758651
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_003
Complete.
eval 3 -13.191324367523194
train 4 -13.828021122217178
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_004
Complete.
eval 4 -13.37281294822693
train 5 -14.090777539491654
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_005
Complete.
eval 5 -13.637734718322754
train 6 -14.291232374429702
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_006
Complete.
eval 6 -13.749120535850524
train 7 -14.472587305545806
Saving checkpoint to /engram/naplab/users/ch3212/NMI/enh/epoch_007
Complete.
eval 7 -13.849336700439453
train 8 -14.62489506649971
Saving checkpoint to /engram/naplab/users/ch3212/N