In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pickle
import h5py
import seaborn as sns
import pandas as pd
from scipy.stats import pearsonr
import configs

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from models.networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset

# Parameters

In [2]:
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'
activations_dir = f'{engram_dir}activations_pnet_all/'
pickles_dir = f'{engram_dir}pickles/'

In [3]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

Device: cpu


# Distance functions

In [4]:
from scipy.stats import pearsonr

In [5]:
# A few distance metrics

def row_rms(A, B):
    """
    RMS across rows
    """
    
    if torch.is_tensor(A):
        A = A.numpy()
    if torch.is_tensor(B):
        B = B.numpy()
    A = A.astype(float)
    B = B.astype(float)
    
    if len(A.shape) == 1:
        stim = A - B
        return np.sqrt(np.mean(stim * stim, axis = 0))
    
    rmses = []
    for idx in range(A.shape[0]):
        a = A[idx]
        b = B[idx]
        a, b = a.T, b.T
        stim = (a - b)
        out = np.sqrt(np.mean(stim * stim, axis = 0))
        rmses.append(out)
    return np.mean(rmses)

def rms(A, B):
    """
    RMS of flattened vectors
    """
    
    if torch.is_tensor(A):
        A = A.numpy()
    if torch.is_tensor(B):
        B = B.numpy()
    A = A.astype(float)
    B = B.astype(float)
    A = A.flatten()
    B = B.flatten()
        
    stim = A - B
    out = np.sqrt(np.mean(stim * stim))

    return out

def tanimoto_distance(A, B):
    """
    Tanimoto distance of flattened vector
    """
    
    if torch.is_tensor(A):
        A = A.numpy()
    if torch.is_tensor(B):
        B = B.numpy()
    A = A.astype(float)
    B = B.astype(float)
    A = A.flatten()
    B = B.flatten()
    
    _out = np.dot(A, B)/(np.linalg.norm(A)**2 + np.linalg.norm(B)**2 - np.dot(A,B))
    return _out
    
def cosine_similarity(A, B):
    """
    Cosine similarity of flattened vector
    """
    
    if torch.is_tensor(A):
        A = A.numpy()
    if torch.is_tensor(B):
        B = B.numpy()
    A = A.astype(float)
    B = B.astype(float)
    A = A.flatten()
    B = B.flatten()
    
    if len(A.shape) == 1:
        return np.dot(A, B)/(np.linalg.norm(A)*np.linalg.norm(B))
    
    out = []
    for channel in range(n_channels):
        a = A[channel]
        b = B[channel]
        _out = np.dot(a, b)/(np.linalg.norm(a)+np.linalg.norm(b)-np.dot(a,b))
        if np.isnan(_out):
            print(f'nan: {np.linalg.norm(a)}, {np.linalg.norm(b)}')
        out.append(_out)

    return np.mean(out)

def pearsonr_sim(A, B):
    if torch.is_tensor(A):
        A = A.numpy()
    if torch.is_tensor(B):
        B = B.numpy()
    A = A.astype(float)
    B = B.astype(float)
    A = A.flatten()
    B = B.flatten()
    pear, _ = pearsonr(A, B)
    return pear

# Function to collect correlations

In [6]:
def eval_correlations(results, dist_func, undead_units):
    labels = np.array(results['label'])
    idxs = np.arange(labels.size)
    
    popln_shuffle = []
    popln_shuffle_undead = []
    popln_timestep = []
    popln_layer = []
    
    unit_shuffle = []
    unit_shuffle_alive = []
    unit_timestep = []
    unit_layer = []
    layers = ['conv1', 'conv2', 'conv3', 'conv4_W', 'conv5_W', 'fc6_W']
    
    n_timesteps = 5
    for t in range(n_timesteps):
        for l in layers:
            unit_noisy_response = []
            unit_clean_response = []
            undead_units_l = undead_units[l]
            for i in idxs:
                noisy_activ = results[f'{l}_{t}_activations'][i]
                clean_activ = results[f'{l}_{t}_clean_activations'][i]
                noisy_activ = noisy_activ.flatten()
                clean_activ = clean_activ.flatten()
                unit_noisy_response.append(noisy_activ)
                unit_clean_response.append(clean_activ)

                # Popln Corr
                dist = dist_func(noisy_activ, clean_activ)
                dist_undead = dist_func(
                    noisy_activ[undead_units_l],
                    clean_activ[undead_units_l]
                    )
                popln_shuffle.append(dist)
                popln_shuffle_undead.append(dist_undead)
                popln_timestep.append(t)
                popln_layer.append(l)

            # Popln Corr
            unit_noisy_response = np.array(unit_noisy_response)
            unit_clean_response = np.array(unit_clean_response)
            for unit in np.arange(unit_noisy_response.shape[1]):               
                dist = dist_func(
                    unit_noisy_response[:,unit],
                    unit_clean_response[:,unit]
                    )
                unit_shuffle.append(dist)
                unit_shuffle_alive.append(undead_units_l[unit])
                unit_timestep.append(t)
                unit_layer.append(l)
        
    results = {
        'popln_shuffle': popln_shuffle,
        'popln_shuffle_undead': popln_shuffle_undead,
        'popln_timestep': popln_timestep,
        'popln_layer': popln_layer,
        'unit_shuffle': unit_shuffle,
        'unit_shuffle_alive': unit_shuffle_alive,
        'unit_timestep': unit_timestep,
        'unit_layer': unit_layer
        }
    return results

# Run and save correlations to pickle files

In [6]:
# This is bad practice! But the warnings are real annoying
import warnings
warnings.filterwarnings("ignore")

In [8]:
file_prefix = 'repr_pearsonr'
shuff_file_prefix = 'shuffle_pearsonr'
dist_func = pearsonr_sim
bgs = ['AudScene']
snrs = [-9., -6.]

In [9]:
with open(f'{pickles_dir}dead_units.p', 'rb') as f:
    undead_units = pickle.load(f)['undead_units']

In [None]:
for bg in bgs:
    for snr in snrs:                                                            
        print(f'{bg}, SNR {snr}')                                               
        results_path = f'{activations_dir}{bg}_snr{int(snr)}.hdf5'              
        results = h5py.File(results_path, 'r')                                
        print('Running correlations')                                           
        results = eval_correlations(                      
            results, dist_func, undead_units                               
            )                                                                   
        with open(f'{pickles_dir}{file_prefix}_{bg}_snr{snr}.p', 'wb') as f:    
            pickle.dump(results, f) 

AudScene, SNR -9.0
Running correlations
