In [21]:
import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import gamma
from models.networks_2022 import BranchedNetwork
from models.priming_pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
from data.ValidationDataset import NoisyDataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from tensorboard.backend.event_processing import event_accumulator

# Set network parameters

In [2]:
# Which network to test
pnet_name = 'pnet'
chckpt = 1960

In [3]:
engram_dir = '/Users/chingfang/temp_locker/'
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'

In [18]:
# Set up parameters
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')
checkpoints_dir = f'{engram_dir}1_checkpoints/'
tensorboard_dir = f'{engram_dir}1_tensorboard/'
activations_dir = f'{engram_dir}3_validation_activations/{pnet_name}/'
hyp_dir = f'{engram_dir}2_hyperp/{pnet_name}/'
PNetClass = PBranchedNetwork_AllSeparateHP
n_timesteps = 5
layers = ['conv1', 'conv2', 'conv3', 'conv4_W', 'conv5_W', 'fc6_W']

Device: cuda:0


In [None]:
bg = 'pinkNoise'
snr = -9.0

# Load network

In [5]:
net = BranchedNetwork()
pnet = PNetClass(net, build_graph=True)
def print_hps(pnet):
    for pc in range(pnet.number_of_pcoders):
        string = f"PCoder{pc+1} : ffm: {getattr(pnet,f'ffm{pc+1}'):0.3f} \t"
        string += f"fbm: {getattr(pnet,f'fbm{pc+1}'):0.3f} \t"
        string += f"erm: {getattr(pnet,f'erm{pc+1}'):0.3f}"
        print(string)
pnet.load_state_dict(torch.load(
    f"{checkpoints_dir}{pnet_name}/{pnet_name}-{chckpt}-regular.pth",
    map_location='cpu'
    ))



<All keys matched successfully>

# Load hyperparameters

In [17]:
def get_hyperparams(tf_filepath, bg, snr):
    hyperparams = []
    ea = event_accumulator.EventAccumulator(tf_filepath)
    ea.Reload()
    eval_score = [0]
    epoch = 0
    while True:
        try:
            score_over_t = 0.
            for t in np.arange(1,5):
                score_over_t += ea.Scalars(f'NoisyPerf/Epoch#{epoch}')[t].value
                epoch += 1
            score_over_t /= 4
            eval_score.append(score_over_t)
        except Exception as e:
            break
    for i in range(1, 6):
        hps = {}
        ffm = ea.Scalars(f'Hyperparam/pcoder{i}_feedforward')[-1].value
        fbm = ea.Scalars(f'Hyperparam/pcoder{i}_feedback')[-1].value
        erm = ea.Scalars(f'Hyperparam/pcoder{i}_error')[-1].value
        if np.isnan(ffm) or np.isnan(fbm) or np.isnan(erm):
            return None, 0.
        hps['ffm'] = ffm
        hps['fbm'] = fbm
        hps['erm'] = erm
        hyperparams.append(hps)
    return hyperparams, eval_score[-1]

def load_pnet(PNetClass, pnet_name, chckpt, hyperparams=None):
    net = BranchedNetwork(track_encoder_representations=True)
    net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))
    pnet = PNetClass(net, build_graph=False)
    pnet.load_state_dict(torch.load(
        f"{checkpoints_dir}{pnet_name}/{pnet_name}-{chckpt}-regular.pth",
        map_location='cpu'
        ))
    if hyperparams is not None:
        pnet.set_hyperparameters(hyperparams)
    pnet.to(DEVICE)
    pnet.eval();
    print(f'Loaded Pnet: {pnet_name}')
    print_hps(pnet)
    return pnet

In [22]:
tf_dir = f'{hyp_dir}hyper_{bg}_snr{snr}/'
best_score = 0.
best_hyperparams = None
best_tf_file = None
for tf_file in os.listdir(tf_dir):
    if not tf_file.startswith('event'): continue
    tf_filepath = f'{tf_dir}{tf_file}'
    tf_file = tf_file.split('edu.')[-1]
    hyperparams, score = get_hyperparams(tf_filepath, bg, snr)
    if score > best_score:
        best_score = score
        best_hyperparams = hyperparams
        best_tf_file = tf_file
print(f'{bg}, SNR {snr} uses {best_tf_file} with valid score {best_score}')

# Use the best hyperparameter set
pnet = load_pnet(PNetClass, pnet_name, chckpt, best_hyperparams)

pinkNoise, SNR -9.0 uses 34545.1 with valid score 0.37640000134706497
Loaded Pnet: pnet
PCoder1 : ffm: 0.749 	fbm: 0.248 	erm: 0.026
PCoder2 : ffm: 0.969 	fbm: 0.031 	erm: -0.002
PCoder3 : ffm: 0.743 	fbm: 0.256 	erm: 0.016
PCoder4 : ffm: 0.133 	fbm: 0.850 	erm: 0.026
PCoder5 : ffm: 0.301 	fbm: 0.000 	erm: -0.041


In [6]:
# pnet.to(DEVICE)
# pnet.build_graph = False
# pnet.eval();
# print_hps(pnet)

PCoder1 : ffm: 0.300 	fbm: 0.300 	erm: 0.010
PCoder2 : ffm: 0.300 	fbm: 0.300 	erm: 0.010
PCoder3 : ffm: 0.300 	fbm: 0.300 	erm: 0.010
PCoder4 : ffm: 0.300 	fbm: 0.300 	erm: 0.010
PCoder5 : ffm: 0.300 	fbm: 0.300 	erm: 0.010


# Set up dataset

In [8]:
dset = NoisyDataset(bg=bg, snr=snr)

# Identify data indices for correct/incorrect

In [9]:
len(dset)

526

In [23]:
ff_correct = []
pred_correct = []
ff_output = []
pred_output = []

for i in range(200):
    cgram, label = dset[i]
    label = label.item()
    pnet.reset()
    for t in range(n_timesteps):
        _input = cgram if t == 0 else None # Clean cochleagram
        if _input is not None:
            _input = _input.unsqueeze(0)
            _input = _input.to(DEVICE)
        output_logits, _ = pnet(_input)
        output = np.argmax(output_logits.cpu().numpy())
        if t == 0:
            ff_correct.append(output == label)
            ff_output.append(output)
        if t == 4:
            pred_correct.append(output == label)
            pred_output.append(output)

In [24]:
np.sum(ff_correct)/len(ff_correct)

0.145

In [25]:
np.sum(pred_correct)/len(pred_correct)

0.175

# Attempt priming experiment

In [32]:
primed_ff_correct = []
primed_ff_output = []

for i in range(200):
    priming_cgram, _ = dset[np.random.choice(len(dset))]
    
    # Prime with cgram
    pnet.reset()
    for t in range(n_timesteps):
        _input = priming_cgram if t == 0 else None
        if _input is not None:
            _input = _input.unsqueeze(0).to(DEVICE)
        _, _ = pnet(_input)
    
    # Now run a regular forward pass
    cgram, label = dset[i]
    label = label.item()
    _input = cgram.unsqueeze(0).to(DEVICE)
    output_logits, _ = pnet(_input, force_no_reset=True)
    output = np.argmax(output_logits.cpu().numpy())
    primed_ff_correct.append(output == label)
    primed_ff_output.append(output)

In [33]:
np.sum(primed_ff_correct)/len(primed_ff_correct)

0.01