In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pickle
import h5py
import seaborn as sns
import pandas as pd
from tensorboard.backend.event_processing import event_accumulator
from scipy.stats import pearsonr
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset
from data.NoisyDataset import NoisyDataset, FullNoisyDataset

# PNet parameters

In [2]:
from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
PNetClass = PBranchedNetwork_AllSeparateHP
pnet_name = 'pnet'
chckpt = 50

In [3]:
n_timesteps = 5
layers = ['conv1', 'conv2', 'conv3', 'conv4_W', 'fc6_W']

# Paths to relevant directories

In [4]:
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
activations_dir = f'{engram_dir}activations_pnet/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'

In [5]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

Device: cuda:0


# Helper functions to load network

In [6]:
def get_hyperparams(tf_dir, bg, snr, shared=False):
    if shared:
        raise ValueError('Not implemented for shared hyperparameters.')
        
    hyperparams = []
    tf_file_dir = f'{tf_dir}hyper_{bg}_snr{snr}/'
    for tf_file in os.listdir(tf_file_dir):
        tf_file = f'{tf_file_dir}{tf_file}'
        ea = event_accumulator.EventAccumulator(tf_file)
        ea.Reload()
        for i in range(1, 6):
            hps = {}
            ffm = ea.Scalars(f'Hyperparam/pcoder{i}_feedforward')[-1].value
            fbm = ea.Scalars(f'Hyperparam/pcoder{i}_feedback')[-1].value
            erm = ea.Scalars(f'Hyperparam/pcoder{i}_error')[-1].value
            hps['ffm'] = ffm
            hps['fbm'] = fbm
            hps['erm'] = erm
            hyperparams.append(hps)
        break
    return hyperparams

In [7]:
def load_pnet(PNetClass, pnet_name, chckpt, hyperparams=None):
    net = BranchedNetwork(track_encoder_representations=True)
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet = PNetClass(net, build_graph=False)
    pnet.load_state_dict(torch.load(
        f"{checkpoints_dir}{pnet_name}/{pnet_name}-{chckpt}-regular.pth",
        map_location='cpu'
        ))
    if hyperparams is not None:
        pnet.set_hyperparameters(hyperparams)
    pnet.to(DEVICE)
    pnet.eval();
    print(f'Loaded Pnet: {pnet_name}')
    print_hps(pnet)
    return pnet

In [8]:
def print_hps(pnet):
    for pc in range(pnet.number_of_pcoders):
        print (f"PCoder{pc+1} : ffm: {getattr(pnet,f'ffm{pc+1}'):0.3f} \t fbm: {getattr(pnet,f'fbm{pc+1}'):0.3f} \t erm: {getattr(pnet,f'erm{pc+1}'):0.3f}")

# Helper functions to save activations

In [9]:
n_units_per_layer = {
    'conv1': (96, 55, 134), 'conv2': (256, 14, 34),
    'conv3': (512, 7, 17), 'conv4_W': (1024, 7, 17),
    'conv5_W': (512, 7, 17), 'fc6_W': (4096,)
    }

In [10]:
def run_pnet(pnet, _input):
    pnet.reset()
    reconstructions = []
    activations = []
    logits = []
    output = []
    for t in range(n_timesteps):
        _input_t = _input if t == 0 else None
        logits_t, _ = pnet(_input_t)
        x = pnet.pcoder1.prd
        reconstructions.append(pnet.pcoder1.prd[0,0].cpu().numpy())
        activations.append(pnet.backbone.encoder_repr)
        logits.append(logits_t.cpu().numpy().squeeze())
        output.append(logits_t.max(-1)[1].item())
    return reconstructions, activations, logits, output

In [13]:
@torch.no_grad()
def save_activations(pnet, dset, hdf5_path):
    
    with h5py.File(hdf5_path, 'x') as f_out:
        data_dict = {}
        data_dict['label'] = f_out.create_dataset(
            'label', dset.n_data, dtype='float32'
            )
        data_dict['clean_correct'] = f_out.create_dataset(
            'clean_correct', dset.n_data, dtype='float32'
            )
        for layer_idx, layer in enumerate(layers):
            reconstr_dim = (dset.n_data, 164, 400)
            activ_dim = (dset.n_data,) + n_units_per_layer[layer]
            logit_dim = (dset.n_data, 531)
            for timestep in range(n_timesteps):
                data_dict[f'{layer}_{timestep}_activations'] = f_out.create_dataset(
                    f'{layer}_{timestep}_activations', activ_dim, dtype='float32'
                    )
                data_dict[f'{layer}_{timestep}_clean_activations'] = f_out.create_dataset(
                    f'{layer}_{timestep}_clean_activations', activ_dim, dtype='float32'
                    )
                if layer_idx == 0:
                    data_dict[f'{timestep}_reconstructions'] = f_out.create_dataset(
                        f'{timestep}_reconstructions', reconstr_dim, dtype='float32'
                        )
                    data_dict[f'{timestep}_logits'] = f_out.create_dataset(
                        f'{timestep}_logits', logit_dim, dtype='float32'
                        )
                    data_dict[f'{timestep}_output'] = f_out.create_dataset(
                        f'{timestep}_output', dset.n_data, dtype='float32'
                        )
                    data_dict[f'{timestep}_clean_logits'] = f_out.create_dataset(
                        f'{timestep}_clean_logits', logit_dim, dtype='float32'
                        )
                    data_dict[f'{timestep}_clean_output'] = f_out.create_dataset(
                        f'{timestep}_clean_output', dset.n_data, dtype='float32'
                        )
    
        for idx in range(dset.n_data):
            # Noisy input
            noisy_in, label = dset[idx]
            data_dict['label'][idx] = label
            noisy_in = noisy_in.to(DEVICE)
            reconstructions, activations, logits, output = run_pnet(pnet, noisy_in)
            for timestep in range(n_timesteps):
                for layer in layers:
                    data_dict[f'{layer}_{timestep}_activations'][idx] = \
                        activations[timestep][layer]
                data_dict[f'{timestep}_reconstructions'][idx] = \
                    reconstructions[timestep]
                data_dict[f'{timestep}_logits'][idx] = \
                    logits[timestep]
                data_dict[f'{timestep}_output'][idx] = output[timestep]

            # Clean input
            clean_in = torch.tensor(
                dset.clean_in[idx].reshape((1, 1, 164, 400))
                ).to(DEVICE)
            reconstructions, activations, logits, output = run_pnet(pnet, clean_in)
            data_dict['clean_correct'][idx] = label == output
            for timestep in range(n_timesteps):
                for layer in layers:
                    data_dict[f'{layer}_{timestep}_clean_activations'][idx] = \
                        activations[timestep][layer]
                data_dict[f'{timestep}_clean_logits'][idx] = \
                    logits[timestep]
                data_dict[f'{timestep}_clean_output'][idx] = output[timestep]
                    

# Run activation-saving functions

In [12]:
bgs = ['AudScene', 'Babble8Spkr']
snrs = [-9.0, -6.0, -3.0, 0.0, 3.0]
tf_dir = f'{tensorboard_dir}lr_0.01x/'

In [14]:
for snr in snrs:
    for bg in bgs:
        print(f'{bg}, SNR = {snr}')
        hdf5_path = f'{activations_dir}{bg}_snr{int(snr)}.hdf5'
        hyperparams = get_hyperparams(tf_dir, bg, snr)
        pnet = load_pnet(PNetClass, pnet_name, chckpt, hyperparams)
        dset = NoisyDataset(bg, snr)
        save_activations(pnet, dset, hdf5_path)

AudScene, SNR = -9.0




Loaded Pnet: pnet
PCoder1 : ffm: 0.055 	 fbm: 0.816 	 erm: 0.007
PCoder2 : ffm: 0.385 	 fbm: 0.244 	 erm: -0.000
PCoder3 : ffm: 0.163 	 fbm: 0.554 	 erm: 0.011
PCoder4 : ffm: 0.220 	 fbm: 0.420 	 erm: 0.021
PCoder5 : ffm: 0.148 	 fbm: 0.000 	 erm: 0.009


  "The default behavior for interpolate/upsample with float scale_factor changed "
  self.prediction_error  = nn.functional.mse_loss(self.prd, target)


Babble8Spkr, SNR = -9.0
Loaded Pnet: pnet
PCoder1 : ffm: 0.485 	 fbm: 0.309 	 erm: 0.017
PCoder2 : ffm: 0.074 	 fbm: 0.505 	 erm: -0.006
PCoder3 : ffm: 0.626 	 fbm: 0.168 	 erm: -0.009
PCoder4 : ffm: 0.485 	 fbm: 0.355 	 erm: 0.015
PCoder5 : ffm: 0.246 	 fbm: 0.000 	 erm: 0.015
AudScene, SNR = -6.0
Loaded Pnet: pnet
PCoder1 : ffm: 0.145 	 fbm: 0.680 	 erm: 0.010
PCoder2 : ffm: 0.565 	 fbm: 0.184 	 erm: 0.006
PCoder3 : ffm: 0.178 	 fbm: 0.362 	 erm: 0.002
PCoder4 : ffm: 0.150 	 fbm: 0.415 	 erm: 0.023
PCoder5 : ffm: 0.116 	 fbm: 0.000 	 erm: 0.012
Babble8Spkr, SNR = -6.0
Loaded Pnet: pnet
PCoder1 : ffm: 0.500 	 fbm: 0.159 	 erm: 0.021
PCoder2 : ffm: 0.226 	 fbm: 0.473 	 erm: 0.004
PCoder3 : ffm: 0.342 	 fbm: 0.450 	 erm: 0.005
PCoder4 : ffm: 0.333 	 fbm: 0.377 	 erm: 0.014
PCoder5 : ffm: 0.191 	 fbm: 0.000 	 erm: 0.012
AudScene, SNR = -3.0
Loaded Pnet: pnet
PCoder1 : ffm: 0.223 	 fbm: 0.555 	 erm: 0.003
PCoder2 : ffm: 0.651 	 fbm: 0.179 	 erm: 0.004
PCoder3 : ffm: 0.137 	 fbm: 0.355 	 e