In [1]:
import argparse
import os

from trainer import Trainer
from dataset import SpectrogramReader, Dataset, DataLoader, logger
from dcnet import DCNet, DCNetDecoder
from utils import nfft, parse_yaml

In [2]:
def uttloader(scp_config, reader_kwargs, loader_kwargs, train=True):
    mix_reader = SpectrogramReader(scp_config['mixture'], **reader_kwargs)
    target_reader = [
        SpectrogramReader(scp_config[spk_key], **reader_kwargs)
        for spk_key in scp_config if spk_key[:3] == 'spk'
    ]
    dataset = Dataset(mix_reader, target_reader)
    print(dataset[12][0].shape)
    print(dataset[12][1][0].shape)
    print(dataset[12][1][1].shape)
    # modify shuffle status
    loader_kwargs["shuffle"] = train
    # validate perutt if needed
    # if not train:
    #     loader_kwargs["batch_size"] = 1
    # if validate, do not shuffle
    utt_loader = DataLoader(dataset, **loader_kwargs)
    return utt_loader

In [3]:
def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["dcnet"]

    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
#    valid_loader = uttloader(
#        config_dict["valid_scp_conf"]
#        if not debug else config_dict["debug_scp_conf"],
#        reader_conf,
#        loader_conf,
#        train=False)
#    checkpoint = config_dict["trainer"]["checkpoint"]
#    logger.info("Training for {} epoches -> {}...".format(
#        args.num_epoches, "default checkpoint"
#        if checkpoint is None else checkpoint))


    a = next(iter(train_loader))
    print(a[0].shape)
    print(a[1].shape)
    print(a[2].shape)
    dcnet = DCNet(num_bins, **dcnnet_conf)
    dcnet_decode = DCNetDecoder(num_bins, **dcnnet_conf)
    out = dcnet(a[0])
    print(out.squeeze().shape)
    decode_out = dcnet_decode(out.squeeze())
#     print(out.shape)
    print("decoder output", decode_out-a[0])
#     trainer = Trainer(dcnet, **config_dict["trainer"])
#     trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)

In [4]:
class Args:
    config = "conf/train.yaml"
    debug = True
    num_epoches = 20
# args ={"config":"conf/train.yaml", "debug":False, "num_epoches":20} 
# print(args["debug"])
train(Args)

2020-01-23 15:44:30,871 [<ipython-input-3-920bdb836dec>:4 - INFO ] Start training in debug model
2020-01-23 15:44:30,881 [<ipython-input-3-920bdb836dec>:13 - INFO ] Training in per utterance
2020-01-23 15:44:30,887 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_mix.scp with 3000 utterances
2020-01-23 15:44:30,893 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_s1.scp with 3000 utterances
2020-01-23 15:44:30,899 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_s2.scp with 3000 utterances


(706, 129)
(706, 129)
(706, 129)
torch.Size([658, 129])
torch.Size([658, 129])
torch.Size([658, 129])




torch.Size([84882, 20])
decoder output tensor([[1.8381, 2.3416, 4.0811,  ..., 7.1844, 6.7106, 6.5764],
        [2.1981, 2.7783, 4.7065,  ..., 5.7026, 5.9964, 6.2000],
        [2.9323, 3.2905, 6.9986,  ..., 4.5831, 4.8302, 4.9681],
        ...,
        [2.0774, 1.8685, 1.8357,  ..., 5.2650, 4.8654, 5.0637],
        [5.0002, 3.1249, 3.1457,  ..., 5.0672, 4.7459, 5.1371],
        [3.1505, 3.3748, 5.2880,  ..., 5.5818, 4.4799, 4.3835]],
       grad_fn=<SubBackward0>)


In [19]:
import torch
a = torch.zeros((24,20))

In [20]:
a = a.mean(dim=1)

In [21]:
a = a.unsqueeze(dim=0)
print(a.shape)
a = a.view(-1, 12)

torch.Size([1, 24])
