In [1]:
import os
import sys
import numpy as np
import gc
import h5py
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset
from data.NoisyDataset import NoisyDataset, FullNoisyDataset

# Global configurations

In [2]:
#total training epoches
EPOCH = 10

SAME_PARAM = False           # to use the same parameters for all pcoders or not
FF_START = True             # to start from feedforward initialization
MAX_TIMESTEP = 5

In [3]:
# Dataset configuration
BATCH_SIZE = 10
NUM_WORKERS = 2
NOISE_TYPE = 'AudScene'
NOISE_SNR = 3

In [4]:
# Path names
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'

net_dir = 'hyper_audscene_snr3'
if FF_START:
    net_dir += '_FFstart'
if SAME_PARAM:
    net_dir += '_shared'

# Load network arguments

In [5]:
from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
PNetClass = PBranchedNetwork_AllSeparateHP
pnet_name = 'all'

In [6]:
fb_state_dict_path = f'{checkpoints_dir}{pnet_name}/{pnet_name}-50-regular.pth'
fb_state_dict = torch.load(fb_state_dict_path)

# Load clean and noisy data

In [7]:
clean_ds_path = f'{engram_dir}training_dataset_random_order.hdf5'
clean_ds = CleanSoundsDataset(clean_ds_path)
clean_loader = torch.utils.data.DataLoader(
    clean_ds,  batch_size=BATCH_SIZE,
    shuffle=False, drop_last=False, num_workers=NUM_WORKERS
    )

In [8]:
noisy_ds = NoisyDataset(bg=NOISE_TYPE, snr=NOISE_SNR)
noise_loader = torch.utils.data.DataLoader(
    noisy_ds,  batch_size=BATCH_SIZE,
    shuffle=True, drop_last=False,
    num_workers=NUM_WORKERS
    )

# Helper functions

In [9]:
def load_pnet(
        net, state_dict, build_graph, random_init,
        ff_multiplier, fb_multiplier, er_multiplier,
        same_param, device='cuda:0'):
    
    if same_param:
        raise Exception('Not implemented!')
    else:
        pnet = PNetClass(
            net, build_graph=build_graph, random_init=random_init,
            ff_multiplier=ff_multiplier, fb_multiplier=fb_multiplier, er_multiplier=er_multiplier
            )

    pnet.load_state_dict(state_dict)
    pnet.eval()
    pnet.to(device)
    return pnet

In [10]:
def evaluate(net, epoch, dataloader, timesteps, writer=None, tag='Clean'):
    test_loss = np.zeros((timesteps+1,))
    correct   = np.zeros((timesteps+1,))
    for (images, labels) in dataloader:
        images = images.cuda()
        labels = labels.cuda()
        
        with torch.no_grad():
            for tt in range(timesteps+1):
                if tt == 0:
                    outputs, _ = net(images)
                else:
                    outputs, _ = net()
                
                loss = loss_function(outputs, labels)
                test_loss[tt] += loss.item()
                _, preds = outputs.max(1)
                correct[tt] += preds.eq(labels).sum()

    print()
    for tt in range(timesteps+1):
        test_loss[tt] /= len(dataloader.dataset)
        correct[tt] /= len(dataloader.dataset)
        print('Test set t = {:02d}: Average loss: {:.4f}, Accuracy: {:.4f}'.format(
            tt,
            test_loss[tt],
            correct[tt]
        ))
        if writer is not None:
            writer.add_scalar(
                f"{tag}Perf/Epoch#{epoch}",
                correct[tt], tt
                )
    print()


In [11]:
def train(net, epoch, dataloader, timesteps, writer=None):
    for batch_index, (images, labels) in enumerate(dataloader):
        net.reset()

        labels = labels.cuda()
        images = images.cuda()

        ttloss = np.zeros((timesteps+1))
        optimizer.zero_grad()

        for tt in range(timesteps+1):
            if tt == 0:
                outputs, _ = net(images)
                loss = loss_function(outputs, labels)
                ttloss[tt] = loss.item()
            else:
                outputs, _ = net()
                current_loss = loss_function(outputs, labels)
                ttloss[tt] = current_loss.item()
                loss += current_loss
        
        loss.backward()
        optimizer.step()
        net.update_hyperparameters()
            
        print(f"Training Epoch: {epoch} [{batch_index * 16 + len(images)}/{len(dataloader.dataset)}]\tLoss: {loss.item():0.4f}\tLR: {optimizer.param_groups[0]['lr']:0.6f}")
        for tt in range(timesteps+1):
            print(f'{ttloss[tt]:0.4f}\t', end='')
        print()
        if writer is not None:
            writer.add_scalar(
                f"TrainingLoss/CE", loss.item(),
                (epoch-1)*len(dataloader) + batch_index
                )


In [12]:
def log_hyper_parameters(net, epoch, sumwriter, same_param=True):
    if same_param:
        sumwriter.add_scalar(f"HyperparamRaw/feedforward", getattr(net,f'ff_part').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/feedback",    getattr(net,f'fb_part').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/error",       getattr(net,f'errorm').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/memory",      getattr(net,f'mem_part').item(), epoch)

        sumwriter.add_scalar(f"Hyperparam/feedforward", getattr(net,f'ffm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/feedback",    getattr(net,f'fbm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/error",       getattr(net,f'erm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/memory",      1-getattr(net,f'ffm').item()-getattr(net,f'fbm').item(), epoch)
    else:
        for i in range(1, net.number_of_pcoders+1):
            sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedforward", getattr(net,f'ffm{i}').item(), epoch)
            if i < net.number_of_pcoders:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedback", getattr(net,f'fbm{i}').item(), epoch)
            else:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedback", 0, epoch)
            sumwriter.add_scalar(f"Hyperparam/pcoder{i}_error", getattr(net,f'erm{i}').item(), epoch)
            if i < net.number_of_pcoders:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_memory",      1-getattr(net,f'ffm{i}').item()-getattr(net,f'fbm{i}').item(), epoch)
            else:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_memory",      1-getattr(net,f'ffm{i}').item(), epoch)


# Set up logs and files for training

In [13]:
sumwriter = SummaryWriter(f'{tensorboard_dir}{net_dir}')

# Load original network
net = BranchedNetwork()
net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))

# Load FF PNet
pnet_fw = load_pnet(
    net, fb_state_dict, build_graph=False, random_init=(not FF_START),
    ff_multiplier=1.0, fb_multiplier=0.0, er_multiplier=0.0,
    same_param=SAME_PARAM, device='cuda:0'
    )

loss_function = torch.nn.CrossEntropyLoss()
evaluate(
    pnet_fw, 0, noise_loader, timesteps=1,
    writer=sumwriter, tag='FeedForward')
del pnet_fw
gc.collect()

  "The default behavior for interpolate/upsample with float scale_factor changed "
  self.prediction_error  = nn.functional.mse_loss(self.prd, target)
  self.prediction_error  = nn.functional.mse_loss(self.prd, target)



Test set t = 00: Average loss: 0.3199, Accuracy: 0.5520
Test set t = 01: Average loss: 0.2807, Accuracy: 0.5485



982

# Run training script for hyperparameters

In [14]:
# Load original network
net = BranchedNetwork()
net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))

# Load PNet with hyperparameters
pnet = load_pnet(
    net, fb_state_dict, build_graph=True, random_init=(not FF_START),
    ff_multiplier=0.33, fb_multiplier=0.33, er_multiplier=0.0,
    same_param=SAME_PARAM, device='cuda:0'
    )

# Set up loss function and hyperparameters
loss_function = torch.nn.CrossEntropyLoss()
hyperparams = [*pnet.get_hyperparameters()]
if SAME_PARAM:
    optimizer = optim.Adam([
        {'params': hyperparams[:-1], 'lr':0.01},
        {'params': hyperparams[-1:], 'lr':0.0001}], weight_decay=0.00001)
else:
    fffbmem_hp = []
    erm_hp = []
    for pc in range(pnet.number_of_pcoders):
        fffbmem_hp.extend(hyperparams[pc*4:pc*4+3])
        erm_hp.append(hyperparams[pc*4+3])
    optimizer = torch.optim.Adam([
        {'params': fffbmem_hp, 'lr':0.01},
        {'params': erm_hp, 'lr':0.0001}], weight_decay=0.00001)

# Log initial hyperparameter and eval values
log_hyper_parameters(pnet, 0, sumwriter, same_param=SAME_PARAM)
hps = pnet.get_hyperparameters_values()
print(hps)
evaluate(
    pnet, 0, noise_loader,
    timesteps=MAX_TIMESTEP, writer=sumwriter,
    tag='Noisy'
    )

# Run epochs
for epoch in range(1, EPOCH+1):
    train(
        pnet, epoch, noise_loader,
        timesteps=MAX_TIMESTEP, writer=sumwriter
        )
    log_hyper_parameters(pnet, epoch, sumwriter, same_param=SAME_PARAM)
    hps = pnet.get_hyperparameters_values()
    print(hps)

    evaluate(
        pnet, epoch, noise_loader,
        timesteps=MAX_TIMESTEP, writer=sumwriter,
        tag='Noisy'
        )
evaluate(
    pnet, epoch, clean_loader,
    timesteps=MAX_TIMESTEP, writer=sumwriter,
    tag='Clean'
    )
sumwriter.close()

[0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.0, 0.699999988079071, 0.009999999776482582]

Test set t = 00: Average loss: 0.3173, Accuracy: 0.5520
Test set t = 01: Average loss: 0.2786, Accuracy: 0.5485
Test set t = 02: Average loss: 0.2525, Accuracy: 0.5397
Test set t = 03: Average loss: 0.2559, Accuracy: 0.5115
Test set t = 04: Average loss: 0.2961, Accuracy: 0.4356
Test set t = 05: Average loss: 0.3597, Accuracy: 0.3122

Training Epoch: 1 [10/567]	Loss: 15.6441	LR: 0.010000
2.6925	2.4166	2.2594	2.4088	2.7410	3.1259	
Training Epoch: 1 [26/567]	Loss: 18.6096	LR: 0.010000
2.4786	2.4531	2.5728	2.9338	3.6694	4.5019	
Training Epoch: 1 [42/567]	Loss: 19.5718	LR: 0.010000
3.8044	3.3

Training Epoch: 2 [202/567]	Loss: 32.4499	LR: 0.010000
6.7204	6.1929	5.5751	5.0159	4.6092	4.3364	
Training Epoch: 2 [218/567]	Loss: 13.2903	LR: 0.010000
2.3873	2.2416	2.1141	2.0839	2.1566	2.3068	
Training Epoch: 2 [234/567]	Loss: 18.8558	LR: 0.010000
3.9015	3.5164	3.1312	2.8675	2.7300	2.7092	
Training Epoch: 2 [250/567]	Loss: 16.7343	LR: 0.010000
3.0398	2.8921	2.7122	2.6228	2.6707	2.7967	
Training Epoch: 2 [266/567]	Loss: 15.4668	LR: 0.010000
3.0659	2.8170	2.5577	2.3807	2.3024	2.3431	
Training Epoch: 2 [282/567]	Loss: 8.0748	LR: 0.010000
1.0585	1.0271	1.1037	1.2926	1.6046	1.9883	
Training Epoch: 2 [298/567]	Loss: 14.9866	LR: 0.010000
2.6323	2.4960	2.3877	2.3796	2.4525	2.6384	
Training Epoch: 2 [314/567]	Loss: 7.5732	LR: 0.010000
0.9877	1.0143	1.0846	1.2297	1.4682	1.7887	
Training Epoch: 2 [330/567]	Loss: 15.6200	LR: 0.010000
3.0035	2.8171	2.5793	2.4059	2.3559	2.4582	
Training Epoch: 2 [346/567]	Loss: 8.8130	LR: 0.010000
1.2427	1.2493	1.3151	1.4476	1.6461	1.9121	
Training Epoch: 2 [362/

Training Epoch: 3 [522/567]	Loss: 21.8353	LR: 0.010000
4.6471	4.1985	3.6969	3.2901	3.0523	2.9502	
Training Epoch: 3 [538/567]	Loss: 5.1670	LR: 0.010000
0.6672	0.6666	0.7050	0.8155	1.0114	1.3012	
Training Epoch: 3 [554/567]	Loss: 13.1157	LR: 0.010000
2.5038	2.2727	2.0801	2.0031	2.0628	2.1932	
Training Epoch: 3 [570/567]	Loss: 7.3844	LR: 0.010000
1.1937	1.1894	1.1770	1.1866	1.2561	1.3816	
Training Epoch: 3 [586/567]	Loss: 19.6589	LR: 0.010000
4.2143	3.7437	3.2552	2.9360	2.7670	2.7427	
Training Epoch: 3 [602/567]	Loss: 19.2584	LR: 0.010000
3.8686	3.6028	3.2633	2.9775	2.7979	2.7484	
Training Epoch: 3 [618/567]	Loss: 23.5781	LR: 0.010000
4.8502	4.4851	4.0381	3.6281	3.3593	3.2173	
Training Epoch: 3 [634/567]	Loss: 18.9232	LR: 0.010000
3.4542	3.1764	3.0561	3.0434	3.0741	3.1190	
Training Epoch: 3 [650/567]	Loss: 14.2267	LR: 0.010000
2.6237	2.4256	2.2247	2.1743	2.2810	2.4973	
Training Epoch: 3 [666/567]	Loss: 26.3749	LR: 0.010000
5.3027	4.9127	4.4409	4.0616	3.8514	3.8055	
Training Epoch: 3 [682

Training Epoch: 4 [842/567]	Loss: 16.7439	LR: 0.010000
2.8508	2.7176	2.6523	2.6785	2.8138	3.0309	
Training Epoch: 4 [858/567]	Loss: 32.5206	LR: 0.010000
7.1336	6.4337	5.6091	4.8592	4.3692	4.1159	
Training Epoch: 4 [874/567]	Loss: 22.5295	LR: 0.010000
4.5484	4.2251	3.8129	3.4756	3.2714	3.1962	
Training Epoch: 4 [890/567]	Loss: 13.8323	LR: 0.010000
2.5727	2.3808	2.2274	2.1474	2.1801	2.3240	
Training Epoch: 4 [903/567]	Loss: 26.0788	LR: 0.010000
5.0866	4.7583	4.4158	4.1109	3.8890	3.8183	
[0.3494298756122589, 0.35379189252853394, 0.29677823185920715, 0.012126980349421501, 0.4683662950992584, 0.2489951103925705, 0.2826385945081711, 0.013437588699162006, 0.44842904806137085, 0.20646457374095917, 0.34510637819767, 0.0003646443656180054, 0.3593849241733551, 0.2684766352176666, 0.37213844060897827, 0.0035985689610242844, 0.18999452888965607, 0.0, 0.8100054711103439, 0.004441494587808847]

Test set t = 00: Average loss: 0.3182, Accuracy: 0.5520
Test set t = 01: Average loss: 0.2919, Accuracy: 0.

Training Epoch: 6 [122/567]	Loss: 23.6785	LR: 0.010000
4.9776	4.5090	3.9778	3.5836	3.3487	3.2817	
Training Epoch: 6 [138/567]	Loss: 16.3311	LR: 0.010000
2.9148	2.6455	2.5033	2.5681	2.7372	2.9622	
Training Epoch: 6 [154/567]	Loss: 9.1145	LR: 0.010000
1.4940	1.3850	1.3498	1.4223	1.5945	1.8690	
Training Epoch: 6 [170/567]	Loss: 8.0244	LR: 0.010000
1.3692	1.2730	1.2000	1.2153	1.3611	1.6058	
Training Epoch: 6 [186/567]	Loss: 33.0393	LR: 0.010000
7.3950	6.5384	5.6091	4.8839	4.4348	4.1781	
Training Epoch: 6 [202/567]	Loss: 20.2575	LR: 0.010000
3.9290	3.5632	3.3498	3.1857	3.1112	3.1187	
Training Epoch: 6 [218/567]	Loss: 22.1575	LR: 0.010000
4.3969	4.0397	3.6990	3.4367	3.2844	3.3009	
Training Epoch: 6 [234/567]	Loss: 8.9955	LR: 0.010000
1.7890	1.5486	1.3726	1.3269	1.3957	1.5626	
Training Epoch: 6 [250/567]	Loss: 8.7914	LR: 0.010000
1.2134	1.1767	1.2659	1.4391	1.6986	1.9976	
Training Epoch: 6 [266/567]	Loss: 10.2953	LR: 0.010000
1.5708	1.5557	1.6175	1.7217	1.8385	1.9911	
Training Epoch: 6 [282/5

Training Epoch: 7 [442/567]	Loss: 21.0196	LR: 0.010000
4.3577	4.0045	3.5989	3.2564	2.9998	2.8023	
Training Epoch: 7 [458/567]	Loss: 32.8525	LR: 0.010000
7.7095	6.7181	5.7205	4.8455	4.1599	3.6989	
Training Epoch: 7 [474/567]	Loss: 22.9620	LR: 0.010000
4.5197	4.1058	3.7347	3.5214	3.4936	3.5867	
Training Epoch: 7 [490/567]	Loss: 11.0449	LR: 0.010000
1.6396	1.6367	1.6863	1.8185	2.0115	2.2522	
Training Epoch: 7 [506/567]	Loss: 11.5072	LR: 0.010000
1.9837	1.8914	1.8457	1.8456	1.9107	2.0301	
Training Epoch: 7 [522/567]	Loss: 16.1328	LR: 0.010000
3.3116	3.0226	2.6817	2.4367	2.3327	2.3475	
Training Epoch: 7 [538/567]	Loss: 14.0621	LR: 0.010000
2.1862	2.1571	2.1876	2.3002	2.4887	2.7424	
Training Epoch: 7 [554/567]	Loss: 31.4773	LR: 0.010000
6.8432	6.0740	5.3234	4.7362	4.3617	4.1387	
Training Epoch: 7 [570/567]	Loss: 24.7553	LR: 0.010000
5.5042	4.8624	4.2222	3.6887	3.3327	3.1452	
Training Epoch: 7 [586/567]	Loss: 10.4458	LR: 0.010000
1.8275	1.7243	1.6414	1.6434	1.7312	1.8781	
Training Epoch: 7 [6

Training Epoch: 8 [762/567]	Loss: 31.9389	LR: 0.010000
7.2643	6.3544	5.4138	4.7056	4.2474	3.9533	
Training Epoch: 8 [778/567]	Loss: 14.7420	LR: 0.010000
2.5878	2.4343	2.3613	2.3432	2.4210	2.5945	
Training Epoch: 8 [794/567]	Loss: 7.7950	LR: 0.010000
1.0097	1.0284	1.0988	1.2689	1.5273	1.8619	
Training Epoch: 8 [810/567]	Loss: 10.4473	LR: 0.010000
1.6667	1.6259	1.6208	1.7030	1.8235	2.0073	
Training Epoch: 8 [826/567]	Loss: 11.0622	LR: 0.010000
1.9902	1.8753	1.7922	1.7421	1.7679	1.8946	
Training Epoch: 8 [842/567]	Loss: 11.4687	LR: 0.010000
1.5149	1.5697	1.7089	1.9431	2.2163	2.5157	
Training Epoch: 8 [858/567]	Loss: 19.1759	LR: 0.010000
3.9756	3.5525	3.1752	2.9242	2.7731	2.7753	
Training Epoch: 8 [874/567]	Loss: 16.6193	LR: 0.010000
3.4229	3.0756	2.7674	2.5178	2.4090	2.4266	
Training Epoch: 8 [890/567]	Loss: 12.9327	LR: 0.010000
2.0723	2.0013	1.9953	2.0808	2.2547	2.5282	
Training Epoch: 8 [903/567]	Loss: 8.7241	LR: 0.010000
1.5855	1.4336	1.3378	1.3320	1.4264	1.6087	
[0.32534468173980713, 

Training Epoch: 10 [42/567]	Loss: 31.9217	LR: 0.010000
6.6724	5.9229	5.3081	4.8861	4.6266	4.5057	
Training Epoch: 10 [58/567]	Loss: 16.5814	LR: 0.010000
3.4764	3.1210	2.7242	2.4724	2.3661	2.4213	
Training Epoch: 10 [74/567]	Loss: 19.2602	LR: 0.010000
4.0364	3.5621	3.1377	2.9071	2.7959	2.8211	
Training Epoch: 10 [90/567]	Loss: 28.3567	LR: 0.010000
6.5535	5.7170	4.8779	4.1757	3.6785	3.3541	
Training Epoch: 10 [106/567]	Loss: 33.6577	LR: 0.010000
7.4906	6.6256	5.7074	5.0190	4.5405	4.2746	
Training Epoch: 10 [122/567]	Loss: 10.4049	LR: 0.010000
1.9087	1.7643	1.6173	1.5745	1.6666	1.8735	
Training Epoch: 10 [138/567]	Loss: 17.6404	LR: 0.010000
3.1366	2.9428	2.8489	2.8327	2.8916	2.9878	
Training Epoch: 10 [154/567]	Loss: 11.4782	LR: 0.010000
2.4179	2.1189	1.8381	1.6727	1.6681	1.7624	
Training Epoch: 10 [170/567]	Loss: 8.9545	LR: 0.010000
1.4574	1.3921	1.3769	1.4273	1.5495	1.7514	
Training Epoch: 10 [186/567]	Loss: 20.7433	LR: 0.010000
4.6400	4.0335	3.4502	3.0150	2.8227	2.7818	
Training Epoch:

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.