#### 0. Import modules and define functions

In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from utils import *

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42

def getPermutedTensor(factors, lambdas, tensorX, NDIRS):
    
    """Find the optimal circular-shifts used by the permuted decomposition to produce
    the tensor components, and apply it to the original tensor."""
    
    # Compute reconstructed tensor by scaling the first mode by the lambdas and 
    # multiplying by the kathri rao product of the other modes
    fittensor = np.reshape((lambdas * factors[0]) @ khatri_rao(factors[1:]).T, tensorX.shape)

    if NDIRS == 1: #no shifting possible, so simply return original tensor
        return tensorX, fittensor

    N = tensorX.shape[0]
    NSTIMS = tensorX.shape[1]
    RLEN = tensorX.shape[2]

    shape4d = (N,NSTIMS,NDIRS,RLEN//NDIRS)
    shapeDot = (N,RLEN)
    tensor4d = np.reshape(tensorX,shape4d,order='F')

    objs = np.empty((NSTIMS,N,NDIRS))
    obj_shifts = np.empty((NSTIMS,N))
    #find best shift (argmin) per stim for all cells at once
    for si in range(NSTIMS):
        for shifti in range(NDIRS):
            # cf. matlab code in `permuted-decomposition/matlab/my_tt_cp_fg.m`
            objs[si,:,shifti] = -np.sum(fittensor[:,si,:] * np.reshape(np.roll(tensor4d[:,si],shifti,1),shapeDot,order='F'), 1)
        obj_shifts[si] = np.argmin(objs[si],axis=1)

    #apply shifts
    shifted_tensor = np.zeros_like(tensorX)
    for shifti in range(NDIRS):
        rolledX = np.reshape(np.roll(tensor4d,shifti,2), tensorX.shape, order='F')
        for si in range(NSTIMS):
            shifted_tensor[(obj_shifts[si] == shifti),si,:] = rolledX[(obj_shifts[si] == shifti),si,:]

    #check that we get the same fit -- OK
    # normsqX = np.square(norm(tensorX.ravel()))
    # print((np.square(norm(shifted_tensor.ravel() - fittensor.ravel())))/( normsqX))
    # print('rec. error',preComputed[best_nfactors]['all_objs'][best_rep])
    return shifted_tensor, fittensor

def getNeuralMatrix(scld_permT, factors, lambdas, NDIRS, all_zeroed_stims=None,
                    order='F', verbose=True):
    """Computes the final neural matrix, X, by fitting the permuted tensor scaled by
    relative stimulus magnitudes using the factors obtained from NTF.
    
    Any previously zeroed out responses are now also permuted by the circular-shift
    producing the best fit.
    
    Additionally, a rebalancing of the factor magnitudes is applied to attribute
    a meaningful interpretation to the final coefficients.
    
    -------------------
    Arguments:
    
    scld_permT: ndarray, permuted tensor scaled by relative stimulus FRs
    
    factors: list, [neural_factors, stimulus_factors, response_factors] (normalized)
    
    lambdas: ndarray, shape (R,), where R is the number of components being used
    
    NDIRS: int, number of stimulus directions (rows in original 2D response maps)
    
    all_zeroed_stims: dict, {cell: (tuple of zeroed stim idxs)}, default None
    
    order: str, order used to flatten the original 2D response maps, default 'F' 
    
    -------------------
    Returns:
    X: ndarray, shape (Ncells, R), neural encoding matrix
    
    new_scld_permT: ndarray, tensor including previously zeroed out responses (if any)
    
    """

    R = lambdas.size
    
    #rebalance factor loadings based on relative stimulus contributions + scale by lambdas
    stim_factors = factors[1].copy()
    stim_scls = stim_factors.max(0,keepdims=1)
    stim_factors /= stim_scls

    neural_factors = factors[0].copy()
    neural_factors *= lambdas * stim_scls
    
    # rescaled stim x response coords
    new_coords = np.stack([khatri_rao([stim_factors[:,r][:,None],factors[2][:,r][:,None]]).ravel() for r in range(R)],axis=1)

    
    Ncells = scld_permT.shape[0]
    NSTIMS = scld_permT.shape[1]
    
    X = np.zeros((Ncells,R))
    
    new_scld_permT = scld_permT.copy()

    for c in range(Ncells):
        
        if verbose and (c+1) % 50 == 0: print(c+1,end=' ')

        if all_zeroed_stims is not None and c in all_zeroed_stims:
            # Any previously zeroed out responses are now also permuted by the circular-shift
            # producing the best fit.
            
            lowest_cost = np.inf
            #for each shift of all zeroed-stims together
            for shifti in range(NDIRS):
                shifted_cell_data = scld_permT[c].copy()

                for si in all_zeroed_stims[c]:
                    #rotate orig_data
                    si_2d = shifted_cell_data[si].reshape((NDIRS,-1),order=order)
                    shifted_cell_data[si] = np.roll(si_2d,shifti,axis=0).ravel(order=order)

                #compute fit cost
                res = lsq_linear(new_coords,shifted_cell_data.ravel(),bounds=(0,np.inf))
                coeffs, cost = res['x'], res['cost']

                #if lower reconstruction cost, update best shift combo
                if cost < lowest_cost:
                    lowest_cost = cost
                    best_shift = shifti
                    best_coeffs = coeffs
                    best_partial = True
                    new_scld_permT[c] = shifted_cell_data

            new_coeffs = best_coeffs

                
        else:#if no zeroed stims
            # update coefficients to fit our stimulus-rescaled tensor
            new_coeffs = lsq_linear(new_coords,scld_permT[c].ravel(),bounds=(0,np.inf))['x']

        # sqrt so that, for each stimulus, the magnitude of a vector of coeffs for factors
        # representing that stimulus can be equal to 1, even if that stimulus response 
        # is split across multiple factors. This ultimately leads to better distances
        # between neurons
        X[c] = np.sqrt(new_coeffs)

    return X, new_scld_permT


#### 1. Load precomputed tensor files

In [None]:
NDIRS = 8 #number of stimulus directions (rows in original 2-D response maps)

# Load precomputed tensor and aux files (see `creating-the-tensor/creating-the-tensor.ipynb`)

sigT = 
allT = 
all_zeroed_stims = 
cell_maxFRs = 

N = tensorX.shape[0]
NSTIMS = tensorX.shape[1]
RLEN = tensorX.shape[2]

# Compute relative FRs between stimuli for each cell
relFRs = [cell_maxFRs[c]/cell_maxFRs[c].max() for c in range(len(cell_maxFRs))]

#### 2. Load factorization results and compute neural encoding matrix

In [None]:
# Load pre-computed optimal factors and corresponding lambdas

R = 17
best_factors = np.load(f'cp-files/R{R}_factors.npy',allow_pickle=True)
best_lambdas = np.load(f'cp-files/R{R}_lambdas.npy',allow_pickle=True)

# remove any eventual zero-norm factor
posnorms = ~np.isclose(best_lambdas,0)
lambdas = best_lambdas[posnorms]
# make sure they are all normalized
factors = [f[:,posnorms]/np.linalg.norm(f[:,posnorms],axis=0,keepdims=1) for f in best_factors]

# find the permuted version of sigT that gave rise to the factors -- this is
# necessary for computing the actual rec error
permT, fitT = getPermutedTensor(factors, lambdas, sigT)


#now, add non-signif stims
#note: these haven't been shifted by our factorization -- will address that later
for c, zeroed_stims in all_zeroed_stims.items():
    for si in zeroed_stims:
        permT[c,si] = allT[c,si]
        
#finally, scale the (unit-normed) stimuli by their relative FRs
scld_permT = permT * relFRs[...,None]

# we will now proceed to adjust the neural loadings to reflect this, and to include the non-signif responses

X, all_scld_permT = getNeuralMatrix(
    scld_permT, factors, lambdas, NDIRS, all_zeroed_stims, order='F', verbose=False)

# finally, eliminate possible redundancy among factors in the neural matrix using PCA
# choose number of PCs based on a prespecified explained variance ratio
MIN_EXPL_VAR_RATIO = 0.8

pca = PCA(len(lambdas))
newX = pca.fit_transform(X)
nPCs = np.flatnonzero(np.cumsum(pca.explained_variance_ratio_) > MIN_EXPL_VAR_RATIO)[0] + 1
print('nPCs',nPCs)

#### 3. Compute IAN similarity kernel from pairwise distances

In [None]:
# compute matrix of squared distances
D2 = squareform(pdist(newX[:,:nPCs], 'sqeuclidean'))

In [None]:
# solver = 'GUROBI' #using a commercial optimization package if highly recommended for faster kernel convergence
# a free academic license can be obtained at https://www.gurobi.com/academia/academic-program-and-licenses/

#use None if you don't have a preferred solver, or pick from the list of solvers from cvxpy: 
# https://www.cvxpy.org/tutorial/advanced/index.html#choosing-a-solver
solver = None 

G, wG, optScales, disc_pts = IAN('exact-precomputed-sq', D2, solver=solver, plot_interval=0, verbose=0, plot_final_stats=0)

# Optional: instead of picking a single factorization result, can compute separate graphs for each initialization and
# average them together to combine all results into a single weighted graph.

#### 4. Diffusion map embedding

In [None]:
from ian.ian import * #https://github.com/dyballa/IAN/
from ian.dset_utils import *

# Compute diffusion map embedding using the IAN weighted graph as similarity matrix
n_components = 3
diffmap_y, diffmap_evals = diffusionMapFromK(wG, n_components)

plot3dScatter(diffmap_y)

#### 5. Local dimensionality estimation

In [None]:
# Estimate local dimension using NCD algorithm

nbrhoodOrder = 2 #using neighbors-of-neighbors up to 2 hops away
NofNDims, degDims = estimateLocalDims(G, D2, 1) 
dims = np.maximum(degDims,NofNDims)

plot3dScatter(diffmap_y, dims)