In [1]:
import os
import sys
import numpy as np
import gc
import h5py
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset
from data.NoisyDataset import NoisyDataset, FullNoisyDataset

# Global configurations

In [2]:
#total training epoches
EPOCH = 15

SAME_PARAM = False           # to use the same parameters for all pcoders or not
FF_START = True             # to start from feedforward initialization
MAX_TIMESTEP = 5

In [4]:
# Dataset configuration
BATCH_SIZE = 10
NUM_WORKERS = 2
noise_types = ['AudScene', 'Babble8Spkr', 'SpeakerShapedNoise']
snr_levels = [-9., -6., -3.,  0.,  3.]

In [5]:
# Path names
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'

# Load network arguments

In [6]:
from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
PNetClass = PBranchedNetwork_AllSeparateHP
pnet_name = 'all'

In [7]:
fb_state_dict_path = f'{checkpoints_dir}{pnet_name}/{pnet_name}-50-regular.pth'
fb_state_dict = torch.load(fb_state_dict_path)

# Helper functions

In [8]:
def load_pnet(
        net, state_dict, build_graph, random_init,
        ff_multiplier, fb_multiplier, er_multiplier,
        same_param, device='cuda:0'):
    
    if same_param:
        raise Exception('Not implemented!')
    else:
        pnet = PNetClass(
            net, build_graph=build_graph, random_init=random_init,
            ff_multiplier=ff_multiplier, fb_multiplier=fb_multiplier, er_multiplier=er_multiplier
            )

    pnet.load_state_dict(state_dict)
    pnet.eval()
    pnet.to(device)
    return pnet

In [22]:
def evaluate(net, epoch, dataloader, timesteps, loss_function, writer=None, tag='Clean'):
    test_loss = np.zeros((timesteps+1,))
    correct   = np.zeros((timesteps+1,))
    for (images, labels) in dataloader:
        images = images.cuda()
        labels = labels.cuda()
        
        with torch.no_grad():
            for tt in range(timesteps+1):
                if tt == 0:
                    outputs, _ = net(images)
                else:
                    outputs, _ = net()
                
                loss = loss_function(outputs, labels)
                test_loss[tt] += loss.item()
                _, preds = outputs.max(1)
                correct[tt] += preds.eq(labels).sum()

    print()
    for tt in range(timesteps+1):
        test_loss[tt] /= len(dataloader.dataset)
        correct[tt] /= len(dataloader.dataset)
        print('Test set t = {:02d}: Average loss: {:.4f}, Accuracy: {:.4f}'.format(
            tt,
            test_loss[tt],
            correct[tt]
        ))
        if writer is not None:
            writer.add_scalar(
                f"{tag}Perf/Epoch#{epoch}",
                correct[tt], tt
                )
    print()


In [23]:
def train(net, epoch, dataloader, timesteps, loss_function, optimizer, writer=None):
    for batch_index, (images, labels) in enumerate(dataloader):
        net.reset()

        labels = labels.cuda()
        images = images.cuda()

        ttloss = np.zeros((timesteps+1))
        optimizer.zero_grad()

        for tt in range(timesteps+1):
            if tt == 0:
                outputs, _ = net(images)
                loss = loss_function(outputs, labels)
                ttloss[tt] = loss.item()
            else:
                outputs, _ = net()
                current_loss = loss_function(outputs, labels)
                ttloss[tt] = current_loss.item()
                loss += current_loss
        
        loss.backward()
        optimizer.step()
        net.update_hyperparameters()
            
        print(f"Training Epoch: {epoch} [{batch_index * 16 + len(images)}/{len(dataloader.dataset)}]\tLoss: {loss.item():0.4f}\tLR: {optimizer.param_groups[0]['lr']:0.6f}")
        for tt in range(timesteps+1):
            print(f'{ttloss[tt]:0.4f}\t', end='')
        print()
        if writer is not None:
            writer.add_scalar(
                f"TrainingLoss/CE", loss.item(),
                (epoch-1)*len(dataloader) + batch_index
                )


In [24]:
def log_hyper_parameters(net, epoch, sumwriter, same_param=True):
    if same_param:
        sumwriter.add_scalar(f"HyperparamRaw/feedforward", getattr(net,f'ff_part').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/feedback",    getattr(net,f'fb_part').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/error",       getattr(net,f'errorm').item(), epoch)
        sumwriter.add_scalar(f"HyperparamRaw/memory",      getattr(net,f'mem_part').item(), epoch)

        sumwriter.add_scalar(f"Hyperparam/feedforward", getattr(net,f'ffm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/feedback",    getattr(net,f'fbm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/error",       getattr(net,f'erm').item(), epoch)
        sumwriter.add_scalar(f"Hyperparam/memory",      1-getattr(net,f'ffm').item()-getattr(net,f'fbm').item(), epoch)
    else:
        for i in range(1, net.number_of_pcoders+1):
            sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedforward", getattr(net,f'ffm{i}').item(), epoch)
            if i < net.number_of_pcoders:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedback", getattr(net,f'fbm{i}').item(), epoch)
            else:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_feedback", 0, epoch)
            sumwriter.add_scalar(f"Hyperparam/pcoder{i}_error", getattr(net,f'erm{i}').item(), epoch)
            if i < net.number_of_pcoders:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_memory",      1-getattr(net,f'ffm{i}').item()-getattr(net,f'fbm{i}').item(), epoch)
            else:
                sumwriter.add_scalar(f"Hyperparam/pcoder{i}_memory",      1-getattr(net,f'ffm{i}').item(), epoch)


# Main hyperparameter optimization script

In [25]:
def train_and_eval(noise_type, snr_level):
    # Load clean and noisy data
    clean_ds_path = f'{engram_dir}training_dataset_random_order.hdf5'
    clean_ds = CleanSoundsDataset(clean_ds_path)
    clean_loader = torch.utils.data.DataLoader(
        clean_ds,  batch_size=BATCH_SIZE,
        shuffle=False, drop_last=False, num_workers=NUM_WORKERS
        )

    noisy_ds = NoisyDataset(bg=noise_type, snr=snr_level)
    noise_loader = torch.utils.data.DataLoader(
        noisy_ds,  batch_size=BATCH_SIZE,
        shuffle=True, drop_last=False,
        num_workers=NUM_WORKERS
        )

    # Set up logs and network for training
    net_dir = f'hyper_{noise_type}_snr{snr_level}'
    if FF_START:
        net_dir += '_FFstart'
    if SAME_PARAM:
        net_dir += '_shared'

    sumwriter = SummaryWriter(f'{tensorboard_dir}{net_dir}')
    net = BranchedNetwork() # Load original network
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet_fw = load_pnet( # Load FF PNet
        net, fb_state_dict, build_graph=False, random_init=(not FF_START),
        ff_multiplier=1.0, fb_multiplier=0.0, er_multiplier=0.0,
        same_param=SAME_PARAM, device='cuda:0'
        )
    loss_function = torch.nn.CrossEntropyLoss()
    evaluate(
        pnet_fw, 0, noise_loader, 1,
        loss_function,
        writer=sumwriter, tag='FeedForward')
    del pnet_fw
    gc.collect()

    # Load PNet for hyperparameter optimization
    net = BranchedNetwork()
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet = load_pnet(
        net, fb_state_dict, build_graph=True, random_init=(not FF_START),
        ff_multiplier=0.33, fb_multiplier=0.33, er_multiplier=0.0,
        same_param=SAME_PARAM, device='cuda:0'
        )

    # Set up loss function and hyperparameters
    loss_function = torch.nn.CrossEntropyLoss()
    hyperparams = [*pnet.get_hyperparameters()]
    if SAME_PARAM:
        optimizer = optim.Adam([
            {'params': hyperparams[:-1], 'lr':0.01},
            {'params': hyperparams[-1:], 'lr':0.0001}], weight_decay=0.00001)
    else:
        fffbmem_hp = []
        erm_hp = []
        for pc in range(pnet.number_of_pcoders):
            fffbmem_hp.extend(hyperparams[pc*4:pc*4+3])
            erm_hp.append(hyperparams[pc*4+3])
        optimizer = torch.optim.Adam([
            {'params': fffbmem_hp, 'lr':0.01},
            {'params': erm_hp, 'lr':0.0001}], weight_decay=0.00001)

    # Log initial hyperparameter and eval values
    log_hyper_parameters(pnet, 0, sumwriter, same_param=SAME_PARAM)
    hps = pnet.get_hyperparameters_values()
    print(hps)
    evaluate(
        pnet, 0, noise_loader,
        MAX_TIMESTEP, loss_function,
        writer=sumwriter, tag='Noisy'
        )

    # Run epochs
    for epoch in range(1, EPOCH+1):
        train(
            pnet, epoch, noise_loader,
            MAX_TIMESTEP, loss_function, optimizer,
            writer=sumwriter
            )
        log_hyper_parameters(pnet, epoch, sumwriter, same_param=SAME_PARAM)
        hps = pnet.get_hyperparameters_values()
        print(hps)

        evaluate(
            pnet, epoch, noise_loader,
            MAX_TIMESTEP, loss_function,
            writer=sumwriter, tag='Noisy'
            )
    # evaluate(
    #     pnet, epoch, clean_loader,
    #     timesteps=MAX_TIMESTEP, writer=sumwriter,
    #     tag='Clean'
    #     )
    sumwriter.close()

In [26]:
for noise_type in noise_types:
    for snr_level in snr_levels:
        print("=====================")
        print(f'{noise_type}, for SNR {snr_level}')
        print("=====================")
        train_and_eval(noise_type, snr_level)


Test set t = 00: Average loss: 0.5143, Accuracy: 0.1812
Test set t = 01: Average loss: 0.4853, Accuracy: 0.1758

[0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.0, 0.699999988079071, 0.009999999776482582]

Test set t = 00: Average loss: 0.5163, Accuracy: 0.1812
Test set t = 01: Average loss: 0.4876, Accuracy: 0.1758
Test set t = 02: Average loss: 0.4772, Accuracy: 0.1829
Test set t = 03: Average loss: 0.4891, Accuracy: 0.1723
Test set t = 04: Average loss: 0.5168, Accuracy: 0.1421
Test set t = 05: Average loss: 0.5555, Accuracy: 0.1208

Training Epoch: 1 [10/563]	Loss: 30.0502	LR: 0.010000
5.5910	5.0671	4.7328	4.7108	4.8560	5.0926	
Training Epoch: 1 [26/563]	Loss: 34.3668	LR: 0.

  self.prediction_error  = nn.functional.mse_loss(self.prd, target)



Test set t = 00: Average loss: 0.6229, Accuracy: 0.0636
Test set t = 01: Average loss: 0.5989, Accuracy: 0.0671

[0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.30000001192092896, 0.3999999761581421, 0.009999999776482582, 0.30000001192092896, 0.0, 0.699999988079071, 0.009999999776482582]

Test set t = 00: Average loss: 0.6225, Accuracy: 0.0636
Test set t = 01: Average loss: 0.5986, Accuracy: 0.0671
Test set t = 02: Average loss: 0.5848, Accuracy: 0.0866
Test set t = 03: Average loss: 0.5827, Accuracy: 0.0830
Test set t = 04: Average loss: 0.5870, Accuracy: 0.0919
Test set t = 05: Average loss: 0.5952, Accuracy: 0.0936

Training Epoch: 1 [10/566]	Loss: 33.5975	LR: 0.010000
5.7073	5.5811	5.5474	5.5442	5.5762	5.6413	
Training Epoch: 1 [26/566]	Loss: 31.5419	LR: 0.