In [1]:
cd /scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding

/scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding


In [2]:
LOAD = 1

import torch
from torch import optim as optim
import torch.utils.data
from tqdm import tqdm as tqdm
import numpy as np
import argparse, os, json, yaml
from networks import *
from model import Model
from dataset import *
from tracker import LossTracker
from utils.custom_adam import LREQAdam
from utils.checkpointer import Checkpointer
from utils.launcher import run
from utils.defaults import get_cfg_defaults
from utils.save import save_sample
device = "cuda" if torch.cuda.is_available() else "cpu"

cd /scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding
python train_e2a.py --OUTPUT_DIR output/resnet_NY869 --trainsubject NY869 --testsubject NY869 --param_file configs/e2a_production.yaml --batch_size 16 --MAPPING_FROM_ECOG ECoGMapping_ResNet --reshape 1 --DENSITY "LD" --wavebased 1 --dynamicfiltershape 0 --n_filter_samples 80 --n_fft 512 --formant_supervision 1  --intensity_thres -1 --epoch_num 60 --pretrained_model_dir output/a2a/NY869 --causal 0

In [3]:
pwd

'/scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding'

In [4]:
class Args:
  noise_db = -50
    
  opts = None
  config_file = 'configs/e2a_production.yaml'
  DENSITY = 'LD'
  wavebased = 1
  bgnoise_fromdata = 1
  ignore_loading = 0
  finetune = 0
  learnedmask = 0
  dynamicfiltershape = 0
  formant_supervision = 0
  pitch_supervision = 0
  intensity_supervision = 0
  n_filter_samples = 20
  n_fft = 512
  reverse_order = 1
  lar_cap = 0
  intensity_thres = -1
  RNN_COMPUTE_DB_LOUDNESS = 1
  BIDIRECTION = 1
  MAPPING_FROM_ECOG = 'ECoGMapping_ResNet'
  OUTPUT_DIR = 'output/resnet'
  COMPONENTKEY = ''
  trainsubject = 'NY869'
  testsubject = 'NY869'
  reshape = -1
  ld_loss_weight = 1
  alpha_loss_weight = 1
  consonant_loss_weight = 0
  batch_size = 8
  param_file = 'configs/e2a_production.yaml'
  pretrained_model_dir = 'output/a2a/NY869'
  causal = 0
  anticausal = 0
  rdropout = 0
  epoch_num = 100
  use_stoi = 0
  use_denoise = 0

args_ = Args()

In [5]:
with open("configs/AllSubjectInfo.json", "r") as rfile:
    allsubj_param = json.load(rfile)
with open(args_.param_file, 'r') as stream:
    param = yaml.safe_load(stream)
(
    ecog_all,
    wave_orig_all,
    x_orig_all,
    x_orig_amp_all,
    labels_all,
    gender_train_all,
    on_stage_all,
    on_stage_wider_all
) = ({}, {}, {}, {}, {}, {}, {}, {})

hann_win = torch.hann_window(21, periodic=False).reshape([1, 1, 21, 1])
hann_win = hann_win / hann_win.sum()

In [8]:
def get_train_data(
    ecog_all,
    wave_orig_all,
    x_orig_all,
    x_orig_amp_all,
    labels_all,
    gender_train_all,
    on_stage_all,
    on_stage_wider_all,
    sample_dict_train=None,
    subject=None,
):
    wave_orig_all[subject] = (
        sample_dict_train["wave_re_batch_all"].to(device).float()
    )
    gender_train_all[subject] =sample_dict_train['gender_all'].to(device).float()
    if cfg.MODEL.WAVE_BASED:
        x_orig_all[subject] = (
            sample_dict_train["wave_spec_re_batch_all"].to(device).float()
        )
        x_orig_amp_all[subject] = (
            sample_dict_train["wave_spec_re_amp_batch_all"].to(device).float()
        )
    on_stage_all[subject] = (
        sample_dict_train["on_stage_re_batch_all"].to(device).float()
    )
    on_stage_wider_all[subject] = (
        sample_dict_train["on_stage_wider_re_batch_all"].to(device).float()
    )
    labels_all[subject] = sample_dict_train["label_batch_all"]
    ecog_all[subject] = sample_dict_train["ecog_re_batch_all"].to(device).float()

    return (
        ecog_all,
        wave_orig_all,
        x_orig_all,
        x_orig_amp_all,
        labels_all,
        gender_train_all,
        on_stage_all,
        on_stage_wider_all,
        )

In [9]:
def load_model_checkpoint(
    logger,
    local_rank,
    distributed,
    tracker=None,
    tracker_test=None,
    dataset_all=None,
    subject="NY742",
    load_dir="",
    single_patient_mapping=0,param=None
):
    if args_.trainsubject != "":
        train_subject_info = args_.trainsubject.split(",")
    else:
        train_subject_info = cfg.DATASET.SUBJECT
    model = Model(
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        ecog_encoder_name=cfg.MODEL.MAPPING_FROM_ECOG,
        spec_chans=cfg.DATASET.SPEC_CHANS,
        n_formants=cfg.MODEL.N_FORMANTS,
        n_formants_noise=cfg.MODEL.N_FORMANTS_NOISE,
        n_formants_ecog=cfg.MODEL.N_FORMANTS_ECOG,
        wavebased=cfg.MODEL.WAVE_BASED,
        n_fft=cfg.MODEL.N_FFT,
        noise_db=cfg.MODEL.NOISE_DB,
        max_db=cfg.MODEL.MAX_DB,
        with_ecog=cfg.MODEL.ECOG,
        do_mel_guide=cfg.MODEL.DO_MEL_GUIDE,
        noise_from_data=cfg.MODEL.BGNOISE_FROMDATA and cfg.DATASET.PROD,
        specsup=cfg.FINETUNE.SPECSUP,
        power_synth=cfg.MODEL.POWER_SYNTH,
        apply_flooding=cfg.FINETUNE.APPLY_FLOODING,
        normed_mask=cfg.MODEL.NORMED_MASK,
        dummy_formant=cfg.MODEL.DUMMY_FORMANT,
        A2A=cfg.VISUAL.A2A,
        causal=cfg.MODEL.CAUSAL,
        anticausal=cfg.MODEL.ANTICAUSAL,
        pre_articulate=cfg.DATASET.PRE_ARTICULATE,
        alpha_sup=param["Subj"][subject][
            "AlphaSup"
        ],
        ld_loss_weight=cfg.MODEL.ld_loss_weight,
        alpha_loss_weight=cfg.MODEL.alpha_loss_weight,
        consonant_loss_weight=cfg.MODEL.consonant_loss_weight,
        component_regression=cfg.MODEL.component_regression,
        amp_formant_loss_weight=cfg.MODEL.amp_formant_loss_weight,
        freq_single_formant_loss_weight=cfg.MODEL.freq_single_formant_loss_weight,
        amp_minmax=cfg.MODEL.amp_minmax,
        amp_energy=cfg.MODEL.amp_energy,
        f0_midi=cfg.MODEL.f0_midi,
        alpha_db=cfg.MODEL.alpha_db,
        network_db=cfg.MODEL.network_db,
        consistency_loss=cfg.MODEL.consistency_loss,
        delta_time=cfg.MODEL.delta_time,
        delta_freq=cfg.MODEL.delta_freq,
        cumsum=cfg.MODEL.cumsum,
        distill=cfg.MODEL.distill,
        learned_mask=cfg.MODEL.LEARNED_MASK,
        n_filter_samples=cfg.MODEL.N_FILTER_SAMPLES,
        patient=subject,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        rdropout=cfg.MODEL.rdropout,
        dynamic_filter_shape=cfg.MODEL.DYNAMIC_FILTER_SHAPE,
        learnedbandwidth=cfg.MODEL.LEARNEDBANDWIDTH,
        gender_patient=allsubj_param["Subj"][train_subject_info[0]]["Gender"],
        reverse_order=args_.reverse_order,
        larger_capacity=args_.lar_cap,
        use_stoi=args_.use_stoi,
    )

    if torch.cuda.is_available():
        model.cuda(local_rank)
    model.train()

    model_s = Model(
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        ecog_encoder_name=cfg.MODEL.MAPPING_FROM_ECOG,
        spec_chans=cfg.DATASET.SPEC_CHANS,
        n_formants=cfg.MODEL.N_FORMANTS,
        n_formants_noise=cfg.MODEL.N_FORMANTS_NOISE,
        n_formants_ecog=cfg.MODEL.N_FORMANTS_ECOG,
        wavebased=cfg.MODEL.WAVE_BASED,
        n_fft=cfg.MODEL.N_FFT,
        noise_db=cfg.MODEL.NOISE_DB,
        max_db=cfg.MODEL.MAX_DB,
        with_ecog=cfg.MODEL.ECOG,
        do_mel_guide=cfg.MODEL.DO_MEL_GUIDE,
        noise_from_data=cfg.MODEL.BGNOISE_FROMDATA and cfg.DATASET.PROD,
        specsup=cfg.FINETUNE.SPECSUP,
        power_synth=cfg.MODEL.POWER_SYNTH,
        apply_flooding=cfg.FINETUNE.APPLY_FLOODING,
        normed_mask=cfg.MODEL.NORMED_MASK,
        dummy_formant=cfg.MODEL.DUMMY_FORMANT,
        A2A=cfg.VISUAL.A2A,
        causal=cfg.MODEL.CAUSAL,
        anticausal=cfg.MODEL.ANTICAUSAL,
        pre_articulate=cfg.DATASET.PRE_ARTICULATE,
        alpha_sup=param["Subj"][subject][
            "AlphaSup"
        ],
        ld_loss_weight=cfg.MODEL.ld_loss_weight,
        alpha_loss_weight=cfg.MODEL.alpha_loss_weight,
        consonant_loss_weight=cfg.MODEL.consonant_loss_weight,
        component_regression=cfg.MODEL.component_regression,
        amp_formant_loss_weight=cfg.MODEL.amp_formant_loss_weight,
        freq_single_formant_loss_weight=cfg.MODEL.freq_single_formant_loss_weight,
        amp_minmax=cfg.MODEL.amp_minmax,
        amp_energy=cfg.MODEL.amp_energy,
        f0_midi=cfg.MODEL.f0_midi,
        alpha_db=cfg.MODEL.alpha_db,
        network_db=cfg.MODEL.network_db,
        consistency_loss=cfg.MODEL.consistency_loss,
        delta_time=cfg.MODEL.delta_time,
        delta_freq=cfg.MODEL.delta_freq,
        cumsum=cfg.MODEL.cumsum,
        distill=cfg.MODEL.distill,
        learned_mask=cfg.MODEL.LEARNED_MASK,
        n_filter_samples=cfg.MODEL.N_FILTER_SAMPLES,
        patient=subject,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        rdropout=cfg.MODEL.rdropout,
        dynamic_filter_shape=cfg.MODEL.DYNAMIC_FILTER_SHAPE,
        learnedbandwidth=cfg.MODEL.LEARNEDBANDWIDTH,
        gender_patient=allsubj_param["Subj"][train_subject_info[0]]["Gender"],
        reverse_order=args_.reverse_order,
        larger_capacity=args_.lar_cap,
        use_stoi=args_.use_stoi,
    )
    if torch.cuda.is_available():
        model_s.cuda(local_rank)
    model_s.eval()
    model_s.requires_grad_(False)
    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True,
        )
        model.device_ids = None
        decoder = model.module.decoder
        encoder = model.module.encoder
        if hasattr(model.module, "ecog_encoder"):
            ecog_encoder = model.module.ecog_encoder
            if torch.cuda.is_available():
                ecog_encoder = ecog_encoder.cuda(local_rank)
            # ecog_encoder.performer.cuda(local_rank)
        if hasattr(model.module, "decoder_mel"):
            decoder_mel = model.module.decoder_mel
    else:
        decoder = model.decoder
        encoder = model.encoder
        if hasattr(model, "ecog_encoder"):
            ecog_encoder = model.ecog_encoder
            if torch.cuda.is_available():
                ecog_encoder = ecog_encoder.cuda(local_rank)
        if hasattr(model, "decoder_mel"):
            decoder_mel = model.decoder_mel
    logger.info("Trainable parameters generator:")
    logger.info("Trainable parameters discriminator:")
    arguments = dict()
    arguments["iteration"] = 0

    if cfg.MODEL.ECOG:
        if cfg.MODEL.SUPLOSS_ON_ECOGF:
            optimizer = LREQAdam(
                [{"params": ecog_encoder.parameters()}],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
        else:
            optimizer = LREQAdam(
                [
                    {"params": ecog_encoder.parameters()},
                    {"params": decoder.parameters()},
                ],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )

    else:
        if cfg.MODEL.DO_MEL_GUIDE:
            optimizer = LREQAdam(
                [
                    {"params": encoder.parameters()},
                    {"params": decoder.parameters()},
                    {"params": decoder_mel.parameters()},
                ],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
        else:
            optimizer = LREQAdam(
                [{"params": encoder.parameters()}, {"params": decoder.parameters()}],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
    model_dict = {
        "encoder": encoder,
        "generator": decoder,
    }
    if hasattr(model, "ecog_encoder"):
        model_dict["ecog_encoder"] = ecog_encoder
    if hasattr(model, "decoder_mel"):
        model_dict["decoder_mel"] = decoder_mel
    if local_rank == 0:
        model_dict["encoder_s"] = model_s.encoder.to(device)
        model_dict["generator_s"] = model_s.decoder.to(device)
        if hasattr(model_s, "ecog_encoder"):
            model_dict["ecog_encoder_s"] = model_s.ecog_encoder.to(device)
        if hasattr(model_s, "decoder_mel"):
            model_dict["decoder_mel_s"] = model_s.decoder_mel
    noise_dist = torch.from_numpy(dataset_all[subject].noise_dist).to(device).float()
    if cfg.MODEL.BGNOISE_FROMDATA:
        model_s.noise_dist_init(noise_dist)
        model.noise_dist_init(noise_dist)
    if cfg.MODEL.ECOG:
        if cfg.MODEL.SUPLOSS_ON_ECOGF:
            optimizer = LREQAdam(
                [{"params": ecog_encoder.parameters()}],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
        else:
            optimizer = LREQAdam(
                [
                    {"params": ecog_encoder.parameters()},
                    {"params": decoder.parameters()},
                ],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
    else:
        if cfg.MODEL.DO_MEL_GUIDE:
            optimizer = LREQAdam(
                [
                    {"params": encoder.parameters()},
                    {"params": decoder.parameters()},
                    {"params": decoder_mel.parameters()},
                ],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
        else:
            optimizer = LREQAdam(
                [
                    {"params": encoder.parameters()},
                    {"params": decoder.parameters()},
                ],
                lr=cfg.TRAIN.BASE_LEARNING_RATE,
                betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1),
                weight_decay=0,
            )
    tracker = LossTracker(cfg.OUTPUT_DIR)
    tracker_test = LossTracker(cfg.OUTPUT_DIR, test=True)
    auxiliary = {
        "optimizer": optimizer,
        "tracker": tracker,
        "tracker_test": tracker_test,
    }
    checkpointer = Checkpointer(
        cfg, model_dict, auxiliary, logger=logger, save=local_rank == 0
    )
    if LOAD:
        extra_checkpoint_data = checkpointer.load(
            ignore_last_checkpoint=True if LOAD!=0 else False,
            ignore_auxiliary=True,
            file_name=load_dir,
        )
        arguments.update(extra_checkpoint_data)
    return (
        checkpointer,
        model,
        model_s,
        encoder,
        decoder,
        ecog_encoder,
        optimizer,
        tracker,
        tracker_test,
    )

In [10]:
def train(cfg, logger, local_rank, world_size, distributed):
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
    with open('configs/train_param_production.json', 'r') as stream:
        param = json.load(stream)
    dataset_all, dataset_test_all = {}, {}
    train_subject_info = args_.trainsubject.split(",") if args_.trainsubject != "" else cfg.DATASET.SUBJECT 
    test_subject_info = args_.testsubject.split(",") if args_.testsubject != "" else cfg.DATASET.SUBJECT

    for subject in np.union1d(train_subject_info, test_subject_info):
        dataset_all[subject] = TFRecordsDataset(
            cfg,
            logger,
            rank=local_rank,
            world_size=world_size,
            SUBJECT=[subject],
            buffer_size_mb=1024,
            channels=cfg.MODEL.CHANNELS,
            param=param,
            allsubj_param=allsubj_param,
            ReshapeAsGrid=1,
            rearrange_elec=0,
            low_density=cfg.DATASET.DENSITY == "LD",
            process_ecog=True,
        )

    for subject in test_subject_info:
        dataset_test_all[subject] = TFRecordsDataset(
            cfg,
            logger,
            rank=local_rank,
            world_size=world_size,
            SUBJECT=[subject],
            buffer_size_mb=1024,
            channels=cfg.MODEL.CHANNELS,
            train=False,
            param=param,
            allsubj_param=allsubj_param,
            ReshapeAsGrid=1,
            rearrange_elec=0,
            low_density=cfg.DATASET.DENSITY == "LD",
            process_ecog=True,
        )
    tracker = LossTracker(cfg.OUTPUT_DIR)
    tracker_test = LossTracker(cfg.OUTPUT_DIR, test=True)

    (checkpointer_all,
        model_all,
        model_s_all,
        encoder_all,
        decoder_all,
        ecog_encoder_all,
        optimizer_all,
    ) = ({}, {}, {}, {}, {}, {}, {})

    for single_patient_mapping, subject in enumerate(
        np.union1d(train_subject_info, test_subject_info)
    ): 
        if args_.pretrained_model_dir != "":
            load_sub_dir = args_.pretrained_model_dir
            max_epoch = (
                np.array(
                    [
                        i.split('epoch')[1].split('.pth')[0]
                        for i in os.listdir(load_sub_dir)
                        if i.endswith("pth")
                    ]
                )
                .astype("int")
                .max()
            )
            load_sub_name = [i for i in load_sub_dir.split("/") if "NY" in i][0]
            print("subject, load_sub_name", subject, load_sub_name)
            load_sub_dir = load_sub_dir + "/{}_a2a_model_epoch{}.pth".format(
                   load_sub_name, max_epoch  
                )
            print("pretrained load dir", load_sub_dir)
        else:
            load_sub_dir = ''
            print ('No pretrainde a2a model provided!')
            #raise Exception("Please Provide pretrained_model_dir")
        (
            checkpointer_all[subject],
            model_all[subject],
            model_s_all[subject],
            encoder_all[subject],
            decoder_all[subject],
            ecog_encoder_all[subject],
            optimizer_all[subject],
            tracker,
            tracker_test,
        ) = load_model_checkpoint(
            logger,
            local_rank,
            distributed,
            tracker=tracker,
            tracker_test=tracker_test,
            dataset_all=dataset_all,
            subject=subject,
            load_dir=load_sub_dir,
            single_patient_mapping=single_patient_mapping,param=param
        )
    loadsub = train_subject_info[0]
    ecog_encoder_shared = ecog_encoder_all[loadsub]

    for single_patient_mapping, subject in enumerate(
        np.union1d(train_subject_info, test_subject_info)
    ):
        model_all[
            subject
        ].ecog_encoder = ecog_encoder_shared
        model_s_all[
            subject
        ].ecog_encoder = ecog_encoder_shared
    (   
        ecog_test_all,
        sample_wave_test_all,
        sample_spec_test_all,
        sample_spec_amp_test_all,
        sample_label_test_all,
        gender_test_all,
        on_stage_test_all,
        on_stage_wider_test_all,
    ) = (
        {},{},{},{},{},{},{},{})

    hann_win = torch.hann_window(21, periodic=False).reshape([1, 1, 21, 1])
    hann_win = hann_win / hann_win.sum()
    x_amp_from_denoise = False

    for subject in test_subject_info:
        dataset_test_all[subject].reset(
            cfg.DATASET.MAX_RESOLUTION_LEVEL, len(dataset_test_all[subject].dataset)
        )
        sample_dict_test = next(iter(dataset_test_all[subject].iterator))
        gender_test_all[subject] = sample_dict_test['gender_all'].to(device).float()
        if cfg.DATASET.PROD:
            sample_wave_test_all[subject] = (
                sample_dict_test["wave_re_batch_all"].to(device).float()
            )
            if cfg.MODEL.WAVE_BASED:
                sample_spec_test_all[subject] = (
                    sample_dict_test["wave_spec_re_batch_all"].to(device).float()
                )
                sample_spec_amp_test_all[subject] = (
                    sample_dict_test["wave_spec_re_amp_batch_all"]
                    .to(device)
                    .float()
                )
            sample_label_test_all[subject] = sample_dict_test["label_batch_all"]
            if cfg.MODEL.ECOG:
                ecog_test_all[subject] = sample_dict_test["ecog_re_batch_all"].to(device).float()
            on_stage_test_all[subject] = (
                sample_dict_test["on_stage_re_batch_all"].to(device).float()
            )
            on_stage_wider_test_all[subject] = (
                sample_dict_test["on_stage_wider_re_batch_all"].to(device).float()
            )
    duomask = True
    x_amp_from_denoise = False
    n_iter = 0

    (
        ecog_all,
        wave_orig_all,
        x_orig_all,
        x_orig_amp_all,
        labels_all,
        gender_train_all,
        on_stage_all,
        on_stage_wider_all
    ) = (
        {},{},{},{},{},{},{},{}
    )
    for epoch in tqdm(range(cfg.TRAIN.TRAIN_EPOCHS)):
        
        
        #train
        for subject in train_subject_info:
            model_all[subject].train()
        i = 0
        dataset_iterator_all = {}
        if len(train_subject_info) <= 1:
            dataset_iterator_all[train_subject_info[0]] = iter(
                dataset_all[train_subject_info[0]].iterator
            )
            sample_dict_train_all = {}
            for sample_dict_train_all[train_subject_info[0]] in tqdm(
                dataset_iterator_all[train_subject_info[0]]
            ):
                n_iter += 1
                i += 1
                for subject in train_subject_info:
                    if n_iter % 200 == 0:
                        print(tracker.register_means(n_iter))
                    (
                        ecog_all,
                        wave_orig_all,
                        x_orig_all,
                        x_orig_amp_all,
                        labels_all,
                        gender_train_all,
                        on_stage_all,
                        on_stage_wider_all
                    ) = get_train_data(
                        ecog_all,
                        wave_orig_all,
                        x_orig_all,
                        x_orig_amp_all,
                        labels_all,
                        gender_train_all,
                        on_stage_all,
                        on_stage_wider_all,
                        sample_dict_train_all[train_subject_info[0]],
                        subject=subject,
                    )
                    initial = None
                    
                    optimizer_all[subject].zero_grad()
                    Lrec, tracker = model_all[subject](
                        x_orig_all[subject],
                        ecog=ecog_all[subject],
                        on_stage=on_stage_all[subject],
                        on_stage_wider=on_stage_all[subject],
                        ae=False,
                        tracker=tracker,
                        encoder_guide=cfg.MODEL.W_SUP,
                        duomask=duomask,
                        x_amp=x_orig_amp_all[subject],
                        x_amp_from_denoise=x_amp_from_denoise,
                        gender=gender_train_all[subject],
                    )
                    (Lrec).backward()
                    optimizer_all[subject].step()

                    betta = 0.5 ** (cfg.TRAIN.BATCH_SIZE / (10 * 1000.0))
                    model_s_all[subject].lerp(
                        model_all[subject],
                        betta,
                        w_classifier=cfg.MODEL.W_CLASSIFIER,
                    )

        #test
        for subject in test_subject_info:
            print(
                2
                ** (
                    torch.tanh(
                        model_all[subject].encoder.formant_bandwitdh_slop
                    )
                )
            )
            print("save test result!")

            model_all[subject].eval()
            Lrec = model_all[subject](
                sample_spec_test_all[subject],
                x_denoise=None,
                x_mel=None,
                ecog=ecog_test_all[subject] if cfg.MODEL.ECOG else None,
                on_stage=on_stage_test_all[subject],
                ae=not cfg.MODEL.ECOG,
                tracker=tracker_test,
                encoder_guide=cfg.MODEL.W_SUP,
                pitch_aug=False,
                duomask=duomask,
                debug=False,
                x_amp=sample_spec_amp_test_all[subject],
                hamonic_bias=False,
                gender=gender_test_all[subject],
                on_stage_wider=on_stage_test_all[subject],
            )

            initial = None

            if epoch % 1 == 0:
                checkpointer_all[subject].save(
                    "model_epoch{}_{}".format(epoch, subject)
                )
                save_sample(
                    cfg,
                    sample_spec_test_all[subject],
                    ecog_test_all[subject],
                    encoder_all[subject],
                    decoder_all[subject],
                    ecog_encoder_shared
                    if hasattr(model_all[subject], "ecog_encoder")
                    else None,
                    encoder2
                    if hasattr(model_all[subject], "encoder2")
                    else None,
                    x_denoise=None,
                    decoder_mel=decoder_mel if cfg.MODEL.DO_MEL_GUIDE else None,
                    epoch=epoch,
                    label=sample_label_test_all[subject],
                    mode="test",
                    path=cfg.OUTPUT_DIR,
                    tracker=tracker_test,
                    linear=cfg.MODEL.WAVE_BASED,
                    n_fft=cfg.MODEL.N_FFT,
                    duomask=duomask,
                    x_amp=sample_spec_amp_test_all[subject],
                    gender=gender_test_all[subject],
                    sample_wave=sample_wave_test_all[subject],
                    sample_wave_denoise=None,
                    on_stage_wider=on_stage_test_all[subject],
                    auto_regressive=False,
                    seq_out_start=initial,
                    suffix=subject,
                )

        

In [11]:
gpu_count = torch.cuda.device_count()
cfg = get_cfg_defaults()
if args_.trainsubject != "":
    train_subject_info = args_.trainsubject.split(",")
else:
    train_subject_info = cfg.DATASET.SUBJECT
if args_.testsubject != "":
    test_subject_info = args_.testsubject.split(",")
else:
    test_subject_info = cfg.DATASET.SUBJECT
with open("configs/AllSubjectInfo.json", "r") as rfile:
    allsubj_param = json.load(rfile)
subj_param = allsubj_param["Subj"][args_.trainsubject.split(",")[0]]
Gender = subj_param["Gender"] if cfg.DATASET.PROD else "Female"
config_file = args_.param_file
cfg.merge_from_file(config_file)
args_.config_file = config_file

In [12]:
run(
    train,
    cfg,
    description="StyleGAN",
    default_config=config_file,
    world_size=gpu_count,
    args_=args_,
)

rank in _run 0
2023-06-11 00:17:42,576 logger INFO: <__main__.Args object at 0x14c7020cc850>
2023-06-11 00:17:42,579 logger INFO: World size: 1
2023-06-11 00:17:42,580 logger INFO: Loaded configuration file configs/e2a_production.yaml
TestNum_cum 1
ecog_alldataset 1
end_ind_re_valid_alldataset 1
formant_re_alldataset 1
intensity_re_alldataset 1
label_alldataset 1
noisesample_re_alldataset 1
pitch_re_alldataset 1
start_ind_re_valid_alldataset 1
wave_re_alldataset 1
wave_re_spec_alldataset 1
wave_re_spec_amp_alldataset 1
self.meta_data[ TestNum_cum s] [50]
dict_keys(['TestNum_cum', 'ecog_alldataset', 'end_ind_re_valid_alldataset', 'formant_re_alldataset', 'intensity_re_alldataset', 'label_alldataset', 'noisesample_re_alldataset', 'pitch_re_alldataset', 'start_ind_re_valid_alldataset', 'wave_re_alldataset', 'wave_re_spec_alldataset', 'wave_re_spec_amp_alldataset'])
self.ReshapeAsGrid:  1 ECoGMapping_ResNet
dict_keys(['TestNum_cum', 'ecog_alldataset', 'end_ind_re_valid_alldataset', 'forman

  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:45,  2.29s/it][A
  4%|▍         | 2/47 [00:04<01:38,  2.18s/it][A
  6%|▋         | 3/47 [00:06<01:33,  2.13s/it][A
  9%|▊         | 4/47 [00:08<01:33,  2.17s/it][A
 11%|█         | 5/47 [00:10<01:28,  2.10s/it][A
 13%|█▎        | 6/47 [00:12<01:24,  2.06s/it][A
 15%|█▍        | 7/47 [00:14<01:22,  2.06s/it][A
 17%|█▋        | 8/47 [00:16<01:18,  2.02s/it][A
 19%|█▉        | 9/47 [00:18<01:15,  1.99s/it][A
 21%|██▏       | 10/47 [00:20<01:11,  1.93s/it][A
 23%|██▎       | 11/47 [00:22<01:09,  1.94s/it][A
 26%|██▌       | 12/47 [00:24<01:08,  1.95s/it][A
 28%|██▊       | 13/47 [00:26<01:07,  1.99s/it][A
 30%|██▉       | 14/47 [00:28<01:05,  1.99s/it][A
 32%|███▏      | 15/47 [00:30<01:02,  1.96s/it][A
 34%|███▍      | 16/47 [00:32<01:01,  1.97s/it][A
 36%|███▌      | 17/47 [00:34<00:59,  1.98s/it][A
 38%|███▊      | 18/47 [00:36<00:56,  1.96s/it][A
 40%|████      | 19/47 [00:38<00:54,  1.96s/it][A
 43%|████▎     | 20/47 [00:40<00:55,  2

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:19:21,759 logger INFO: Saving checkpoint to output/resnet/model_epoch0_NY869.pth
registering means,n_iter,self.n_iters 0 []
0 Lae_a1 0.0
0 Lae_a_l21 0.0
0 Lae_db1 0.2677802
0 Lae_db_l21 0.267818
0 Lae_a2 0.0
0 Lae_a_l22 0.0
0 Lae_db2 0.26032233
0 Lae_db_l22 0.26033404
0 Lrec 21.124102
0 loudness_metric 0.015242103
0 loudness 4.00646
0 f0_metric 1.9358493
0 f0_hz 0.5807548
0 amplitudes_metric 0.013717631
0 amplitudes 6.700783
0 amplitude_formants_hamon_metric 0.0007394142
0 amplitude_formants_hamon 0.29576567
0 freq_formants_hamon_hz_metric_2 0.11886589
0 freq_formants_hamon_hz_metric_6 0.1609572
0 freq_formants_hamon 0.37046427
0 amplitude_formants_noise_metric 0.0019222703
0 amplitude_formants_noise 0.76890814
0 freq_formants_noise_metric 0.4208961
0 freq_formants_noise 1.9831544
0 bandwidth_formants_noise_hz_metric 2.548035
0 bandwidth_formants_noise_hz 7.644105
0 Ldiff 0.04359938
0 Lexp -0.16299137


  1%|          | 1/100 [03:34<5:54:23, 214.78s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:37,  2.12s/it][A
  4%|▍         | 2/47 [00:04<01:34,  2.09s/it][A
  6%|▋         | 3/47 [00:06<01:31,  2.07s/it][A
  9%|▊         | 4/47 [00:08<01:29,  2.08s/it][A
 11%|█         | 5/47 [00:10<01:26,  2.06s/it][A
 13%|█▎        | 6/47 [00:12<01:22,  2.02s/it][A
 15%|█▍        | 7/47 [00:14<01:19,  1.99s/it][A
 17%|█▋        | 8/47 [00:16<01:16,  1.95s/it][A
 19%|█▉        | 9/47 [00:17<01:12,  1.92s/it][A
 21%|██▏       | 10/47 [00:19<01:10,  1.91s/it][A
 23%|██▎       | 11/47 [00:21<01:09,  1.92s/it][A
 26%|██▌       | 12/47 [00:23<01:07,  1.92s/it][A
 28%|██▊       | 13/47 [00:25<01:05,  1.91s/it][A
 30%|██▉       | 14/47 [00:27<01:03,  1.94s/it][A
 32%|███▏      | 15/47 [00:29<01:01,  1.93s/it][A
 34%|███▍      | 16/47 [00:31<00:59,  1.93s/it][A
 36%|███▌      | 17/47 [00:33<00:58,  1.95s/it][A
 38%|███▊      | 18/47 [00:35<00:55,  1.92s/it][A
 40%|████      | 19/47 [00:37<00:52,  1.88s/it][A
 43%|████▎     | 20/47 [00:39<00:51,  1

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:22:53,088 logger INFO: Saving checkpoint to output/resnet/model_epoch1_NY869.pth
registering means,n_iter,self.n_iters 1 [0]
1 Lae_a1 0.0
1 Lae_a_l21 0.0
1 Lae_db1 0.26282498
1 Lae_db_l21 0.26286075
1 Lae_a2 0.0
1 Lae_a_l22 0.0
1 Lae_db2 0.25013015
1 Lae_db_l22 0.25014192
1 Lrec 20.518206
1 loudness_metric 0.015204156
1 loudness 3.9717407
1 f0_metric 0.32003784
1 f0_hz 0.096011356
1 amplitudes_metric 0.012421683
1 amplitudes 6.7395463
1 amplitude_formants_hamon_metric 0.000703727
1 amplitude_formants_hamon 0.2814908
1 freq_formants_hamon_hz_metric_2 0.08216901
1 freq_formants_hamon_hz_metric_6 0.11941582
1 freq_formants_hamon 0.24883237
1 amplitude_formants_noise_metric 0.0017236979
1 amplitude_formants_noise 0.6894792
1 freq_formants_noise_metric 0.3423767
1 freq_formants_noise 1.6101981
1 bandwidth_formants_noise_hz_metric 2.4412599
1 bandwidth_formants_noise_hz 7.3237796
1 Ldiff 0.03959763
1 Lexp -0.1582362


  2%|▏         | 2/100 [07:06<5:48:19, 213.26s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:39,  2.16s/it][A
  4%|▍         | 2/47 [00:04<01:30,  2.01s/it][A
  6%|▋         | 3/47 [00:05<01:24,  1.93s/it][A
  9%|▊         | 4/47 [00:07<01:22,  1.91s/it][A
 11%|█         | 5/47 [00:09<01:18,  1.88s/it][A
 13%|█▎        | 6/47 [00:11<01:17,  1.88s/it][A
 15%|█▍        | 7/47 [00:13<01:15,  1.89s/it][A
 17%|█▋        | 8/47 [00:15<01:13,  1.88s/it][A
 19%|█▉        | 9/47 [00:17<01:11,  1.89s/it][A
 21%|██▏       | 10/47 [00:18<01:09,  1.87s/it][A
 23%|██▎       | 11/47 [00:20<01:08,  1.90s/it][A
 26%|██▌       | 12/47 [00:22<01:06,  1.89s/it][A
 28%|██▊       | 13/47 [00:24<01:05,  1.93s/it][A
 30%|██▉       | 14/47 [00:26<01:02,  1.90s/it][A
 32%|███▏      | 15/47 [00:28<01:00,  1.90s/it][A
 34%|███▍      | 16/47 [00:30<01:00,  1.94s/it][A
 36%|███▌      | 17/47 [00:32<00:58,  1.94s/it][A
 38%|███▊      | 18/47 [00:34<00:55,  1.91s/it][A
 40%|████      | 19/47 [00:36<00:53,  1.92s/it][A
 43%|████▎     | 20/47 [00:38<00:51,  1

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:26:24,626 logger INFO: Saving checkpoint to output/resnet/model_epoch2_NY869.pth
registering means,n_iter,self.n_iters 2 [0, 1]
2 Lae_a1 0.0
2 Lae_a_l21 0.0
2 Lae_db1 0.25020963
2 Lae_db_l21 0.25024873
2 Lae_a2 0.0
2 Lae_a_l22 0.0
2 Lae_db2 0.2423929
2 Lae_db_l22 0.24240509
2 Lrec 19.704102
2 loudness_metric 0.014805636
2 loudness 3.8310122
2 f0_metric 0.07176988
2 f0_hz 0.021530963
2 amplitudes_metric 0.011415317
2 amplitudes 6.524787
2 amplitude_formants_hamon_metric 0.0006941898
2 amplitude_formants_hamon 0.27767593
2 freq_formants_hamon_hz_metric_2 0.056592673
2 freq_formants_hamon_hz_metric_6 0.0981735
2 freq_formants_hamon 0.1747631
2 amplitude_formants_noise_metric 0.001883906
2 amplitude_formants_noise 0.7535624
2 freq_formants_noise_metric 0.36421186
2 freq_formants_noise 1.7191356
2 bandwidth_formants_noise_hz_metric 2.4856262
2 bandwidth_formants_noise_hz 7.4568787
2 Ldiff 0.031988136
2 Lexp -0.152664

  3%|▎         | 3/100 [10:37<5:42:57, 212.14s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:41,  2.20s/it][A
  4%|▍         | 2/47 [00:04<01:31,  2.03s/it][A
  6%|▋         | 3/47 [00:06<01:27,  1.98s/it][A
  9%|▊         | 4/47 [00:07<01:24,  1.97s/it][A
 11%|█         | 5/47 [00:10<01:23,  2.00s/it][A
 13%|█▎        | 6/47 [00:11<01:19,  1.94s/it][A
 15%|█▍        | 7/47 [00:13<01:17,  1.93s/it][A
 17%|█▋        | 8/47 [00:15<01:14,  1.92s/it][A
 19%|█▉        | 9/47 [00:17<01:12,  1.92s/it][A
 21%|██▏       | 10/47 [00:19<01:10,  1.90s/it][A
 23%|██▎       | 11/47 [00:21<01:07,  1.87s/it][A
 26%|██▌       | 12/47 [00:23<01:06,  1.89s/it][A
 28%|██▊       | 13/47 [00:25<01:04,  1.91s/it][A
 30%|██▉       | 14/47 [00:27<01:03,  1.93s/it][A
 32%|███▏      | 15/47 [00:28<01:00,  1.88s/it][A
 34%|███▍      | 16/47 [00:30<00:57,  1.86s/it][A
 36%|███▌      | 17/47 [00:32<00:55,  1.85s/it][A
 38%|███▊      | 18/47 [00:34<00:55,  1.90s/it][A
 40%|████      | 19/47 [00:36<00:52,  1.89s/it][A
 43%|████▎     | 20/47 [00:38<00:50,  1

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:29:54,623 logger INFO: Saving checkpoint to output/resnet/model_epoch3_NY869.pth
registering means,n_iter,self.n_iters 3 [0, 1, 2]
3 Lae_a1 0.0
3 Lae_a_l21 0.0
3 Lae_db1 0.2529148
3 Lae_db_l21 0.25295347
3 Lae_a2 0.0
3 Lae_a_l22 0.0
3 Lae_db2 0.24246448
3 Lae_db_l22 0.2424764
3 Lrec 19.81517
3 loudness_metric 0.014596612
3 loudness 3.9722514
3 f0_metric 0.05654022
3 f0_hz 0.016962066
3 amplitudes_metric 0.011870064
3 amplitudes 6.6364255
3 amplitude_formants_hamon_metric 0.0006661497
3 amplitude_formants_hamon 0.26645988
3 freq_formants_hamon_hz_metric_2 0.06628768
3 freq_formants_hamon_hz_metric_6 0.09544511
3 freq_formants_hamon 0.16504097
3 amplitude_formants_noise_metric 0.0019245798
3 amplitude_formants_noise 0.76983196
3 freq_formants_noise_metric 0.3396586
3 freq_formants_noise 1.6041679
3 bandwidth_formants_noise_hz_metric 2.4651892
3 bandwidth_formants_noise_hz 7.395568
3 Ldiff 0.033016272
3 Lexp -0.156

  4%|▍         | 4/100 [14:08<5:38:29, 211.56s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:37,  2.11s/it][A
  4%|▍         | 2/47 [00:04<01:32,  2.06s/it][A
  6%|▋         | 3/47 [00:05<01:26,  1.96s/it][A
  9%|▊         | 4/47 [00:07<01:22,  1.92s/it][A
 11%|█         | 5/47 [00:09<01:19,  1.89s/it][A
 13%|█▎        | 6/47 [00:11<01:16,  1.88s/it][A
 15%|█▍        | 7/47 [00:13<01:14,  1.86s/it][A
 17%|█▋        | 8/47 [00:15<01:12,  1.87s/it][A
 19%|█▉        | 9/47 [00:17<01:10,  1.85s/it][A
 21%|██▏       | 10/47 [00:18<01:08,  1.84s/it][A
 23%|██▎       | 11/47 [00:20<01:06,  1.84s/it][A

registering means,n_iter,self.n_iters 200 []
200 Lae_a1 0.0
200 Lae_a_l21 0.0
200 Lae_db1 0.27113318
200 Lae_db_l21 0.2711629
200 Lae_a2 0.0
200 Lae_a_l22 0.0
200 Lae_db2 0.2693394
200 Lae_db_l22 0.26935026
200 Lrec 21.618902
200 loudness_metric 0.020677581
200 loudness 4.934557
200 f0_metric 1.1471866
200 f0_hz 0.344156
200 amplitudes_metric 0.015297581
200 amplitudes 2.818128
200 amplitude_formants_hamon_metric 0.0009018101
200 amplitude_formants_hamon 0.36072406
200 freq_formants_hamon_hz_metric_2 0.112008184
200 freq_formants_hamon_hz_metric_6 0.1562096
200 freq_formants_hamon 0.34439036
200 amplitude_formants_noise_metric 0.0022607953
200 amplitude_formants_noise 0.9043181
200 freq_formants_noise_metric 0.49177328
200 freq_formants_noise 2.3247125
200 bandwidth_formants_noise_hz_metric 2.889857
200 bandwidth_formants_noise_hz 8.669572
200 Ldiff 0.1407026
200 Lexp -0.28287488
None



 26%|██▌       | 12/47 [00:22<01:05,  1.87s/it][A
 28%|██▊       | 13/47 [00:24<01:02,  1.85s/it][A
 30%|██▉       | 14/47 [00:26<01:01,  1.87s/it][A
 32%|███▏      | 15/47 [00:28<00:59,  1.86s/it][A
 34%|███▍      | 16/47 [00:30<00:57,  1.86s/it][A
 36%|███▌      | 17/47 [00:31<00:55,  1.85s/it][A
 38%|███▊      | 18/47 [00:33<00:53,  1.86s/it][A
 40%|████      | 19/47 [00:35<00:51,  1.82s/it][A
 43%|████▎     | 20/47 [00:37<00:48,  1.81s/it][A
 45%|████▍     | 21/47 [00:39<00:46,  1.80s/it][A
 47%|████▋     | 22/47 [00:40<00:45,  1.82s/it][A
 49%|████▉     | 23/47 [00:42<00:43,  1.83s/it][A
 51%|█████     | 24/47 [00:44<00:41,  1.81s/it][A
 53%|█████▎    | 25/47 [00:46<00:40,  1.82s/it][A
 55%|█████▌    | 26/47 [00:48<00:38,  1.82s/it][A
 57%|█████▋    | 27/47 [00:50<00:36,  1.83s/it][A
 60%|█████▉    | 28/47 [00:51<00:34,  1.84s/it][A
 62%|██████▏   | 29/47 [00:53<00:33,  1.87s/it][A
 64%|██████▍   | 30/47 [00:55<00:31,  1.86s/it][A
 66%|██████▌   | 31/47 [00:57<

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:33:23,584 logger INFO: Saving checkpoint to output/resnet/model_epoch4_NY869.pth
registering means,n_iter,self.n_iters 4 [0, 1, 2, 3]
4 Lae_a1 0.0
4 Lae_a_l21 0.0
4 Lae_db1 0.25042203
4 Lae_db_l21 0.25045952
4 Lae_a2 0.0
4 Lae_a_l22 0.0
4 Lae_db2 0.23778373
4 Lae_db_l22 0.23779581
4 Lrec 19.528229
4 loudness_metric 0.013983068
4 loudness 3.9601338
4 f0_metric 0.06554264
4 f0_hz 0.019662792
4 amplitudes_metric 0.012523084
4 amplitudes 6.6537027
4 amplitude_formants_hamon_metric 0.00063714007
4 amplitude_formants_hamon 0.25485602
4 freq_formants_hamon_hz_metric_2 0.090070926
4 freq_formants_hamon_hz_metric_6 0.100258276
4 freq_formants_hamon 0.18413046
4 amplitude_formants_noise_metric 0.0018124974
4 amplitude_formants_noise 0.72499895
4 freq_formants_noise_metric 0.3656872
4 freq_formants_noise 1.7125032
4 bandwidth_formants_noise_hz_metric 2.3439436
4 bandwidth_formants_noise_hz 7.031831
4 Ldiff 0.029060617
4 Le

  5%|▌         | 5/100 [17:37<5:33:43, 210.78s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:58,  2.57s/it][A
  4%|▍         | 2/47 [00:04<01:48,  2.42s/it][A
  6%|▋         | 3/47 [00:07<01:45,  2.40s/it][A
  9%|▊         | 4/47 [00:09<01:39,  2.32s/it][A
 11%|█         | 5/47 [00:11<01:39,  2.36s/it][A
 13%|█▎        | 6/47 [00:14<01:35,  2.33s/it][A
 15%|█▍        | 7/47 [00:16<01:32,  2.32s/it][A
 17%|█▋        | 8/47 [00:18<01:30,  2.31s/it][A
 19%|█▉        | 9/47 [00:21<01:27,  2.29s/it][A
 21%|██▏       | 10/47 [00:23<01:23,  2.25s/it][A
 23%|██▎       | 11/47 [00:25<01:20,  2.24s/it][A
 26%|██▌       | 12/47 [00:27<01:18,  2.24s/it][A
 28%|██▊       | 13/47 [00:29<01:15,  2.22s/it][A
 30%|██▉       | 14/47 [00:31<01:11,  2.18s/it][A
 32%|███▏      | 15/47 [00:34<01:11,  2.23s/it][A
 34%|███▍      | 16/47 [00:36<01:07,  2.19s/it][A
 36%|███▌      | 17/47 [00:38<01:05,  2.19s/it][A
 38%|███▊      | 18/47 [00:40<01:03,  2.19s/it][A
 40%|████      | 19/47 [00:42<01:01,  2.18s/it][A
 43%|████▎     | 20/47 [00:45<00:58,  2

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:37:08,363 logger INFO: Saving checkpoint to output/resnet/model_epoch5_NY869.pth
registering means,n_iter,self.n_iters 5 [0, 1, 2, 3, 4]
5 Lae_a1 0.0
5 Lae_a_l21 0.0
5 Lae_db1 0.24447273
5 Lae_db_l21 0.24451733
5 Lae_a2 0.0
5 Lae_a_l22 0.0
5 Lae_db2 0.23439074
5 Lae_db_l22 0.23440348
5 Lrec 19.154537
5 loudness_metric 0.014511393
5 loudness 3.9464347
5 f0_metric 0.07050653
5 f0_hz 0.02115196
5 amplitudes_metric 0.012662956
5 amplitudes 6.6178403
5 amplitude_formants_hamon_metric 0.0006158994
5 amplitude_formants_hamon 0.24635975
5 freq_formants_hamon_hz_metric_2 0.09932348
5 freq_formants_hamon_hz_metric_6 0.09987568
5 freq_formants_hamon 0.19431394
5 amplitude_formants_noise_metric 0.0019648909
5 amplitude_formants_noise 0.7859563
5 freq_formants_noise_metric 0.3594472
5 freq_formants_noise 1.6902386
5 bandwidth_formants_noise_hz_metric 2.2135465
5 bandwidth_formants_noise_hz 6.6406393
5 Ldiff 0.028682545
5 Lex

  6%|▌         | 6/100 [21:21<5:37:14, 215.26s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:44,  2.26s/it][A
  4%|▍         | 2/47 [00:04<01:37,  2.16s/it][A
  6%|▋         | 3/47 [00:06<01:33,  2.13s/it][A
  9%|▊         | 4/47 [00:08<01:33,  2.16s/it][A
 11%|█         | 5/47 [00:10<01:28,  2.11s/it][A
 13%|█▎        | 6/47 [00:12<01:25,  2.08s/it][A
 15%|█▍        | 7/47 [00:14<01:24,  2.11s/it][A
 17%|█▋        | 8/47 [00:17<01:23,  2.13s/it][A
 19%|█▉        | 9/47 [00:19<01:20,  2.12s/it][A
 21%|██▏       | 10/47 [00:21<01:18,  2.12s/it][A
 23%|██▎       | 11/47 [00:23<01:16,  2.11s/it][A
 26%|██▌       | 12/47 [00:25<01:13,  2.11s/it][A
 28%|██▊       | 13/47 [00:27<01:13,  2.16s/it][A
 30%|██▉       | 14/47 [00:29<01:11,  2.17s/it][A
 32%|███▏      | 15/47 [00:32<01:08,  2.16s/it][A
 34%|███▍      | 16/47 [00:34<01:05,  2.12s/it][A
 36%|███▌      | 17/47 [00:36<01:03,  2.10s/it][A
 38%|███▊      | 18/47 [00:38<01:01,  2.11s/it][A
 40%|████      | 19/47 [00:40<00:58,  2.07s/it][A
 43%|████▎     | 20/47 [00:42<00:55,  2

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:40:46,226 logger INFO: Saving checkpoint to output/resnet/model_epoch6_NY869.pth
registering means,n_iter,self.n_iters 6 [0, 1, 2, 3, 4, 5]
6 Lae_a1 0.0
6 Lae_a_l21 0.0
6 Lae_db1 0.24395002
6 Lae_db_l21 0.24399738
6 Lae_a2 0.0
6 Lae_a_l22 0.0
6 Lae_db2 0.23053822
6 Lae_db_l22 0.23055123
6 Lrec 18.97953
6 loudness_metric 0.016198717
6 loudness 4.2241344
6 f0_metric 0.07860848
6 f0_hz 0.023582546
6 amplitudes_metric 0.011738509
6 amplitudes 6.5832634
6 amplitude_formants_hamon_metric 0.00063209364
6 amplitude_formants_hamon 0.25283745
6 freq_formants_hamon_hz_metric_2 0.101337716
6 freq_formants_hamon_hz_metric_6 0.09719759
6 freq_formants_hamon 0.20050831
6 amplitude_formants_noise_metric 0.0019563402
6 amplitude_formants_noise 0.7825361
6 freq_formants_noise_metric 0.37296078
6 freq_formants_noise 1.74068
6 bandwidth_formants_noise_hz_metric 2.2654195
6 bandwidth_formants_noise_hz 6.7962584
6 Ldiff 0.029629562
6

  7%|▋         | 7/100 [24:59<5:35:07, 216.21s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:48,  2.36s/it][A
  4%|▍         | 2/47 [00:04<01:39,  2.21s/it][A
  6%|▋         | 3/47 [00:06<01:36,  2.19s/it][A
  9%|▊         | 4/47 [00:08<01:29,  2.07s/it][A
 11%|█         | 5/47 [00:10<01:27,  2.09s/it][A
 13%|█▎        | 6/47 [00:12<01:22,  2.02s/it][A
 15%|█▍        | 7/47 [00:14<01:20,  2.02s/it][A
 17%|█▋        | 8/47 [00:16<01:17,  1.99s/it][A
 19%|█▉        | 9/47 [00:18<01:14,  1.96s/it][A
 21%|██▏       | 10/47 [00:20<01:10,  1.91s/it][A
 23%|██▎       | 11/47 [00:22<01:09,  1.92s/it][A
 26%|██▌       | 12/47 [00:24<01:08,  1.95s/it][A
 28%|██▊       | 13/47 [00:26<01:07,  1.98s/it][A
 30%|██▉       | 14/47 [00:28<01:04,  1.96s/it][A
 32%|███▏      | 15/47 [00:29<01:01,  1.91s/it][A
 34%|███▍      | 16/47 [00:31<00:58,  1.89s/it][A
 36%|███▌      | 17/47 [00:33<00:58,  1.93s/it][A
 38%|███▊      | 18/47 [00:35<00:55,  1.92s/it][A
 40%|████      | 19/47 [00:37<00:53,  1.92s/it][A
 43%|████▎     | 20/47 [00:39<00:52,  1

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:44:20,312 logger INFO: Saving checkpoint to output/resnet/model_epoch7_NY869.pth
registering means,n_iter,self.n_iters 7 [0, 1, 2, 3, 4, 5, 6]
7 Lae_a1 0.0
7 Lae_a_l21 0.0
7 Lae_db1 0.24728993
7 Lae_db_l21 0.24733184
7 Lae_a2 0.0
7 Lae_a_l22 0.0
7 Lae_db2 0.23472999
7 Lae_db_l22 0.23474239
7 Lrec 19.280796
7 loudness_metric 0.013411522
7 loudness 3.9291961
7 f0_metric 0.08220282
7 f0_hz 0.024660848
7 amplitudes_metric 0.009881485
7 amplitudes 6.021891
7 amplitude_formants_hamon_metric 0.0006018083
7 amplitude_formants_hamon 0.24072333
7 freq_formants_hamon_hz_metric_2 0.11517214
7 freq_formants_hamon_hz_metric_6 0.101912625
7 freq_formants_hamon 0.21743372
7 amplitude_formants_noise_metric 0.0019196235
7 amplitude_formants_noise 0.7678494
7 freq_formants_noise_metric 0.36800614
7 freq_formants_noise 1.7432581
7 bandwidth_formants_noise_hz_metric 2.4021375
7 bandwidth_formants_noise_hz 7.2064123
7 Ldiff 0.0252294

  8%|▊         | 8/100 [28:34<5:30:53, 215.80s/it]
  0%|          | 0/47 [00:00<?, ?it/s][A

length, self.TestNum_cum [50]
length, self.TestNum_cum [50]
length, self.TestNum_cum [50]



  2%|▏         | 1/47 [00:02<01:43,  2.25s/it][A
  4%|▍         | 2/47 [00:04<01:32,  2.05s/it][A
  6%|▋         | 3/47 [00:06<01:25,  1.95s/it][A
  9%|▊         | 4/47 [00:08<01:25,  2.00s/it][A
 11%|█         | 5/47 [00:10<01:23,  1.99s/it][A
 13%|█▎        | 6/47 [00:11<01:20,  1.95s/it][A
 15%|█▍        | 7/47 [00:13<01:18,  1.96s/it][A
 17%|█▋        | 8/47 [00:15<01:16,  1.96s/it][A
 19%|█▉        | 9/47 [00:17<01:14,  1.95s/it][A
 21%|██▏       | 10/47 [00:19<01:10,  1.91s/it][A
 23%|██▎       | 11/47 [00:21<01:07,  1.88s/it][A
 26%|██▌       | 12/47 [00:23<01:05,  1.88s/it][A
 28%|██▊       | 13/47 [00:25<01:03,  1.88s/it][A
 30%|██▉       | 14/47 [00:27<01:02,  1.89s/it][A
 32%|███▏      | 15/47 [00:29<01:01,  1.92s/it][A
 34%|███▍      | 16/47 [00:30<00:58,  1.87s/it][A
 36%|███▌      | 17/47 [00:32<00:55,  1.86s/it][A
 38%|███▊      | 18/47 [00:34<00:56,  1.94s/it][A
 40%|████      | 19/47 [00:36<00:54,  1.94s/it][A
 43%|████▎     | 20/47 [00:38<00:52,  1

registering means,n_iter,self.n_iters 400 [200]
400 Lae_a1 0.0
400 Lae_a_l21 0.0
400 Lae_db1 0.2463071
400 Lae_db_l21 0.24634786
400 Lae_a2 0.0
400 Lae_a_l22 0.0
400 Lae_db2 0.24559443
400 Lae_db_l22 0.24560662
400 Lrec 19.676062
400 loudness_metric 0.01755958
400 loudness 3.782256
400 f0_metric 0.25544336
400 f0_hz 0.07663301
400 amplitudes_metric 0.015976695
400 amplitudes 2.8186743
400 amplitude_formants_hamon_metric 0.00073728967
400 amplitude_formants_hamon 0.29491585
400 freq_formants_hamon_hz_metric_2 0.12903315
400 freq_formants_hamon_hz_metric_6 0.12295094
400 freq_formants_hamon 0.2576533
400 amplitude_formants_noise_metric 0.0023036078
400 amplitude_formants_noise 0.9214431
400 freq_formants_noise_metric 0.4097411
400 freq_formants_noise 1.932734
400 bandwidth_formants_noise_hz_metric 2.2368677
400 bandwidth_formants_noise_hz 6.710603
400 Ldiff 0.07741421
400 Lexp -0.28138924
None



 51%|█████     | 24/47 [00:46<00:45,  1.98s/it][A
 53%|█████▎    | 25/47 [00:48<00:43,  1.96s/it][A
 55%|█████▌    | 26/47 [00:50<00:41,  1.96s/it][A
 57%|█████▋    | 27/47 [00:52<00:38,  1.93s/it][A
 60%|█████▉    | 28/47 [00:54<00:36,  1.93s/it][A
 62%|██████▏   | 29/47 [00:56<00:34,  1.90s/it][A
 64%|██████▍   | 30/47 [00:58<00:32,  1.92s/it][A
 66%|██████▌   | 31/47 [01:00<00:31,  1.99s/it][A
 68%|██████▊   | 32/47 [01:01<00:28,  1.93s/it][A
 70%|███████   | 33/47 [01:04<00:28,  2.01s/it][A
 72%|███████▏  | 34/47 [01:05<00:25,  1.96s/it][A
 74%|███████▍  | 35/47 [01:07<00:22,  1.89s/it][A
 77%|███████▋  | 36/47 [01:09<00:20,  1.86s/it][A
 79%|███████▊  | 37/47 [01:11<00:19,  1.92s/it][A
 81%|████████  | 38/47 [01:13<00:17,  1.96s/it][A
 83%|████████▎ | 39/47 [01:15<00:15,  1.94s/it][A
 85%|████████▌ | 40/47 [01:17<00:13,  1.91s/it][A
 87%|████████▋ | 41/47 [01:19<00:11,  1.93s/it][A
 89%|████████▉ | 42/47 [01:21<00:09,  1.92s/it][A
 91%|█████████▏| 43/47 [01:23<

length, self.TestNum_cum [50]
tensor([1.])
save test result!
2023-06-11 00:47:53,555 logger INFO: Saving checkpoint to output/resnet/model_epoch8_NY869.pth
registering means,n_iter,self.n_iters 8 [0, 1, 2, 3, 4, 5, 6, 7]
8 Lae_a1 0.0
8 Lae_a_l21 0.0
8 Lae_db1 0.22700496
8 Lae_db_l21 0.22705537
8 Lae_a2 0.0
8 Lae_a_l22 0.0
8 Lae_db2 0.22505563
8 Lae_db_l22 0.22506937
8 Lrec 18.082424
8 loudness_metric 0.013084595
8 loudness 3.1863556
8 f0_metric 0.09715257
8 f0_hz 0.029145772
8 amplitudes_metric 0.011779965
8 amplitudes 6.414452
8 amplitude_formants_hamon_metric 0.00062535494
8 amplitude_formants_hamon 0.25014198
8 freq_formants_hamon_hz_metric_2 0.14478149
8 freq_formants_hamon_hz_metric_6 0.11209976
8 freq_formants_hamon 0.25550583
8 amplitude_formants_noise_metric 0.0021536024
8 amplitude_formants_noise 0.86144096
8 freq_formants_noise_metric 0.36376724
8 freq_formants_noise 1.7274042
8 bandwidth_formants_noise_hz_metric 2.2072618
8 bandwidth_formants_noise_hz 6.621785
8 Ldiff 0.0250

  8%|▊         | 8/100 [31:59<6:07:56, 239.97s/it]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/ext3/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-12-b13e55f54dc2>", line 7, in <module>
    args_=args_,
  File "/scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding/utils/launcher.py", line 174, in run
    _run(0, world_size, fn, defaults, write_log, no_cuda, args_)
  File "/scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding/utils/launcher.py", line 158, in _run
    fn(**matching_args_)
  File "<ipython-input-10-a02f4c11cfd6>", line 309, in train
    suffix=subject,
  File "/scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding/utils/save.py", line 981, in save_sample
    129,
  File "/scratch/xc1490/projects/ecog/ALAE_1023/neural_speech_decoding/utils/save.py", line 117, in mygriffinlim
    wave_gt[i] = griffinlim(msgram[i] ** 0.5, hop_length=hop_length)
  File "/ext3/miniconda3/lib/p

TypeError: object of type 'NoneType' has no len()

In [5]:
rm data/NY869.h5 

In [6]:
#!rsync ../data/data/LD_data_extracted/meta_data/NY869.h5 data/
#mv data/NY869.h5 data/NY869_full.h5
from tqdm.notebook import tqdm
meta_data = h5py.File('data/NY869_full.h5')
keys_to_store = ['ecog_alldataset','label_alldataset',\
'formant_re_alldataset','pitch_re_alldataset','intensity_re_alldataset',\
'start_ind_re_valid_alldataset','end_ind_re_valid_alldataset','wave_re_spec_alldataset',\
'wave_re_alldataset','wave_re_spec_amp_alldataset','noisesample_re_alldataset']
with h5py.File('data/NY869.h5' , 'w') as hf:
    for key in tqdm(keys_to_store):
        print (key)
        hf.create_dataset(key,  data= meta_data[key][:])

  0%|          | 0/11 [00:00<?, ?it/s]

ecog_alldataset
label_alldataset
formant_re_alldataset
pitch_re_alldataset
intensity_re_alldataset
start_ind_re_valid_alldataset
end_ind_re_valid_alldataset
wave_re_spec_alldataset
wave_re_alldataset
wave_re_spec_amp_alldataset
noisesample_re_alldataset
