In [11]:
import argparse
import os
import torch
import wandb

from trainer import Trainer
from dataset import SpectrogramReader, Dataset, DataLoader, logger
from dcnet import DCNet, DCNetDecoder
from utils import nfft, parse_yaml

In [13]:
!export WANDB_NOTEBOOK_NAME=experiments

In [21]:
# Weight and Bias configuration

run = wandb.init(project="Speaker Separation")
wandb_config = run.config
wandb_config.num_epoches = 1
wandb_config.lr = 0.001

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


In [22]:
def uttloader(scp_config, reader_kwargs, loader_kwargs, train=True):
    mix_reader = SpectrogramReader(scp_config['mixture'], **reader_kwargs)
    target_reader = [
        SpectrogramReader(scp_config[spk_key], **reader_kwargs)
        for spk_key in scp_config if spk_key[:3] == 'spk'
    ]
    dataset = Dataset(mix_reader, target_reader)
#     print(dataset[12][0].shape)
#     print(dataset[12][1][0].shape)
#     print(dataset[12][1][1].shape)
    # modify shuffle status
    loader_kwargs["shuffle"] = train
    # validate perutt if needed
    # if not train:
    #     loader_kwargs["batch_size"] = 1
    # if validate, do not shuffle
    utt_loader = DataLoader(dataset, **loader_kwargs)
    return utt_loader

In [25]:
PATH = os.getcwd() + '/weights/'
loss_fn = torch.nn.MSELoss(reduction='mean')
def loss(x, x_):
    loss_ = loss_fn(x, x_)
    return loss_

def optimizer(encoder, decoder):
    optimizer_ = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=wandb_config.lr)
    return optimizer_

def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["dcnet"]

    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
#    valid_loader = uttloader(
#        config_dict["valid_scp_conf"]
#        if not debug else config_dict["debug_scp_conf"],
#        reader_conf,
#        loader_conf,
#        train=False)
#    checkpoint = config_dict["trainer"]["checkpoint"]
#    logger.info("Training for {} epoches -> {}...".format(
#        args.num_epoches, "default checkpoint"
#        if checkpoint is None else checkpoint))
#     loss_fn = torch.nn.MSELoss(reduction='mean')
    dcnet = DCNet(num_bins, **dcnnet_conf)
    dcnet_decode = DCNetDecoder(num_bins, **dcnnet_conf)
    optimizer_ = optimizer(dcnet, dcnet_decode)

    for epoch in range(args.num_epoches):
        for i in range(3):
            for a in iter(train_loader):
                if i != 0:
                    input_ = torch.mul(a[i].float(),a[0])
                else:
                    input_ = a[i]
    #             print("Input", input_)
                out = dcnet(input_)
#                 print(out.squeeze().shape)
                decode_out = dcnet_decode(out.squeeze())
                optimizer_.zero_grad()
                loss_ = loss(input_, decode_out)
                wandb.log({'epoch': epoch, 'Speaker '+ str(i) + ' loss': loss_})
#                 print("Loss", loss_.item())
                loss_.backward()
                optimizer_.step()

                decode_out = torch.sigmoid(decode_out)
#                 print("decoder output", decode_out)
    torch.save(dcnet.state_dict(), PATH + "encoder_" + str(i))
    torch.save(dcnet_decode.state_dict(), PATH + "decoder_" + str(i))

In [None]:
class Args:
    config = "conf/train.yaml"
    debug = True
    num_epoches = wandb_config.num_epoches
# args ={"config":"conf/train.yaml", "debug":False, "num_epoches":20} 
# print(args["debug"])
train(Args)

2020-02-06 13:52:59,292 [<ipython-input-25-d18d4d514533>:14 - INFO ] Start training in debug model
2020-02-06 13:52:59,300 [<ipython-input-25-d18d4d514533>:23 - INFO ] Training in per utterance
2020-02-06 13:52:59,304 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_mix.scp with 3000 utterances
2020-02-06 13:52:59,308 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_s1.scp with 3000 utterances
2020-02-06 13:52:59,312 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_s2.scp with 3000 utterances


torch.Size([77787, 20])


Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


torch.Size([67983, 20])
torch.Size([99330, 20])
torch.Size([100233, 20])
torch.Size([66306, 20])
torch.Size([90945, 20])
torch.Size([88107, 20])
torch.Size([79206, 20])
torch.Size([79722, 20])
torch.Size([77916, 20])
torch.Size([77271, 20])
torch.Size([146544, 20])
torch.Size([72369, 20])
torch.Size([99201, 20])
torch.Size([84108, 20])
torch.Size([132483, 20])
torch.Size([54696, 20])
torch.Size([73401, 20])
torch.Size([87333, 20])
torch.Size([71982, 20])
torch.Size([89139, 20])
torch.Size([87720, 20])
torch.Size([84108, 20])
torch.Size([90816, 20])
torch.Size([78174, 20])
torch.Size([102297, 20])
torch.Size([60888, 20])
torch.Size([89784, 20])
torch.Size([78948, 20])
torch.Size([88365, 20])
torch.Size([143706, 20])
torch.Size([89913, 20])
torch.Size([111456, 20])
torch.Size([81786, 20])
torch.Size([93267, 20])
torch.Size([161379, 20])
torch.Size([93654, 20])
torch.Size([92880, 20])
torch.Size([100233, 20])
torch.Size([44892, 20])
torch.Size([45279, 20])
torch.Size([93396, 20])
torch.Si

In [19]:
import torch
a = torch.zeros((24,20))

In [20]:
a = a.mean(dim=1)

In [21]:
a = a.unsqueeze(dim=0)
print(a.shape)
a = a.view(-1, 12)

torch.Size([1, 24])


In [73]:
os.getcwd() + '/weights/'

'/home/ubuntu/deep-clustering/weights'