In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pickle
import h5py
import seaborn as sns
import pandas as pd
from scipy.stats import pearsonr
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset



# Choose which network you're running

In [2]:
args = []

In [3]:
from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
PNetClass = PBranchedNetwork_AllSeparateHP
pnet_name = 'all'
p_layers = 'All Layers'
chckpt = 25
args.append((PNetClass, pnet_name, p_layers, chckpt))

In [4]:
from pbranchednetwork_a1 import PBranchedNetwork_A1SeparateHP
PNetClass = PBranchedNetwork_A1SeparateHP
pnet_name = 'a1'
p_layers = 'Layers 1-3'
chckpt = 50
args.append((PNetClass, pnet_name, p_layers, chckpt))

# Parameters

In [6]:
engram_dir = '/mnt/smb/locker/issa-locker/users/Erica/'

In [7]:
fig_dir = f'{engram_dir}hcnn/figures/'
pickle_dir = f'{engram_dir}hcnn/pickles/'
activations_dir = f'{engram_dir}hcnn/activations/'

In [8]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

checkpoints_dir = f'{engram_dir}hcnn/checkpoints/'
tensorboard_dir = f'{engram_dir}hcnn/tensorboard/'

Device: cuda:0


# Helper functions

In [9]:
def load_pnet(PNetClass, pnet_name, chckpt):
    net = BranchedNetwork(track_encoder_representations=True)
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet = PNetClass(net, build_graph=False)
    pnet.load_state_dict(torch.load(
        f"{checkpoints_dir}{pnet_name}/{pnet_name}-{chckpt}-regular.pth",
        map_location='cpu'
        ))
    pnet.to(DEVICE)
    pnet.eval();
    print(f'Loaded Pnet: {pnet_name}')
    print_hps(pnet)
    return pnet

In [10]:
def print_hps(pnet):
    for pc in range(pnet.number_of_pcoders):
        print (f"PCoder{pc+1} : ffm: {getattr(pnet,f'ffm{pc+1}'):0.3f} \t fbm: {getattr(pnet,f'fbm{pc+1}'):0.3f} \t erm: {getattr(pnet,f'erm{pc+1}'):0.3f}")

# Load PsychoPhysics Dataset

In [11]:
f_in = h5py.File(f"{engram_dir}PsychophysicsWord2017W_not_resampled.hdf5", 'r')

In [12]:
f_metadata = np.load(f"{engram_dir}PsychophysicsWord2017W_999c6fc475be1e82e114ab9865aa5459e4fd329d.__META.npy", 'r')

In [13]:
f_key = np.load(f"{engram_dir}PsychophysicsWord2017W_999c6fc475be1e82e114ab9865aa5459e4fd329d.__META_key.npy", 'r')

In [14]:
with open(f"{engram_dir}PsychophysicsWord2017W_net_performance.p", 'rb') as f:
    net_mistakes = pickle.load(f)['net_mistakes']

In [15]:
def getPsychophysics2017WCleanCochleagrams():
    
    cochleagrams_clean = []
   
    cochleagrams = []
    for batch_ii in range(0,15300,100):
        hdf5_path = '/mnt/smb/locker/issa-locker/users/Erica/cgrams_for_noise_robustness_analysis/PsychophysicsWord2017W_clean/batch_'+str(batch_ii)+'_to_'+str(batch_ii+100)+'.hdf5'
        with h5py.File(hdf5_path, 'r') as f_in:
            cochleagrams += list(f_in['data'])

    return cochleagrams
clean_in = getPsychophysics2017WCleanCochleagrams()

In [16]:
labels = []
for word in f_metadata['word']:
    idx = np.argwhere(f_key == word)
    if len(idx) == 0:
        labels.append(-1)
    else:
        labels.append(idx.item())
labels = np.array(labels)
labels += 1

In [17]:
bg = []
for _bg in f_metadata['bg']:
    bg.append(str(_bg, 'utf-8'))
bg = np.array(bg)

In [18]:
snr = []
for _snr in f_metadata['snr']:
    _snr = str(_snr, 'utf-8')
    if 'inf' in _snr:
        _snr = np.inf
    elif 'neg' in _snr:
        if '3' in _snr:
            _snr = -3
        elif '6' in _snr:
            _snr = -6
        elif '9' in _snr:
            _snr = -9
        else:
            raise ValueError('Not found')
    else:
        if '0' in _snr:
            _snr = 0
        elif '3' in _snr:
            _snr = 3
        else:
            raise ValueError('Not found')
    snr.append(_snr)
snr = np.array(snr)

In [19]:
orig_dset = []
for _orig_dset in f_metadata['orig_dset']:
    _orig_dset = str(_orig_dset, 'utf-8')
    _orig_dset = 'WSJ' if 'WSJ' in _orig_dset else 'Timit'
    orig_dset.append(_orig_dset)
orig_dset = np.array(orig_dset)

# Save network activations

In [20]:
# This is bad practice! But the warnings are real annoying
import warnings
warnings.filterwarnings("ignore")

In [21]:
exclude_timit = True
bg_types = ['AudScene'] #, 'Babble8Spkr']

In [22]:
@torch.no_grad()
def save_activations(
    pnet, exclude_timit, bg_types, idx_range
    ):
    
    timesteps = []
    distances = []
    snrs = []
    splits = []
    encodetypes = []
    pred_accs = []
    all_results = {}
    n_timesteps = 5
    
    for idx in idx_range:
        idx_results = {}
        
        # Exclusion criteria
        if exclude_timit and orig_dset[idx] != 'WSJ':
            continue
        if bg[idx] not in bg_types:
            continue
        if snr[idx] not in [-3]:
            continue
        
        # Clean input
        clean_input = clean_in[idx]
        clean_input = torch.tensor(
            clean_input.reshape((1, 1, 164, 400))).clone()
        clean_input = clean_input.to(DEVICE)
        
        # Activations with clean input
        pnet.reset()
        logits, _ = pnet(clean_input)
        clean_output = logits.max(-1)[1].item()
        clean_acc = clean_output == labels[idx]
        clean_repr_dict = pnet.backbone.encoder_repr
#         for key in ['conv1', 'conv2', 'conv3', 'conv4_W']:
#             del clean_repr_dict[key]
        idx_results['clean_output'] = clean_output
        idx_results['clean_repr_dict'] = clean_repr_dict
            
        # Noisy input
        noisy_input = torch.tensor(
            f_in['data'][idx].reshape((1, 1, 164, 400)))
        timestep_results = {}
        idx_results['timestep_results'] = timestep_results
        
        # Activations with noisy input
        pnet.reset()
        for j in range(n_timesteps):
            _input = noisy_input if j == 0 else None
            if _input is not None:
                _input = _input.to(DEVICE)
            logits, _ = pnet(_input)
            noisy_output = logits.max(-1)[1].item()
            noisy_acc = noisy_output == labels[idx]
            noisy_repr_dict = pnet.backbone.encoder_repr
#             for key in ['conv1', 'conv2', 'conv3', 'conv4_W']:
#                 del noisy_repr_dict[key]
            timestep_results[j] = {}
            timestep_results[j]['noisy_output'] = noisy_output
            timestep_results[j]['noisy_repr_dict'] = noisy_repr_dict
            
        all_results[idx] = idx_results
                        
    return all_results

In [23]:
for arg in args:
    PNetClass, pnet_name, p_layers, chckpt = arg
    pnet = load_pnet(PNetClass, pnet_name, chckpt)
    full_idx_range = np.arange(bg.size)
    
    for idx_range in np.array_split(full_idx_range, 20):
        all_results = save_activations(
            pnet, exclude_timit, bg_types, idx_range
            )
        with open(f'{activations_dir}{pnet_name}_{idx_range[0]}-{idx_range[-1]}.p', 'wb') as f:
            pickle.dump(all_results, f)

Loaded Pnet: all
PCoder1 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
PCoder2 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
PCoder3 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
PCoder4 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
PCoder5 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
Loaded Pnet: a1
PCoder1 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
PCoder2 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
PCoder3 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010
Loaded Pnet: conv1
PCoder1 : ffm: 0.300 	 fbm: 0.300 	 erm: 0.010


OSError: [Errno 5] Input/output error