# Import Modules

In [None]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import glob
import sys 
import yaml 
import glob
import h5py 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import io_dict_to_hdf5 as ioh5

from tqdm.notebook import tqdm, trange
from matplotlib.backends.backend_pdf import PdfPages
from scipy import interpolate 
from scipy import signal
from pathlib import Path
from scipy.linalg import block_diag

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append(str(Path('.').absolute().parent))
from utils import check_path, add_colorbar

pd.set_option('display.max_rows', None)
FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/PredCodingModel')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up Input Data

## Natural Images

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm.notebook import tqdm
import scipy.io as sio

np.random.seed(0)

In [None]:
# DoG filter as a model of LGN
def DoG(img, ksize=(5,5), sigma=1.3, k=1.6):
    g1 = cv2.GaussianBlur(img, ksize, sigma)
    g2 = cv2.GaussianBlur(img, ksize, k*sigma)
    dog = g1 - g2
    return (dog - dog.min())/(dog.max()-dog.min())

# Gaussian mask for inputs
def GaussianMask(sizex=16, sizey=16, sigma=5):
    x = np.arange(0, sizex, 1, float)
    y = np.arange(0, sizey, 1, float)
    x, y = np.meshgrid(x,y)
    
    x0 = sizex // 2
    y0 = sizey // 2
    mask = np.exp(-((x-x0)**2 + (y-y0)**2) / (2*(sigma**2)))
    return mask / np.sum(mask)

In [None]:
# Preprocess of inputs
num_images = 10
num_iter = 5000

# datasets from http://www.rctn.org/bruno/sparsenet/
mat_images = sio.loadmat(os.path.expanduser('~/Research/Github/PredictiveCoding-RaoBallard-Model/datasets/IMAGES.mat'))
imgs = mat_images['IMAGES']
mat_images_raw = sio.loadmat(os.path.expanduser('~/Research/Github/PredictiveCoding-RaoBallard-Model/datasets/IMAGES_RAW.mat'))
imgs_raw = mat_images_raw['IMAGESr']

In [None]:
# Plot datasets
fig = plt.figure(figsize=(8, 4))
for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(imgs_raw[:,:,i], cmap="gray")
    plt.axis("off")
plt.tight_layout()
fig.suptitle("Natural Images", fontsize=20)
plt.subplots_adjust(top=0.9)  

## Sim Data

In [None]:
in_dim = 256
hidden_dims = [32,128]
n_layers = len(hidden_dims)
N = in_dim

In [None]:
sd = 3
mu = 0
x1 = np.arange(-2*N,2*N+1)
f1 = (1/(sd*np.sqrt(2*np.pi))*np.exp(-1/2*((x1-mu)/sd)**2))
f1 = f1/np.max(f1)

L = x1.shape[0]
I_mat = np.array([f1[L//2-k:L//2+N-k] for k in np.arange(N)])

In [None]:
fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].plot(x1,f1)
im2 = ax[1].pcolormesh(I_mat)
cbar = add_colorbar(im2)
plt.tight_layout()

In [None]:
num_samps=3
data_in = np.random.multivariate_normal(mean=np.zeros(I_mat.shape[0]),cov=I_mat,size=(num_samps,))
im = plt.pcolormesh(data_in.T)#,aspect='auto')
cbar = add_colorbar(im,linewidth=1)

In [None]:
plt.imshow(data_in[0].reshape(16,16))

# Creating Time dependent input

In [None]:
def sliding_window(image, stepSize, windowSize):
    # slide a window across the image
    for y in range(0, image.shape[0], stepSize):
        for x in range(0, image.shape[1], stepSize):
            # yield the current window
            yield (x, y, image[y:y + windowSize[1], x:x + windowSize[0]])

In [None]:
img = imgs[:,:,n]
winW, winH = (16,16)
img2 = []
for (x, y, window) in sliding_window(img, stepSize=8, windowSize=(winW, winH)):
    # if the window does not meet our desired window size, ignore it
    if window.shape[0] != winH or window.shape[1] != winW:
        continue
    else:
        img2.append(window)
img2 = np.stack(img2)

In [None]:
n = 10
m = 2
x = 8*(n+1)
y = 8*(m+1)
# Create figure and axes
fig, ax = plt.subplots(1,1, figsize=(10,10))
ax.imshow(img, aspect='auto',cmap='gray')
# Create a Rectangle patch
rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='r', facecolor='none')
# Add the patch to the Axes
ax.add_patch(rect)
fig.savefig(FigPath/'NaturalImageEx.png')

In [None]:
# Clop three inputs
inputs = np.array([(gmask*img2[i]).flatten() for i in range(3)])

inputs = (inputs - np.mean(inputs)) * input_scale

# Direct Implementation

## Non-Recurrent Model

In [None]:
class RaoBallard1999Model:
    def __init__(self, dt=1, sigma2=1, sigma2_td=10):
        self.dt = dt
        self.inv_sigma2 = 1/sigma2 # 1 / sigma^2        
        self.inv_sigma2_td = 1/sigma2_td # 1 / sigma_td^2
        
        self.k1 = 0.3 # k_1: update rate
        self.k2 = 0.2 # k_2: learning rate
        
        self.lam = 0.02 # sparsity rate
        self.alpha = 1
        self.alphah = 0.05
        
        self.num_units_level0 = 256
        self.num_units_level1 = 32
        self.num_units_level2 = 128
        self.num_level1 = 3
        
        U = np.random.randn(self.num_units_level0, 
                            self.num_units_level1)
        Uh = np.random.randn(int(self.num_level1*self.num_units_level1),
                             self.num_units_level2)
        self.U = U.astype(np.float32) * np.sqrt(2/(self.num_units_level0+self.num_units_level1))
        self.Uh = Uh.astype(np.float32) * np.sqrt(2/(int(self.num_level1*self.num_units_level1)+self.num_units_level2)) 
                
        self.r = np.zeros((self.num_level1, self.num_units_level1))
        self.rh = np.zeros((self.num_units_level2))
    
    def initialize_states(self, inputs):
        self.r = inputs @ self.U 
        self.rh = self.Uh.T @ np.reshape(self.r, (int(self.num_level1*self.num_units_level1)))
    
    def calculate_total_error(self, error, errorh):
        recon_error = self.inv_sigma2*np.sum(error**2) + self.inv_sigma2_td*np.sum(errorh**2)
        sparsity_r = self.alpha*np.sum(self.r**2) + self.alphah*np.sum(self.rh**2)
        sparsity_U = self.lam*(np.sum(self.U**2) + np.sum(self.Uh**2))
        return recon_error + sparsity_r + sparsity_U
        
    def __call__(self, inputs, training=False):
        # inputs : (3, 256)
        r_reshaped = np.reshape(self.r, (int(self.num_level1*self.num_units_level1))) # (96)

        fx = self.r @ self.U.T
        fxh = self.Uh @ self.rh # (96, )
        
        # Calculate errors
        error = inputs - fx # (3, 256)
        errorh = r_reshaped - fxh # (96, ) 
        errorh_reshaped = np.reshape(errorh, (self.num_level1, self.num_units_level1)) # (3, 32)
        
        g_r = self.alpha * self.r / (1 + self.r**2) # (3, 32)
        g_rh = self.alphah * self.rh / (1 + self.rh**2) # (64, )
        
        # Update r and rh
        dr = self.inv_sigma2 * error @ self.U - self.inv_sigma2_td * errorh_reshaped - g_r
        drh = self.inv_sigma2_td * self.Uh.T @ errorh - g_rh
        
        dr = self.k1 * dr
        drh = self.k1 * drh
        
        rr = self.r.copy()
        # Updates                
        self.r = self.r + dr
        self.rh = self.rh + drh
        
        if training:  
            dU = self.inv_sigma2 * error.T @ self.r - 3*self.lam * self.U
            dUh = self.inv_sigma2_td * np.outer(errorh, self.rh) - self.lam * self.Uh
            
            self.U = self.U + self.k2 * dU
            self.Uh = self.Uh + self.k2 * dUh
            
        return error, errorh, dr, drh, self.r, self.rh

In [None]:
# Define model
model = RaoBallard1999Model()
# model = Recurrent_PredCoding()

# Simulation constants
H, W, num_images = imgs.shape
nt_max = 1000 # Maximum number of simulation time
eps = 1e-3 # small value which determines convergence
input_scale = 40 # scale factor of inputs .05 for sim data, 40 for nat images
gmask = GaussianMask() # Gaussian mask
error_list = [] # List to save errors
r_t = [] # List to save r
rh_t = []  # list to save rh
U_t = []
Uh_t = []
J_t = []
J_t = []
e1, e2 = [], []
dr_t, drh_t = [], []

winW, winH = (16,16)

FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/PredCodingModel')
data_path = check_path(Path('~/Research/SensoryMotorPred_Data/data/').expanduser(),'PredCodingData/flat')
deltaJ_type = 'flat'
FigPath = check_path(FigPath, deltaJ_type)

In [None]:
for n in range(2): #imgs.shape[-1]
    img2 = []
    idx = np.random.randint(0, num_images)
    img = imgs[:, :, idx]
#     img = imgs[:,:,n]
    for (x, y, window) in sliding_window(img, stepSize=winW//2, windowSize=(winW, winH)):
        # if the window does not meet our desired window size, ignore it
        if window.shape[0] != winH or window.shape[1] != winW:
            continue
        else:
            img2.append(window)
            
    img2 = np.stack(img2)
    num_iter = img2.shape[0]-3
    pbar = tqdm(range(num_iter))
    for iter_ in pbar:
    #     # Get images randomly
    #     idx = np.random.randint(0, num_images)
    #     img = imgs[:, :, idx]

    #     # Get the coordinates of the upper left corner of clopping image randomly.
    #     beginx = np.random.randint(0, W-27)
    #     beginy = np.random.randint(0, H-17)
    #     img_clopped = img[beginy:beginy+16, beginx:beginx+26]

        # Clop three inputs
        inputs = np.array([(gmask*img2[iter_+i]).flatten() for i in range(3)])
    #     inputs = np.random.multivariate_normal(mean=np.zeros(I_mat.shape[0]),cov=I_mat,size=(num_samps,))

        inputs = (inputs - np.mean(inputs)) * input_scale

        # Reset states
        model.initialize_states(inputs)

        # Input an image patch until latent variables are converged 
        for i in range(nt_max):
            # Update r and rh without update weights 
            error, errorh, dr, drh, r, rh = model(inputs, training=False)
            r_t.append(r)
            rh_t.append(rh)
            e1.append(error)
            e2.append(errorh)
            # Compute norm of r and rh
            dr_norm = np.linalg.norm(dr, ord=2) 
            drh_norm = np.linalg.norm(drh, ord=2)
            dr_t.append(dr)
            drh_t.append(drh)
            # Check convergence of r and rh, then update weights
            if dr_norm < eps and drh_norm < eps:
                error, errorh, dr, drh, r, rh = model(inputs, training=True)
                U_t.append(model.U)
                Uh_t.append(model.Uh)
    #             print('converged after {:d}'.format(i))
                break

            # If failure to convergence, break and print error
            if i >= nt_max-2: 
                print("Error at patch:", iter_)
                print(dr_norm, drh_norm)
                break

        error_list.append(model.calculate_total_error(error, errorh)) # Append errors

        # Decay learning rate         
        if iter_ % 40 == 39:
            model.k2 /= 1.015

        # Print moving average error
        if iter_ % 1000 == 999:  
            print("iter: "+str(iter_+1)+"/"+str(num_iter)+", Moving error:", np.mean(error_list[iter_-999:iter_]))
        pbar.set_description('total_error: {:.02f}'.format(error_list[-1]), refresh=True)
    
r_t = np.array(r_t)
rh_t = np.array(rh_t)
U_t = np.array(U_t)
Uh_t = np.array(Uh_t)
e1 = np.array(e1)
e2 = np.array(e2)
dr_t = np.array(dr_t)
drh_t = np.array(drh_t)

In [None]:
def moving_average(x, n=100) :
    ret = np.cumsum(x, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

with PdfPages(FigPath/ 'PredCoding_flat.pdf') as pdf:
    moving_average_error = moving_average(np.array(error_list))
    plt.figure(figsize=(5, 3))
    plt.ylabel("Error")
    plt.xlabel("Iterations")
    plt.yticks(np.arange(0,5))
    plt.ylim([0,5])
    plt.plot(np.arange(len(moving_average_error)), moving_average_error)
    plt.tight_layout()
    pdf.savefig()
    plt.show()

    # Plot Receptive fields of level 1
    fig = plt.figure(figsize=(8, 4))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i in range(32):
        plt.subplot(4, 8, i+1)
        plt.imshow(np.reshape(model.U[:, i], (16, 16)), cmap="gray")
        plt.axis("off")

    fig.suptitle("Receptive fields of level 1", fontsize=20)
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    pdf.savefig()
    plt.show()

    # Plot Receptive fields of level 2
    zero_padding = np.zeros((80, 32))
    U0 = np.concatenate((model.U, zero_padding, zero_padding))
    U1 = np.concatenate((zero_padding, model.U, zero_padding))
    U2 = np.concatenate((zero_padding, zero_padding, model.U))
    U_ = np.concatenate((U0, U1, U2), axis = 1)
    Uh_ = U_ @ model.Uh  

    fig = plt.figure(figsize=(8, 5))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i in range(36):
        plt.subplot(6, 6, i+1)
        plt.imshow(np.reshape(Uh_[:, i], (16, 26), order='F'), cmap="gray")
        plt.axis("off")

    fig.suptitle("Receptive fields of level 2", fontsize=20)
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    pdf.savefig()
    plt.show()

    fig, axs = plt.subplots(1,2, figsize=(10,5))
    axs[0].imshow(np.cov(U_))
    axs[0].set_title('Cov(U)')
    axs[1].imshow(np.cov(Uh_))
    axs[1].set_title('Cov(Uh)')
    plt.tight_layout()
    pdf.savefig()

#     fig, axs = plt.subplots(1,2, figsize=(10,5))
#     axs[0].imshow(model.J)
#     axs[0].set_title('J Matrix')
#     axs[1].imshow(np.cov(model.J))
#     axs[1].set_title('Cov(J)')
#     plt.tight_layout()
#     pdf.savefig()

## Recurrent Model

In [None]:
class Recurrent_PredCoding:
    def __init__(self, dt=1, sigma2=1, sigma2_td=1.1, deltaJ=True):
        self.dt = dt
        self.inv_sigma2 = 1/sigma2 # 1 / sigma^2        
        self.inv_sigma2_td = 1/sigma2_td # 1 / sigma_td^2
        
        self.deltaJ = deltaJ
        
        self.k1 = 0.03 # k_1: update rate
        self.k2 = 0.02 # k_2: learning rate
        
        self.lam = 0.002 # sparsity rate
        self.alpha = 1
        self.alphah = 0.05
        
        self.num_units_level0 = 256
        self.num_units_level1 = 32
        self.num_units_level2 = 128
        self.num_level1 = 3
        
        U = np.random.randn(self.num_units_level0, 
                            self.num_units_level1)
        Uh = np.random.randn(int(self.num_level1*self.num_units_level1),
                             self.num_units_level2)
        self.U = U.astype(np.float32) * np.sqrt(.5/(self.num_units_level0+self.num_units_level1))
        self.Uh = Uh.astype(np.float32) * np.sqrt(.5/(int(self.num_level1*self.num_units_level1)+self.num_units_level2)) 
        
        self.J = np.random.randn(self.num_units_level2,
                                 self.num_units_level2).astype(np.float32) * np.sqrt(.5/(int(self.num_units_level2*self.num_units_level1)+self.num_units_level2)) 
        self.r = np.zeros((self.num_level1, self.num_units_level1))
        self.rh = np.zeros((self.num_units_level2))
        
        self.h = np.zeros((self.num_units_level2,
                             self.num_units_level2))
        
    def initialize_states(self, inputs):
        self.r = inputs @ self.U 
        self.rh = self.Uh.T @ np.reshape(self.r, (int(self.num_level1*self.num_units_level1))) + self.J @ self.rh # np.reshape(self.rh, (int(self.num_level1*self.num_units_level1)))
        
        
    def calculate_total_error(self, error, errorh):
        recon_error = self.inv_sigma2*np.sum(error**2) + self.inv_sigma2_td*np.sum(errorh**2)
        sparsity_r = self.alpha*np.sum(self.r**2) + self.alphah*np.sum(self.rh**2)
        sparsity_U = self.lam*(np.sum(self.U**2) + np.sum(self.Uh**2))
        return recon_error + sparsity_r + sparsity_U
        
    def __call__(self, inputs, training=False):
        # inputs : (3, 256)
        r_reshaped = np.reshape(self.r, (int(self.num_level1*self.num_units_level1))) # (96)

        fx = np.tanh(self.r @ self.U.T)
        fxh = np.tanh(self.Uh @ self.rh) # (96, )

        # Calculate errors
        error = (inputs - fx)*(1-fx**2)#@ # (3, 256)
        errorh = (r_reshaped - fxh)*(1-fxh**2)#@ # (96, ) 
        errorh_reshaped = np.reshape(errorh, (self.num_level1, self.num_units_level1)) # (3, 32)
        
        g_r = self.alpha * self.r / (1 + self.r**2) # (3, 32)
        g_rh = self.alphah * self.rh / (1 + self.rh**2) # (64, )
        
        temp = np.zeros((len(self.rh),len(self.rh)))
        np.fill_diagonal(temp, self.rh)
        dh = (self.h + (1-np.tanh(self.J@self.rh)**2)*temp) # self.J@self.h + 

        # Update r and rh
        dr = self.inv_sigma2 * (error) @ self.U - self.inv_sigma2_td * errorh_reshaped - g_r
        drh = self.inv_sigma2_td * self.Uh.T @ (errorh) - g_rh + self.inv_sigma2_td * np.tanh(self.J@self.rh)
        
        dr = self.k1 * dr
        drh = self.k1 * drh
        
        # Updates                
        self.r = self.r + dr
        self.rh = self.rh + drh
        self.h = self.h + dh
        
        if training:  
            dU = self.inv_sigma2 * (error).T  @ self.r - 3*self.lam * self.U
            dUh = self.inv_sigma2_td * (np.outer(errorh, self.rh)) - self.lam * self.Uh
            
            if self.deltaJ == True:
                dJ = self.inv_sigma2_td * (errorh ) @ self.Uh * self.h
                self.J = self.J + self.k2*.1 * dJ # self.k2 

            self.U = self.U + self.k2 * dU
            self.Uh = self.Uh + self.k2 * dUh
        return error, errorh, dr, drh, self.r, self.rh

In [None]:
# Define model
# model = RaoBallard1999Model()
deltaJ = True
model = Recurrent_PredCoding(deltaJ=deltaJ)

# Simulation constants
H, W, num_images = imgs.shape
nt_max = 1000 # Maximum number of simulation time
eps = 1e-3 # small value which determines convergence
input_scale = 40 # scale factor of inputs .05 for sim data, 40 for nat images
gmask = GaussianMask() # Gaussian mask
error_list = [] # List to save errors
r_t = [] # List to save r
rh_t = []  # list to save rh
U_t = []
Uh_t = []
J_t = []
J_t = []
e1, e2 = [], []
dr_t, drh_t = [], []

winW, winH = (16,16)

FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/PredCodingModel')
# deltaJ = False
if deltaJ == True:
    data_path = check_path(Path('~/Research/SensoryMotorPred_Data/data/').expanduser(),'PredCodingData/deltaJ')
    deltaJ_type = 'deltaJ'
    FigPath = check_path(FigPath, deltaJ_type)
else:
    data_path = check_path(Path('~/Research/SensoryMotorPred_Data/data/').expanduser(),'PredCodingData/randJ')
    deltaJ_type = 'randJ'
    FigPath = check_path(FigPath, deltaJ_type)

In [None]:
for n in range(2): #imgs.shape[-1]
    img2 = []
    idx = np.random.randint(0, num_images)
    img = imgs[:, :, idx]
#     img = imgs[:,:,n]
    for (x, y, window) in sliding_window(img, stepSize=winW//2, windowSize=(winW, winH)):
        # if the window does not meet our desired window size, ignore it
        if window.shape[0] != winH or window.shape[1] != winW:
            continue
        else:
            img2.append(window)
            
    img2 = np.stack(img2)
    num_iter = img2.shape[0]-3
#     num_iter = 10
    pbar = tqdm(range(num_iter))
    for iter_ in pbar:
        # Clop three inputs
        inputs = np.array([(gmask*img2[iter_+i]).flatten() for i in range(3)])

        inputs = (inputs - np.mean(inputs)) * input_scale

        # Reset states
        model.initialize_states(inputs)
    
        # Input an image patch until latent variables are converged 
        for i in range(nt_max):
            # Update r and rh without update weights 
            error, errorh, dr, drh, r, rh = model(inputs, training=False)
            r_t.append(r)
            rh_t.append(rh)
            
            e1.append(error)
            e2.append(errorh)
            
            # Compute norm of r and rh
            dr_norm = np.linalg.norm(dr, ord=2) 
            drh_norm = np.linalg.norm(drh, ord=2)
            dr_t.append(dr)
            drh_t.append(drh)
            # Check convergence of r and rh, then update weights
            if dr_norm < eps and drh_norm < eps:
                error, errorh, dr, drh, r, rh = model(inputs, training=True)
                U_t.append(model.U)
                Uh_t.append(model.Uh)
                J_t.append(model.J)
    #             print('converged after {:d}'.format(i))
                break

            # If failure to convergence, break and print error
            if i >= nt_max-2: 
                print("Error at patch:", iter_)
                print(dr_norm, drh_norm)
                break

        error_list.append(model.calculate_total_error(error, errorh)) # Append errors

        # Decay learning rate         
        if iter_ % 40 == 39:
            model.k2 /= 1.015

        # Print moving average error
        if iter_ % 1000 == 999:  
            print("iter: "+str(iter_+1)+"/"+str(num_iter)+", Moving error:", np.mean(error_list[iter_-999:iter_]))
        pbar.set_description('total_error: {:.02f}'.format(error_list[-1]), refresh=True)
    
r_t = np.array(r_t)
rh_t = np.array(rh_t)
U_t = np.array(U_t)
Uh_t = np.array(Uh_t)
e1 = np.array(e1)
e2 = np.array(e2)
dr_t = np.array(dr_t)
drh_t = np.array(drh_t)
J_t = np.array(J_t)

In [None]:
    
r_t = np.array(r_t)
rh_t = np.array(rh_t)
U_t = np.array(U_t)
Uh_t = np.array(Uh_t)
e1 = np.array(e1)
e2 = np.array(e2)
dr_t = np.array(dr_t)
drh_t = np.array(drh_t)
J_t = np.array(J_t)

In [None]:
plt.plot(drh_t[:,0])

In [None]:
def moving_average(x, n=100) :
    ret = np.cumsum(x, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

with PdfPages(FigPath/ 'PredCoding_Rec_{}.pdf'.format(deltaJ_type)) as pdf:
    moving_average_error = moving_average(np.array(error_list))
    plt.figure(figsize=(5, 3))
    plt.ylabel("Error")
    plt.xlabel("Iterations")
    plt.yticks(np.arange(0,5))
    plt.ylim([0,5])
    plt.plot(np.arange(len(moving_average_error)), moving_average_error)
    plt.tight_layout()
    pdf.savefig()
    plt.show()

    # Plot Receptive fields of level 1
    fig = plt.figure(figsize=(8, 4))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i in range(32):
        plt.subplot(4, 8, i+1)
        plt.imshow(np.reshape(model.U[:, i], (16, 16)), cmap="gray")
        plt.axis("off")

    fig.suptitle("Receptive fields of level 1", fontsize=20)
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    pdf.savefig()
    plt.show()

    # Plot Receptive fields of level 2
    zero_padding = np.zeros((80, 32))
    U0 = np.concatenate((model.U, zero_padding, zero_padding))
    U1 = np.concatenate((zero_padding, model.U, zero_padding))
    U2 = np.concatenate((zero_padding, zero_padding, model.U))
    U_ = np.concatenate((U0, U1, U2), axis = 1)
    Uh_ = U_ @ model.Uh  

    fig = plt.figure(figsize=(8, 5))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i in range(36):
        plt.subplot(6, 6, i+1)
        plt.imshow(np.reshape(Uh_[:, i], (16, 26), order='F'), cmap="gray")
        plt.axis("off")

    fig.suptitle("Receptive fields of level 2", fontsize=20)
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    pdf.savefig()
    plt.show()

    fig, axs = plt.subplots(1,2, figsize=(10,5))
    axs[0].imshow(np.cov(U_))
    axs[0].set_title('Cov(U)')
    axs[1].imshow(np.cov(Uh_))
    axs[1].set_title('Cov(Uh)')
    plt.tight_layout()
    pdf.savefig()

    fig, axs = plt.subplots(1,2, figsize=(10,5))
    axs[0].imshow(model.J)
    axs[0].set_title('J Matrix')
    axs[1].imshow(np.cov(model.J))
    axs[1].set_title('Cov(J)')
    plt.tight_layout()
    pdf.savefig()

In [None]:
J_t[0]==J_t[-1]

In [None]:
from utils import *
##### Print memory of local variables #####
for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

In [None]:
plt.plot(np.array(r_t)[:,0,0])

In [None]:
# r_t = np.array(r_t)
# rh_t = np.array(rh_t)
U_t = np.array(U_t)
Uh_t = np.array(Uh_t)
e1 = np.array(e1)
e2 = np.array(e2)

In [None]:
pred_data = {
    'r_t':r_t, 
    'rh_t':rh_t, 
    'U_t':U_t, 
    'Uh_t':Uh_t, 
    'e1':e1, 
    'e2':e2, 
    'J_t':J_t,
}

ioh5.save((data_path / 'PredCodingRec_{}.h5'.format(deltaJ_type)),pred_data)

In [None]:
plt.plot(e1[:20,0,0])
plt.plot(e2[:20,0])
# plt.plot(error_list)


In [None]:
dr_t.shape,drh_t.shape

In [None]:
plt.plot(dr_t[:100,0,0])
plt.plot(drh_t[:100,0])
# plt.plot(r_t[:,0,0])

In [None]:
plt.plot(r_t[:100,0,0])

In [None]:
fxh = r_t.reshape(-1,96) - (model.Uh @ rh_t.T).T

In [None]:
plt.plot(fxh[0:100,0])

In [None]:
model.rh

In [None]:
plt.figure()
plt.imshow(np.cov(model.r.T))

In [None]:
inputs.shape, model.U.shape, model.r.shape, model.Uh.shape, model.rh.shape

# RNN

In [None]:
dt=1
sigma2=1 
sigma2_td=10

inv_sigma2 = 1/sigma2 # 1 / sigma^2        
inv_sigma2_td = 1/sigma2_td # 1 / sigma_td^2

k1 = 0.3 # k_1: update rate
k2 = 0.2 # k_2: learning rate

lam = 0.002 # sparsity rate
alpha = 1
alphah = 0.05

num_units_level0 = 256
num_units_level1 = 32
num_units_level2 = 128
num_level1 = 3

U = np.random.randn(num_units_level0, 
                    num_units_level1)
Uh = np.random.randn(int(num_level1*num_units_level1),
                     num_units_level2)
U = U.astype(np.float32) * np.sqrt(2/(num_units_level0+num_units_level1))
Uh = Uh.astype(np.float32) * np.sqrt(2/(int(num_level1*num_units_level1)+num_units_level2)) 

J = np.random.randn(num_units_level2,
                         num_units_level2).astype(np.float32) * np.sqrt(2/(int(num_units_level2*num_units_level1)+num_units_level2)) 
r = np.zeros((num_level1, num_units_level1))
rh = np.zeros((num_units_level2))

h = np.zeros((num_units_level2,
                     num_units_level2))

# def initialize_states(self, inputs):
# r = inputs @ U 
# rh = Uh.T @ np.reshape(r, (int(num_level1*num_units_level1))) + J @ rh# np.reshape(rh, (int(num_level1*num_units_level1)))


# def calculate_total_error(self, error, errorh):
# recon_error = inv_sigma2*np.sum(error**2) + inv_sigma2_td*np.sum(errorh**2)
# sparsity_r = alpha*np.sum(r**2) + alphah*np.sum(rh**2)
# sparsity_U = lam*(np.sum(U**2) + np.sum(Uh**2))
#     return recon_error + sparsity_r + sparsity_U

# def __call__(self, inputs, training=False):
#     # inputs : (3, 256)
#     r_reshaped = np.reshape(r, (int(num_level1*num_units_level1))) # (96)

#     fx = r @ U.T
#     fxh = Uh @ rh # (96, )

#     # Calculate errors
#     error = inputs - fx # (3, 256)
#     errorh = r_reshaped - fxh # (96, ) 
#     errorh_reshaped = np.reshape(errorh, (num_level1, num_units_level1)) # (3, 32)

#     g_r = alpha * r / (1 + r**2) # (3, 32)
#     g_rh = alphah * rh / (1 + rh**2) # (64, )

#     temp = np.zeros((len(rh),len(rh)))
#     np.fill_diagonal(temp,rh)
#     dh = h + temp

#     # Update r and rh
#     dr = inv_sigma2 * error @ U - inv_sigma2_td * errorh_reshaped - g_r
#     drh = inv_sigma2_td * Uh.T @ errorh - g_rh + inv_sigma2_td * J@rh

#     dr = k1 * dr
#     drh = k1 * drh

#     # Updates                
#     r = r + dr
#     rh = rh + drh

#     if training:  
#         dU = inv_sigma2 * error.T @ r - 3*lam * U
#         dUh = inv_sigma2_td * np.outer(errorh, rh) - lam * Uh

#         dJ = inv_sigma2_td * np.outer(errorh, rh).T @ Uh *h

#         U += k2 * dU
#         Uh += k2 * dUh
#         J += k2 * dJ
#     return error, errorh, dr, drh, r, rh

# NN implementation

In [None]:
class RaoBallard1999Model(nn.Module):
    
    def __init__(self, dt=1, sigma2=1, sigma2_td=10):
        super(RaoBallard1999Model, self).__init__()
        self.dt = dt
        self.inv_sigma2 = 1/sigma2 # 1 / sigma^2        
        self.inv_sigma2_td = 1/sigma2_td # 1 / sigma_td^2
        
        self.k1 = 0.3 # k_1: update rate
        self.k2 = 0.2 # k_2: learning rate
        
        self.lam = 0.002 # sparsity rate
        self.alpha = 1
        self.alphah = 0.05
        
        self.num_units_level0 = 256
        self.num_units_level1 = 32
        self.num_units_level2 = 128
        self.num_level1 = 3
        
        self.U = torch.randn(self.num_units_level0, 
                            self.num_units_level1, requires_grad=True)
        self.Uh = torch.randn(int(self.num_level1*self.num_units_level1),
                             self.num_units_level2, requires_grad=True)
#         self.U = torch.Tensor(U*np.sqrt(2/(self.num_units_level0+self.num_units_level1))).requires_grad_()
#         self.Uh = torch.Tensor(Uh*np.sqrt(2/(int(self.num_level1*self.num_units_level1)+self.num_units_level2))).requires_grad_()
        self.r = torch.zeros((self.num_level1, self.num_units_level1), requires_grad=True)
        self.rh = torch.zeros((self.num_units_level2), requires_grad=True)
        self.o_r = torch.optim.Adam([self.r], lr =self.k1)
        self.o_rh = torch.optim.Adam([self.rh], lr =self.k1)
        self.o_U = torch.optim.Adam([self.U], lr =self.k2)
        self.o_Uh = torch.optim.Adam([self.Uh], lr =self.k2)
    def initialize_states(self, inputs):
        print(self.r.is_leaf)
        self.r = (inputs @ self.U).requires_grad_()
        self.rh = (self.Uh.T @ torch.reshape(self.r, (int(self.num_level1*self.num_units_level1),))).requires_grad_()
        print(self.r.is_leaf)

    def calculate_total_error(self, error, errorh):
        recon_error = self.inv_sigma2*torch.sum(error**2) + self.inv_sigma2_td*torch.sum(errorh**2)
        sparsity_r = self.alpha*torch.sum(self.r**2) + self.alphah*torch.sum(self.rh**2)
        sparsity_U = self.lam*(torch.sum(self.U**2) + torch.sum(self.Uh**2))
        return recon_error + sparsity_r + sparsity_U
        
    def forward(self, inputs, training=False):
        # inputs : (3, 256)
        r_reshaped = torch.reshape(self.r, (int(self.num_level1*self.num_units_level1),)) # (96)

        fx = self.r @ self.U.T
        fxh = self.Uh @ self.rh # (96, )
        
        # Calculate errors
        error = inputs - fx # (3, 256)
        errorh = r_reshaped - fxh # (96, ) 
        errorh_reshaped = torch.reshape(errorh, (self.num_level1, self.num_units_level1)) # (3, 32)
        self.o_r.zero_grad()
        self.o_rh.zero_grad()
        if training:
            self.o_U.zero_grad()
            self.o_Uh.zero_grad()
        loss = self.calculate_total_error(error,errorh)
        loss.backward(retain_graph=True)
        self.o_r.step()
        self.o_rh.step()
        if training:
            self.o_U.step()
            self.o_Uh.step()

        return error, errorh, self.r.grad, self.rh.grad

In [None]:
model = RaoBallard1999Model()
inputs = torch.Tensor(data_in)#.to(device)
model.initialize_states(inputs.cpu())
# print(model.r)
# error, errorh, dr, drh = model(inputs.cpu())
# print(model.r)

# model(inputs,training=True)

In [None]:
for iter_ in tqdm(range(num_iter)):
    inputs = np.random.multivariate_normal(mean=np.zeros(I_mat.shape[0]),cov=I_mat,size=(num_samps,))
    inputs = (inputs - np.mean(inputs)) * input_scale
    inputs = torch.Tensor(inputs)
    # Reset states
    model.initialize_states(inputs)
    
    # Input an image patch until latent variables are converged 
    for i in range(nt_max):
        # Update r and rh without update weights 
        error, errorh = model(inputs, training=False)
        
        # Compute norm of r and rh
        dr_norm = np.linalg.norm(dr, ord=2) 
        drh_norm = np.linalg.norm(drh, ord=2)
        
        # Check convergence of r and rh, then update weights
        if dr_norm < eps and drh_norm < eps:
            error, errorh, dr, drh = model(inputs, training=True)
            break
        
        # If failure to convergence, break and print error
        if i >= nt_max-2: 
            print("Error at patch:", iter_)
            print(dr_norm, drh_norm)
            break
   
    error_list.append(model.calculate_total_error(error, errorh)) # Append errors

    # Decay learning rate         
    if iter_ % 40 == 39:
        model.k2 /= 1.015
    
    # Print moving average error
    if iter_ % 1000 == 999:  
        print("iter: "+str(iter_+1)+"/"+str(num_iter)+", Moving error:", np.mean(error_list[iter_-999:iter_]))

In [None]:
num_units_level0 = 256
num_units_level1 = 32
num_units_level2 = 128
num_level1 = 3
feedforward = nn.ModuleDict()
feedback = nn.ModuleDict()
# for layer in range(len(hidden_dims)):
#     if layer == len(hidden_dims) - 1:
#         out_dim = hidden_dims[layer]
#     else: 
#         out_dim = hidden_dims[layer+1]
layer=0
feedforward.add_module('U{:d}'.format(layer),nn.Linear(num_units_level0, num_units_level1, bias=False))
feedforward.add_module('NonLinearity{:d}'.format(layer), nn.ReLU())
# feedforward.add_module('U{:d}_T'.format(layer),nn.Linear(num_units_level1, num_units_level0, bias=False))
#     in_dim = out_dim
#     if layer < len(hidden_dims)-1:
feedback.add_module('Uh{:d}'.format(layer),nn.Linear(num_level1*num_units_level1,num_units_level2, bias=False))
feedback.add_module('NonLinearity{:d}'.format(layer), nn.ReLU())
# with torch.no_grad():
#     feedback['Uh0'].weight = nn.Parameter(feedforward['U0'].weight.T)
optimizer = torch.optim.Adam(list(feedforward.parameters()) + list(feedback.parameters()), lr =.001)
feedforward.to(device)
feedback.to(device)
print(feedforward)
print(feedback)

In [None]:
feedback['Uh0'].weight.shape

In [None]:
inputs = torch.Tensor(data_in).to(device)
r = feedforward['U0'](inputs)
rh = feedback['Uh0'](r.view(-1))


fx = torch.matmul(r,feedforward['U0'].weight.T)
fxh = feedback['Uh0'](rh)

error = inputs - fx
errorh = r.view(-1) - fxh

In [None]:
feedforward['U0'].weight.T.shape, 

In [None]:
fxh.shape,r.shape

In [None]:
Nepochs = 100
# log = pd.DataFrame([],columns=['Loss'])
log = {'loss': []}
for epoch in tqdm(range(Nepochs)):
    result = torch.Tensor(data_in).to(device)
    for layer in range(n_layers-1):
        fback = feedback['Layer{:d}'.format(layer)](result)
        result = torch.abs(result - fback)
    loss = result.mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    log['loss'].append(loss.item())

In [None]:
log['activations'] = np.zeros((n_layers+1,) + tuple(result.shape))
with torch.no_grad():
    result = torch.Tensor(data_in).to(device)
    for layer in range(n_layers):
        l_activations = encoder['Layer{:d}'.format(layer)](result)
        result = torch.abs(result - l_activations)
        log['activations'][layer] = l_activations.cpu()
    log['activations'][layer+1] = result.cpu()

In [None]:
fig, axs = plt.subplots(1,3,figsize=(15,5))
for n in range(log['activations'].shape[0]):
    axs[n].pcolormesh(np.cov(log['activations'][n].T))
    axs[n].set_title('Layer {:d} Cov Matrix'.format(n))

In [None]:
plt.plot(log['loss'])

In [None]:
base_dir = Path('~/Research/SensoryMotorPred_Data').expanduser()

In [None]:
h5file = (base_dir / 'PredCoding_data_logs.h5').as_posix()

In [None]:
ioh5.save(h5file,log)

In [None]:
out = ioh5.load(h5file)