In [1]:
import os
import sys
import numpy as np
import gc
import h5py
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset
from data.NoisyDataset import NoisyDataset, FullNoisyDataset, LargeNoisyDataset
from data.MergedNoisyDataset import MergedNoisyDataset

# Global configurations

In [2]:
# Main args
SAME_PARAM = False           # to use the same parameters for all pcoders or not
noise_types = ['Merged']

In [3]:
# Dataset configuration
snr_levels = [None]
BATCH_SIZE = 10
NUM_WORKERS = 2

In [4]:
# Other training params
EPOCH = 15
FF_START = True             # to start from feedforward initialization
MAX_TIMESTEP = 5

In [5]:
# Path names
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'

# Load network arguments

In [6]:
if SAME_PARAM:
    from pbranchednetwork_shared import PBranchedNetwork_SharedSameHP
    PNetClass = PBranchedNetwork_SharedSameHP
    pnet_name = 'all_noNulls'
    fb_state_dict_path = f'{checkpoints_dir}{pnet_name}/{pnet_name}-shared-50-regular.pth'
else:
    from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
    PNetClass = PBranchedNetwork_AllSeparateHP
    pnet_name = 'all_noNulls'
    fb_state_dict_path = f'{checkpoints_dir}{pnet_name}/{pnet_name}-50-regular.pth'
fb_state_dict = torch.load(fb_state_dict_path)

# Helper functions

In [7]:
def load_pnet(
        net, state_dict, build_graph, random_init,
        ff_multiplier, fb_multiplier, er_multiplier,
        same_param, device='cuda:0'):
    
    pnet = PNetClass(
        net, build_graph=build_graph, random_init=random_init,
        ff_multiplier=ff_multiplier, fb_multiplier=fb_multiplier, er_multiplier=er_multiplier
        )

    pnet.load_state_dict(state_dict)
    pnet.eval()
    pnet.to(device)
    return pnet

In [8]:
def evaluate(net, epoch, dataloader, timesteps, loss_function, writer=None, tag='Clean'):
    test_loss = np.zeros((timesteps+1,))
    correct   = np.zeros((timesteps+1,))
    for (images, labels) in dataloader:
        images = images.cuda()
        labels = labels.cuda()
        
        with torch.no_grad():
            for tt in range(timesteps+1):
                if tt == 0:
                    outputs, _ = net(images)
                else:
                    outputs, _ = net()
                
                loss = loss_function(outputs, labels)
                test_loss[tt] += loss.item()
                _, preds = outputs.max(1)
                correct[tt] += preds.eq(labels).sum()

    print()
    for tt in range(timesteps+1):
        test_loss[tt] /= len(dataloader.dataset)
        correct[tt] /= len(dataloader.dataset)
        print('Test set t = {:02d}: Average loss: {:.4f}, Accuracy: {:.4f}'.format(
            tt,
            test_loss[tt],
            correct[tt]
        ))
        if writer is not None:
            writer.add_scalar(
                f"{tag}Perf/Epoch#{epoch}",
                correct[tt], tt
                )
    print()


In [9]:
def train(net, epoch, dataloader, timesteps, loss_function, optimizer, writer=None):
    for batch_index, (images, labels) in enumerate(dataloader):
        net.reset()

        labels = labels.cuda()
        images = images.cuda()

        ttloss = np.zeros((timesteps+1))
        optimizer.zero_grad()

        for tt in range(timesteps+1):
            if tt == 0:
                outputs, _ = net(images)
                loss = loss_function(outputs, labels)
                ttloss[tt] = loss.item()
            else:
                outputs, _ = net()
                current_loss = loss_function(outputs, labels)
                ttloss[tt] = current_loss.item()
                loss += current_loss
        
        loss.backward()
        optimizer.step()
        net.update_hyperparameters()
            
        print(f"Training Epoch: {epoch} [{batch_index * BATCH_SIZE + len(images)}/{len(dataloader.dataset)}]\tLoss: {loss.item():0.4f}\tLR: {optimizer.param_groups[0]['lr']:0.6f}")
        for tt in range(timesteps+1):
            print(f'{ttloss[tt]:0.4f}\t', end='')
        print()
        if writer is not None:
            writer.add_scalar(
                f"TrainingLoss/CE", loss.item(),
                (epoch-1)*len(dataloader) + batch_index
                )


In [10]:
def log_hyper_parameters(net, epoch, sumwriter, same_param=True):
    if same_param:
        sumwriter.add_scalar(f"HyperparamRaw/feedforward", getattr(net,f'ff_part').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/feedback",    getattr(net,f'fb_part').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/error",       getattr(net,f'errorm').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/memory",      getattr(net,f'mem_part').item(), epoch)

        sumwriter.add_scalar(f"Hyperparam/feedforward", getattr(net,f'ffm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/feedback",    getattr(net,f'fbm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/error",       getattr(net,f'erm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/memory",      1-getattr(net,f'ffm').item()-getattr(net,f'fbm').item(), epoch)
    else:
        for i in range(1, net.number_of_pcoders+1):
            sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedforward", getattr(net,f'ffm{i}').item(), epoch)
            if i < net.number_of_pcoders:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedback", getattr(net,f'fbm{i}').item(), epoch)
            else:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedback", 0, epoch)
            sumwriter.add_scalar(f"Hyperparam/pcoder{i}_error", getattr(net,f'erm{i}').item(), epoch)
            if i < net.number_of_pcoders:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_memory",      1-getattr(net,f'ffm{i}').item()-getattr(net,f'fbm{i}').item(), epoch)
            else:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_memory",      1-getattr(net,f'ffm{i}').item(), epoch)


# Main hyperparameter optimization script

In [15]:
def train_and_eval(noise_type, snr_level):
    # Load clean and noisy data
#     clean_ds_path = f'{engram_dir}training_dataset_random_order.hdf5'
#     clean_ds = CleanSoundsDataset(clean_ds_path)
#     clean_loader = torch.utils.data.DataLoader(
#         clean_ds,  batch_size=BATCH_SIZE,
#         shuffle=False, drop_last=False, num_workers=NUM_WORKERS
#         )

    if noise_type == 'Merged':
        noisy_ds = MergedNoisyDataset(subset=0.9, train=True)
        noise_loader = torch.utils.data.DataLoader(
            noisy_ds,  batch_size=BATCH_SIZE,
            shuffle=True, drop_last=False,
            num_workers=NUM_WORKERS
            )
        eval_ds = MergedNoisyDataset(subset=0.9, train=False)
        eval_loader = torch.utils.data.DataLoader(
            eval_ds,  batch_size=BATCH_SIZE,
            shuffle=True, drop_last=False,
            num_workers=NUM_WORKERS
            )
    else:
        noisy_ds = NoisyDataset(bg=noise_type, snr=snr_level)
        noise_loader = torch.utils.data.DataLoader(
            noisy_ds,  batch_size=BATCH_SIZE,
            shuffle=True, drop_last=False,
            num_workers=NUM_WORKERS
            )
        eval_loader = noise_loader

    # Set up logs and network for training
    net_dir = f'hyper_{noise_type}_snr{snr_level}'
    if FF_START:
        net_dir += '_FFstart'
    if SAME_PARAM:
        net_dir += '_shared'

    sumwriter = SummaryWriter(f'{tensorboard_dir}{net_dir}')
    net = BranchedNetwork() # Load original network
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet_fw = load_pnet( # Load FF PNet
        net, fb_state_dict, build_graph=False, random_init=(not FF_START),
        ff_multiplier=1.0, fb_multiplier=0.0, er_multiplier=0.0,
        same_param=SAME_PARAM, device='cuda:0'
        )
    loss_function = torch.nn.CrossEntropyLoss()
    evaluate(
        pnet_fw, 0, eval_loader, 1,
        loss_function,
        writer=sumwriter, tag='FeedForward')
    del pnet_fw
    gc.collect()

    # Load PNet for hyperparameter optimization
    net = BranchedNetwork()
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet = load_pnet(
        net, fb_state_dict, build_graph=True, random_init=(not FF_START),
        ff_multiplier=0.33, fb_multiplier=0.33, er_multiplier=0.0,
        same_param=SAME_PARAM, device='cuda:0'
        )

    # Set up loss function and hyperparameters
    loss_function = torch.nn.CrossEntropyLoss()
    hyperparams = [*pnet.get_hyperparameters()]
    if SAME_PARAM:
        optimizer = torch.optim.Adam([
            {'params': hyperparams[:-1], 'lr':0.01},
            {'params': hyperparams[-1:], 'lr':0.0001}], weight_decay=0.00001)
    else:
        fffbmem_hp = []
        erm_hp = []
        for pc in range(pnet.number_of_pcoders):
            fffbmem_hp.extend(hyperparams[pc*4:pc*4+3])
            erm_hp.append(hyperparams[pc*4+3])
        optimizer = torch.optim.Adam([
            {'params': fffbmem_hp, 'lr':0.01},
            {'params': erm_hp, 'lr':0.0001}], weight_decay=0.00001)

    # Log initial hyperparameter and eval values
    log_hyper_parameters(pnet, 0, sumwriter, same_param=SAME_PARAM)
    hps = pnet.get_hyperparameters_values()
    print(hps)
    evaluate(
        pnet, 0, eval_loader,
        MAX_TIMESTEP, loss_function,
        writer=sumwriter, tag='Noisy'
        )

    # Run epochs
    for epoch in range(1, EPOCH+1):
        train(
            pnet, epoch, noise_loader,
            MAX_TIMESTEP, loss_function, optimizer,
            writer=sumwriter
            )
        log_hyper_parameters(pnet, epoch, sumwriter, same_param=SAME_PARAM)
        hps = pnet.get_hyperparameters_values()
        print(hps)

        evaluate(
            pnet, epoch, eval_loader,
            MAX_TIMESTEP, loss_function,
            writer=sumwriter, tag='Noisy'
            )
    # evaluate(
    #     pnet, epoch, clean_loader,
    #     timesteps=MAX_TIMESTEP, writer=sumwriter,
    #     tag='Clean'
    #     )
    sumwriter.close()

In [16]:
for noise_type in noise_types:
    for snr_level in snr_levels:
        print("=====================")
        print(f'{noise_type}, for SNR {snr_level}')
        print("=====================")
        train_and_eval(noise_type, snr_level)

Merged, for SNR None


  "The default behavior for interpolate/upsample with float scale_factor changed "



Test set t = 00: Average loss: 0.2486, Accuracy: 0.5334
Test set t = 01: Average loss: 0.2458, Accuracy: 0.5339

[0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.0, 0.699999988079071, 0.009999999776482582]

Test set t = 00: Average loss: 0.2486, Accuracy: 0.5334
Test set t = 01: Average loss: 0.2458, Accuracy: 0.5339
Test set t = 02: Average loss: 0.2472, Accuracy: 0.5275
Test set t = 03: Average loss: 0.2538, Accuracy: 0.5198
Test set t = 04: Average loss: 0.2651, Accuracy: 0.5069
Test set t = 05: Average loss: 0.2806, Accuracy: 0.4901

Training Epoch: 1 [10/67482]	Loss: 20.0702	LR: 0.010000
3.4022	3.3479	3.3174	3.3045	3.3261	3.3720	
Training Epoch: 1 [20/67482]	Loss: 13.2718	LR

Training Epoch: 1 [750/67482]	Loss: 12.6181	LR: 0.010000
1.9573	2.0002	2.0570	2.1221	2.1926	2.2889	
Training Epoch: 1 [760/67482]	Loss: 18.7988	LR: 0.010000
3.2288	3.1888	3.1422	3.1059	3.0757	3.0574	
Training Epoch: 1 [770/67482]	Loss: 29.0253	LR: 0.010000
5.0700	4.9658	4.8582	4.7644	4.7035	4.6634	
Training Epoch: 1 [780/67482]	Loss: 18.0230	LR: 0.010000
2.9432	2.9431	2.9630	3.0047	3.0536	3.1154	
Training Epoch: 1 [790/67482]	Loss: 21.8598	LR: 0.010000
3.8153	3.7408	3.6575	3.5888	3.5457	3.5117	
Training Epoch: 1 [800/67482]	Loss: 9.9412	LR: 0.010000
1.7320	1.6972	1.6724	1.6441	1.6122	1.5832	
Training Epoch: 1 [810/67482]	Loss: 8.9990	LR: 0.010000
1.4255	1.4484	1.4804	1.5130	1.5474	1.5842	
Training Epoch: 1 [820/67482]	Loss: 20.6061	LR: 0.010000
3.6741	3.5552	3.4451	3.3604	3.3030	3.2681	
Training Epoch: 1 [830/67482]	Loss: 16.1995	LR: 0.010000
2.6656	2.6526	2.6549	2.6846	2.7324	2.8095	
Training Epoch: 1 [840/67482]	Loss: 27.9691	LR: 0.010000
4.9111	4.8371	4.7187	4.5996	4.4953	4.4075	
Tr

Training Epoch: 1 [1570/67482]	Loss: 7.5142	LR: 0.010000
1.2551	1.2496	1.2493	1.2426	1.2484	1.2692	
Training Epoch: 1 [1580/67482]	Loss: 21.0257	LR: 0.010000
3.3507	3.3915	3.4596	3.5322	3.6058	3.6858	
Training Epoch: 1 [1590/67482]	Loss: 10.6563	LR: 0.010000
1.9468	1.8517	1.7634	1.7097	1.6914	1.6935	
Training Epoch: 1 [1600/67482]	Loss: 12.6422	LR: 0.010000
2.1151	2.1089	2.1090	2.1041	2.0981	2.1070	
Training Epoch: 1 [1610/67482]	Loss: 18.1171	LR: 0.010000
3.1299	3.0893	3.0368	2.9918	2.9500	2.9192	
Training Epoch: 1 [1620/67482]	Loss: 19.2151	LR: 0.010000
3.2117	3.1948	3.1909	3.1938	3.2067	3.2172	
Training Epoch: 1 [1630/67482]	Loss: 7.1461	LR: 0.010000
1.2481	1.2069	1.1754	1.1629	1.1661	1.1867	
Training Epoch: 1 [1640/67482]	Loss: 3.9977	LR: 0.010000
0.6494	0.6520	0.6570	0.6661	0.6795	0.6937	
Training Epoch: 1 [1650/67482]	Loss: 10.0742	LR: 0.010000
1.6857	1.6546	1.6404	1.6532	1.6902	1.7502	
Training Epoch: 1 [1660/67482]	Loss: 7.5649	LR: 0.010000
1.2513	1.2342	1.2309	1.2484	1.2828	1.

KeyboardInterrupt: 