In [1]:
import argparse
import os
import torch
import torch.nn.functional as F
import wandb

from trainer import Trainer
from dataset import SpectrogramReader, Dataset, DataLoader, logger
from dcnet import DCNet, DCNetDecoder
from utils import nfft, parse_yaml

In [9]:
!export WANDB_NOTEBOOK_NAME=experiments
# !pip install wandb --upgrade
# !pip install -U pip
!echo $WANDB_NOTEBOOK_NAME




In [11]:
# # Weight and Bias configuration

# run = wandb.init(project="Speaker Separation")
# wandb_config = run.config
# wandb_config.num_epoches = 1
# wandb_config.lr = 0.001

# Configure the sweep – specify the parameters to search through, the search strategy, the optimization metric et all.
sweep_config = {
    'method': 'random', #grid, random
    'metric': {
      'name': 'accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'lr': {
            'values': [1e-2, 1e-3]
        },
        'optim': {
            'values': ['adam', 'nadam', 'sgd', 'rmsprop']
        }
    }
}
sweep_id = wandb.sweep(sweep_config, entity="kantologist", project="SpeakerSeparation")

Create sweep with ID: tyteotv4
Sweep URL: https://app.wandb.ai/kantologist/SpeakerSeparation/sweeps/tyteotv4


In [12]:
def uttloader(scp_config, reader_kwargs, loader_kwargs, train=True):
    mix_reader = SpectrogramReader(scp_config['mixture'], **reader_kwargs)
    target_reader = [
        SpectrogramReader(scp_config[spk_key], **reader_kwargs)
        for spk_key in scp_config if spk_key[:3] == 'spk'
    ]
    dataset = Dataset(mix_reader, target_reader)
#     print(dataset[12][0].shape)
#     print(dataset[12][1][0].shape)
#     print(dataset[12][1][1].shape)
    # modify shuffle status
    loader_kwargs["shuffle"] = train
    # validate perutt if needed
    # if not train:
    #     loader_kwargs["batch_size"] = 1
    # if validate, do not shuffle
    utt_loader = DataLoader(dataset, **loader_kwargs)
    return utt_loader

In [19]:
PATH = os.getcwd() + '/weights/'
# loss_fn = torch.nn.MSELoss(reduction='mean')

def loss(x, x_, mu, var):
#     likelihood = F.binary_cross_entropy(x_, x, reduction="sum")
    likelihood = F.mse_loss(x_, x, reduction="sum")
    kld = -0.5 * torch.sum(1 + var - mu.pow(2) - var.exp())
    return likelihood + kld

def optimizer(optim_config, lr_config, encoder, decoder):
    if optim_config == 'adam':
        optimizer_ = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr_config)
    if optim_config == 'nadam':
        optimizer_ = torch.optim.Nadam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr_config)
    if optim_config == 'sgd':
        optimizer_ = torch.optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr_config)
    if optim_config == 'rmsprop':
        optimizer_ = torch.optim.RMSprop(list(encoder.parameters()) + list(decoder.parameters()), lr=lr_config)
    return optimizer_

def train():
    class Args:
        config = "conf/train.yaml"
        debug = True
    args = Args
    
    # Weight and Bias configuration
    
    config_defaults = {
        'num_epoches': 1,
        'lr': 1e-3,
        'optim': 'adam'
    }

    wandb.init(config=config_defaults)
    wandb_config = wandb.config
#     wandb_config.num_epoches = 1
#     wandb_config.lr = 0.001
#     wandb_config.optim = 'adam' 
    

#     debug = args.debug
    debug = wandb_config.num_epoches
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["dcnet"]

    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
#    valid_loader = uttloader(
#        config_dict["valid_scp_conf"]
#        if not debug else config_dict["debug_scp_conf"],
#        reader_conf,
#        loader_conf,
#        train=False)
#    checkpoint = config_dict["trainer"]["checkpoint"]
#    logger.info("Training for {} epoches -> {}...".format(
#        args.num_epoches, "default checkpoint"
#        if checkpoint is None else checkpoint))
#     loss_fn = torch.nn.MSELoss(reduction='mean')
    dcnet = DCNet(num_bins, **dcnnet_conf)
    dcnet_decode = DCNetDecoder(num_bins, **dcnnet_conf)
    optimizer_ = optimizer(wandb_config.optim, wandb_config.lr, dcnet, dcnet_decode)

    for epoch in range(wandb_config.num_epoches):
        for i in range(3):
            for j, a in enumerate(iter(train_loader)):
                input_ = a[i]
#                 print("Input", input_.shape)
#                 z, mu, var = dcnet(input_)
#                 print(z.shape)
#                 print(mu.shape)
#                 print(var.shape)
#                 decode_out = dcnet_decode(z)
#                 print("Out", decode_out.shape)
#                 loss_ = loss(input_, decode_out, mu, var)
#                 print("Loss", loss_.item())
                
                if i != 0:
                    input_ = torch.mul(a[i].float(),a[0])
                else:
                    input_ = a[i]
    #             print("Input", input_)
                z, mu, var = dcnet(input_)
                decode_out = dcnet_decode(z)
                optimizer_.zero_grad()
                loss_ = loss(input_, decode_out, mu, var)
#                 wandb.log({'epoch': epoch, 'Speaker '+ str(i) + ' loss': loss_})
                wandb.log({'Speaker '+ str(i) + ' loss': loss_})
#                 print("Loss", loss_.item())
                loss_.backward()
                optimizer_.step()

#                 decode_out = torch.sigmoid(decode_out)
#                 print("decoder output", decode_out)
                if j == 50:
                    break
    torch.save(dcnet.state_dict(), PATH + "encoder_" + str(i))
    torch.save(dcnet_decode.state_dict(), PATH + "decoder_" + str(i))

In [20]:
class Args:
    config = "conf/train.yaml"
    debug = True
#     num_epoches = wandb_config.num_epoches
# args ={"config":"conf/train.yaml", "debug":False, "num_epoches":20} 
# print(args["debug"])
# wandb.agent(sweep_id, train)
train()

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
2020-02-12 14:23:54,572 [<ipython-input-19-50248903de17>:45 - INFO ] Start training in debug model
2020-02-12 14:23:54,581 [<ipython-input-19-50248903de17>:54 - INFO ] Training in per utterance
2020-02-12 14:23:54,586 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_mix.scp with 3000 utterances
2020-02-12 14:23:54,591 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_s1.scp with 3000 utterances
2020-02-12 14:23:54,596 [/home/ubuntu/deep-clustering/dataset.py:40 - INFO ] Create SpectrogramReader for ./data/2spk/test/wav8k_min_s2.scp with 3000 utterances


In [19]:
import torch
a = torch.zeros((24,20))

In [20]:
a = a.mean(dim=1)

In [21]:
a = a.unsqueeze(dim=0)
print(a.shape)
a = a.view(-1, 12)

torch.Size([1, 24])


In [73]:
os.getcwd() + '/weights/'

'/home/ubuntu/deep-clustering/weights'