In [2]:
%matplotlib notebook

In [3]:
import sys, os, time, gc, click, logging, pprint
from os.path import expanduser
from collections import defaultdict
import numpy as np
from mpi4py import MPI
import neuroh5
from neuroh5.io import append_cell_attributes, read_population_ranges, bcast_cell_attributes, \
    scatter_read_cell_attributes, read_cell_attribute_selection, NeuroH5ProjectionGen
import h5py
import matplotlib.pyplot as plt


In [4]:
user_home = expanduser("~")
neuron_home = '%s/bin/nrnpython3/lib/python' % user_home
model_home = '%s/src/model' % user_home
dentate_home = '%s/dentate' % model_home
sys.path.append(neuron_home) 
sys.path.append(model_home) 

import dentate
from dentate.env import Env
from dentate import utils, stimulus, synapses
from dentate.utils import *

  return f(*args, **kwds)


In [5]:
def detect_topological_peaks(im):
    """
    Peak detect on 2D images via persistent homology.
    Author: Stefan Huber <shuber@sthu.org>"
    """

    def get(im, p):
        return im[p[0]][p[1]]

    def iter_neighbors(p, w, h):
        y, x = p

        # 8-neighborship
        neigh = [(y+j, x+i) for i in [-1, 0, 1] for j in [-1, 0, 1]]

        for j, i in neigh:
            if j < 0 or j >= h:
                continue
            if i < 0 or i >= w:
                continue
            if j == y and i == x:
                continue
            yield j, i

    h, w = im.shape

    # Get indices orderd by value from high to low
    indices = [(i, j) for i in range(h) for j in range(w)]
    indices.sort(key=lambda p: get(im, p), reverse=True)

    # Maintains the growing sets
    uf = UnionFind()

    groups0 = {}

    def get_comp_birth(p):
        return get(im, uf[p])

    # Process pixels from high to low
    for i, p in enumerate(indices):
        v = get(im, p)
        ni = [uf[q] for q in iter_neighbors(p, w, h) if q in uf]
        nc = sorted([(get_comp_birth(q), q) for q in set(ni)], reverse=True)

        if i == 0:
            groups0[p] = (v, v, None)

        uf.add(p, -i)

        if len(nc) > 0:
            oldp = nc[0][1]
            uf.union(oldp, p)

            # Merge all others with oldp
            for bl, q in nc[1:]:
                if uf[q] not in groups0:
                    groups0[uf[q]] = (bl, bl-v, p)
                uf.union(oldp, q)

    groups0 = [(k, groups0[k][0], groups0[k][1], groups0[k][2]) for k in groups0]
    groups0.sort(key=lambda g: g[2], reverse=True)

    return groups0

In [6]:
template_paths='%s/templates' % dentate_home
dataset_prefix='%s/datasets' % dentate_home
config_prefix='%s/config' % dentate_home


In [124]:
destination = 'MC'
sources = ['CA3c']
non_structured_sources = ['MC', 'GC']  # []
synapse_name='AMPA'
config_file='Network_Clamp_GC_Exc_Sat_SLN_IN_Izh_proximal_pf.yaml'
output_weights_namespace='Structured Weights'
h5types_path='%s/dentate_h5types.h5' % dataset_prefix
connections_path='%s/Slice/dentatenet_Full_Scale_GC_Exc_Sat_SLN_proximal_pf_20210625.h5' % dataset_prefix
input_features_path="%s/Full_Scale_Control/DG_input_features_20200910_compressed.h5" % dataset_prefix
initial_weights_namespace="Log-Normal Weights"  # None  # 
initial_weights_path='%s/Slice/dentatenet_Full_Scale_GC_Exc_Sat_SLN_proximal_pf_20210625.h5' % dataset_prefix
non_structured_weights_namespace="Normal Weights"  # None  # 
non_structured_weights_path='%s/Slice/dentatenet_Full_Scale_GC_Exc_Sat_SLN_proximal_pf_20210625.h5' % dataset_prefix

arena_id='A' 
activity_dependent=True
use_arena_margin=True
coordinates=(None, None) 
optimize_tol=1e-3
optimize_grad=True
max_delta_weight=20
max_weight_decay_fraction = 0.9
field_width_scale = 1.25
target_amplitude=3.0
#target_gid_set = set([1008431, 1021757, 1011284])
target_gid_set = set([1008431])


In [104]:
np.seterr(all='raise')
os.chdir(dentate_home)
env = Env(config_prefix=config_prefix,config_file=config_file,template_paths=template_paths,dataset_prefix=dataset_prefix)


In [105]:
input_features_namespaces = ['Place Selectivity', 'Grid Selectivity']
this_input_features_namespaces = ['%s %s' % (input_features_namespace, arena_id) 
                                  for input_features_namespace in input_features_namespaces]

selectivity_type_index = { i: n for n, i in viewitems(env.selectivity_types) }
target_selectivity_type_name = 'place'
target_selectivity_type = env.selectivity_types[target_selectivity_type_name]
features_attrs = defaultdict(dict)
source_features_attr_names = ['Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
                                  'Module ID', 'Grid Spacing', 'Grid Orientation', 'Field Width Concentration Factor', 
                                  'X Offset', 'Y Offset']
target_features_attr_names = ['Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 
                                  'X Offset', 'Y Offset']

local_random = np.random.RandomState()

seed_offset = int(env.model_config['Random Seeds']['GC Structured Weights'])
spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm

arena = env.stimulus_config['Arena'][arena_id]
default_run_vel = arena.properties['default run velocity']  # cm/s

gid_count = 0
all_sources = sources + non_structured_sources


In [106]:
connection_gen_list = [ NeuroH5ProjectionGen(connections_path, source, destination, namespaces=['Synapses']) 
                        for source in all_sources ]
field_width = None
peak_rate = None
structured_weights_dict = {}


for iter_count, attr_gen_package in enumerate(zip_longest(*connection_gen_list)):
        
        this_gid = attr_gen_package[0][0]
        if not all([attr_gen_items[0] == this_gid for attr_gen_items in attr_gen_package]):
            raise Exception('destination: %s; this_gid not matched across multiple attribute '
                            'generators: %s' % (destination,
                                                [attr_gen_items[0] for attr_gen_items in attr_gen_package]))

        if (target_gid_set is not None) and (this_gid not in target_gid_set):
            continue


        if this_gid is None:
            selection = []
            logger.info('received None')
        else:
            selection = [this_gid]
            local_random.seed(int(this_gid + seed_offset))

        has_structured_weights = False

        dst_input_features_attr_dict = {}
        for input_features_namespace in this_input_features_namespaces:
            print(input_features_namespace, destination, selection)
            input_features_iter = read_cell_attribute_selection(input_features_path, destination, 
                                                                namespace=input_features_namespace,
                                                                mask=set(target_features_attr_names), 
                                                                selection=selection)
            count = 0
            for gid, attr_dict in input_features_iter:
                dst_input_features_attr_dict[gid] = attr_dict
                count += 1
            logger.info('Read %s feature data for %i cells in population %s' % (input_features_namespace, count, destination))

        arena_margin = 0.
        target_selectivity_features_dict = {}
        target_selectivity_config_dict = {}
        target_field_width_dict = {}
        for gid in selection:
            target_selectivity_features_dict[gid] = dst_input_features_attr_dict.get(gid, {})
            target_selectivity_features_dict[gid]['Selectivity Type'] = np.asarray([target_selectivity_type], dtype=np.uint8)

            num_fields = target_selectivity_features_dict[gid]['Num Fields'][0]
            
            if coordinates[0] is not None:
                num_fields = 1
                target_selectivity_features_dict[gid]['X Offset'] =  np.asarray([coordinates[0]], dtype=np.float32)
                target_selectivity_features_dict[gid]['Y Offset'] =  np.asarray([coordinates[1]], dtype=np.float32)
                target_selectivity_features_dict[gid]['Num Fields'] = np.asarray([num_fields], dtype=np.uint8)
                
                
            if field_width is not None:
                target_selectivity_features_dict[gid]['Field Width'] = np.asarray([field_width]*num_fields, dtype=np.float32)
            else:
                this_field_width = target_selectivity_features_dict[gid]['Field Width']
                target_selectivity_features_dict[gid]['Field Width'] = this_field_width[:num_fields]
                
            if peak_rate is not None:
                target_selectivity_features_dict[gid]['Peak Rate'] = np.asarray([peak_rate]*num_fields, dtype=np.float32)

            input_cell_config = stimulus.get_input_cell_config(target_selectivity_type,
                                                               selectivity_type_index,
                                                               selectivity_attr_dict=target_selectivity_features_dict[gid])
            if input_cell_config.num_fields > 0:
                arena_margin = max(arena_margin, np.max(input_cell_config.field_width) / 2.) if use_arena_margin else 0.
                target_field_width_dict[gid] = input_cell_config.field_width
                target_selectivity_config_dict[gid] = input_cell_config
                has_structured_weights = True

        arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(arena, spatial_resolution, margin=arena_margin)
        arena_x_grid, arena_y_grid = stimulus.get_2D_arena_grid(arena, spatial_resolution, margin=arena_margin)
               
        if not has_structured_weights:
            selection = []
                
        initial_weights_by_syn_id_dict = defaultdict(lambda: dict())
        initial_weights_by_source_gid_dict = defaultdict(lambda: dict())
        if initial_weights_path is not None:
            initial_weights_iter = \
              read_cell_attribute_selection(initial_weights_path, destination,
                                            namespace=initial_weights_namespace,
                                            selection=selection)

            initial_weights_gid_count = 0
            for this_gid, syn_weight_attr_dict in initial_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    initial_weights_by_syn_id_dict[this_gid][int(syn_id)] = float(weight)
                initial_weights_gid_count += 1

            logger.info('destination: %s; read initial synaptic weights for %i gids' %
                        (destination, initial_weights_gid_count))
            
        if len(non_structured_sources) > 0:
            non_structured_weights_by_syn_id_dict = defaultdict(lambda: dict())
            non_structured_weights_by_source_gid_dict = defaultdict(lambda: dict())
        else:
            non_structured_weights_by_syn_id_dict = None
            
        if non_structured_weights_path is not None:
            non_structured_weights_iter = \
                read_cell_attribute_selection(initial_weights_path, destination,
                                              namespace=non_structured_weights_namespace,
                                              selection=selection)

            non_structured_weights_gid_count = 0
            non_structured_weights_syn_count = 0
            for this_gid, syn_weight_attr_dict in non_structured_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    non_structured_weights_by_syn_id_dict[this_gid][int(syn_id)] = float(weight)
                non_structured_weights_gid_count += 1
                non_structured_weights_syn_count += len(syn_ids)     

        syn_count_by_source_gid_dict = defaultdict(int)
        source_gid_set_dict = defaultdict(set)
        syn_ids_by_source_gid_dict = defaultdict(list)
        structured_syn_id_count = 0

        if has_structured_weights:
            for source, (destination_gid, (source_gid_array, conn_attr_dict)) in zip_longest(all_sources, attr_gen_package):
                syn_ids = conn_attr_dict['Synapses']['syn_id']
                count = 0
                this_initial_weights_by_syn_id_dict = None
                this_initial_weights_by_source_gid_dict = None
                this_non_structured_weights_by_syn_id_dict = None
                this_non_structured_weights_by_source_gid_dict = None
                
                if destination_gid is not None:
                    this_initial_weights_by_syn_id_dict = initial_weights_by_syn_id_dict[destination_gid]
                    this_initial_weights_by_source_gid_dict = initial_weights_by_source_gid_dict[destination_gid]
                    this_non_structured_weights_by_syn_id_dict = non_structured_weights_by_syn_id_dict[destination_gid]
                    this_non_structured_weights_by_source_gid_dict = non_structured_weights_by_source_gid_dict[destination_gid]


                for i in range(len(source_gid_array)):
                    this_source_gid = source_gid_array[i]
                    this_syn_id = syn_ids[i]
                    if this_syn_id in this_initial_weights_by_syn_id_dict:
                        this_syn_wgt = this_initial_weights_by_syn_id_dict[this_syn_id]
                        if this_source_gid not in this_initial_weights_by_source_gid_dict:
                            this_initial_weights_by_source_gid_dict[this_source_gid] = this_syn_wgt
                        
                    elif this_syn_id in this_non_structured_weights_by_syn_id_dict:
                        this_syn_wgt = this_non_structured_weights_by_syn_id_dict[this_syn_id]
                        if this_source_gid not in this_non_structured_weights_by_source_gid_dict:
                            this_non_structured_weights_by_source_gid_dict[this_source_gid] = this_syn_wgt
                    source_gid_set_dict[source].add(this_source_gid)
                    syn_ids_by_source_gid_dict[this_source_gid].append(this_syn_id)
                    syn_count_by_source_gid_dict[this_source_gid] += 1
                    
                    count += 1
                if source not in non_structured_sources:
                    structured_syn_id_count += len(syn_ids)
                logger.info('destination: %s; gid %i; %d edges from source population %s' %
                            (destination, this_gid, count, source))


        input_rate_maps_by_source_gid_dict = {}
        if len(non_structured_sources) > 0:
            non_structured_input_rate_maps_by_source_gid_dict = {}
        else:
            non_structured_input_rate_maps_by_source_gid_dict = None
        for source in all_sources:
            if has_structured_weights:
                source_gids = list(source_gid_set_dict[source])
            else:
                source_gids = []
            logger.info('Reading %s feature data for %i cells in population %s...' % (input_features_namespace, len(source_gids), source))
            for input_features_namespace in this_input_features_namespaces:
                input_features_iter = read_cell_attribute_selection(input_features_path, source, 
                                                                    namespace=input_features_namespace,
                                                                    mask=set(source_features_attr_names), 
                                                                    comm=env.comm, selection=source_gids)
                count = 0
                for gid, attr_dict in input_features_iter:
                    this_selectivity_type = attr_dict['Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_index[this_selectivity_type]
                    input_cell_config = stimulus.get_input_cell_config(this_selectivity_type,
                                                                       selectivity_type_index,
                                                                       selectivity_attr_dict=attr_dict)
                    this_arena_rate_map = np.asarray(input_cell_config.get_rate_map(arena_x, arena_y),
                                                     dtype=np.float32)
                    if source in non_structured_sources:
                        non_structured_input_rate_maps_by_source_gid_dict[gid] = this_arena_rate_map
                    else:
                        input_rate_maps_by_source_gid_dict[gid] = this_arena_rate_map
                    count += 1
                
                logger.info('Read %s feature data for %i cells in population %s' % (input_features_namespace, count, source))



        if has_structured_weights:

            num_fields = target_selectivity_features_dict[this_gid]['Num Fields'][0]
            this_input_cell_config = target_selectivity_config_dict[this_gid]
            this_initial_weights_by_source_gid_dict = initial_weights_by_source_gid_dict[this_gid]
            
            
            random_target_map = np.asarray(this_input_cell_config.get_rate_map(arena_x, arena_y,
                                                                               scale=field_width_scale),
                                    dtype=np.float32)

            structured_source_input = None
            for source_gid in sorted(input_rate_maps_by_source_gid_dict):
                w = this_initial_weights_by_source_gid_dict[source_gid]
                c = syn_count_by_source_gid_dict[source_gid]
                this_input_rate_map = input_rate_maps_by_source_gid_dict[source_gid]
                this_input_rate_map[np.isclose(this_input_rate_map, 0., atol=1e-4, rtol=1e-4)] = 0.
                if np.max(this_input_rate_map) > 0.:
                    if structured_source_input is None:
                        structured_source_input = np.multiply(float(c) * w, this_input_rate_map)
                    else:
                        structured_source_input += np.multiply(float(c) * w, this_input_rate_map)
    
            
            print(f"this_input_cell_config.x0 = {this_input_cell_config.x0}")
            print(f"this_input_cell_config.y0 = {this_input_cell_config.y0}")
            g0 = None
            if activity_dependent:
                g0 = detect_topological_peaks(structured_source_input)
                n_peaks = len(g0)
                print(f"activity dependent: n peaks = {n_peaks} num_fields = {num_fields}")
                if n_peaks < num_fields:
                    num_fields = n_peaks
                this_input_cell_config.num_fields = num_fields
                this_input_cell_config.x0 = this_input_cell_config.x0[:num_fields]
                this_input_cell_config.y0 = this_input_cell_config.y0[:num_fields]
                target_selectivity_features_dict[this_gid]['Num Fields'][0] = num_fields
                for i in range(num_fields):
                    if i < n_peaks:
                        p_birth, birth_level, pers, p_death = g0[i]
                        xp = arena_x_grid[p_birth[0]]
                        yp = arena_y_grid[p_birth[1]]
                        print(f"activity dependent {i}: this_input_cell_config.x0 = {this_input_cell_config.x0[i]}")
                        print(f"activity dependent {i}: this_input_cell_config.y0 = {this_input_cell_config.y0[i]}")
                        this_input_cell_config.x0[i] = xp
                        this_input_cell_config.y0[i] = yp
                        print(f"activity dependent {i}: this_input_cell_config.x0 = {this_input_cell_config.x0[i]}")
                        print(f"activity dependent {i}: this_input_cell_config.y0 = {this_input_cell_config.y0[i]}")
            
            this_target_map = np.asarray(this_input_cell_config.get_rate_map(arena_x, arena_y,
                                                                             scale=field_width_scale),
                                    dtype=np.float32)
            target_selectivity_features_dict[this_gid]['Arena Rate Map'] = this_target_map

            structured_weights_dict[this_gid] = \
                {'target_map': this_target_map,
                 'random_target_map': random_target_map,
                 'initial_weight_dict': initial_weights_by_source_gid_dict[this_gid],
                 'non_structured_weight_dict': non_structured_weights_by_source_gid_dict[this_gid],
                 'input_rate_map_dict': input_rate_maps_by_source_gid_dict,
                 'non_structured_input_rate_map_dict': non_structured_input_rate_maps_by_source_gid_dict,
                 'syn_count_dict': syn_count_by_source_gid_dict,
                 'max_delta_weight': max_delta_weight, 
                 'arena_x': arena_x, 
                 'arena_y': arena_y,
                 'topological_peaks': g0,
                 'structured_input_rate_map': structured_source_input,
                }



Place Selectivity A MC [1008431]
Grid Selectivity A MC [1008431]
this_input_cell_config.x0 = [ -47.802284  -39.686314  111.69287   -45.15478   -53.77981    -5.095423
   87.844055  -41.783607 -143.88994   102.21789 ]
this_input_cell_config.y0 = [ 139.51607    -75.85112    -89.03725     57.371353  -132.41411
   14.532869    91.721176   -31.647476   -99.01918      1.3338622]
activity dependent: n peaks = 12 num_fields = 10
activity dependent 0: this_input_cell_config.x0 = -47.802284240722656
activity dependent 0: this_input_cell_config.y0 = 139.5160675048828
activity dependent 0: this_input_cell_config.x0 = -11.584270477294922
activity dependent 0: this_input_cell_config.y0 = 0.4157295227050781
activity dependent 1: this_input_cell_config.x0 = -39.68631362915039
activity dependent 1: this_input_cell_config.y0 = -75.85111999511719
activity dependent 1: this_input_cell_config.x0 = -5.584270477294922
activity dependent 1: this_input_cell_config.y0 = 117.41572570800781
activity dependent 2: t

In [107]:
for gid, this_structured_weights_dict in structured_weights_dict.items():
    fig, axs = plt.subplots(1, 3, figsize=(15,8))
    #fig.set_title("gid %d" % gid)
    structured_input_rate_map = this_structured_weights_dict['structured_input_rate_map']
    axs[0].imshow(structured_input_rate_map, aspect='equal', interpolation='nearest')
    target_rate_map = this_structured_weights_dict['target_map']
    axs[1].imshow(target_rate_map, aspect='equal', interpolation='nearest')
    random_rate_map = this_structured_weights_dict['random_target_map']
    axs[2].imshow(random_rate_map, aspect='equal', interpolation='nearest')


<IPython.core.display.Javascript object>

In [108]:
def get_input_arrays(structured_weights_dict, gid):

    target_map = structured_weights_dict[gid]['target_map']
    initial_weight_dict = structured_weights_dict[gid]['initial_weight_dict']
    input_rate_map_dict = structured_weights_dict[gid]['input_rate_map_dict']
    non_structured_input_rate_map_dict = structured_weights_dict[gid]['non_structured_input_rate_map_dict']
    non_structured_weights_dict = structured_weights_dict[gid]['non_structured_weight_dict']
    syn_count_dict = structured_weights_dict[gid]['syn_count_dict']
    arena_x = structured_weights_dict[gid]['arena_x']
    arena_y = structured_weights_dict[gid]['arena_y']

    input_matrix = np.empty((target_map.size, len(input_rate_map_dict)),
                            dtype=np.float64)
    source_gid_array = np.empty(len(input_rate_map_dict), dtype=np.uint32)
    syn_count_array = np.empty(len(input_rate_map_dict), dtype=np.uint32)
    initial_weight_array = np.empty(len(input_rate_map_dict), dtype=np.float64)
    for i, source_gid in enumerate(input_rate_map_dict):
        source_gid_array[i] = source_gid
        this_syn_count = syn_count_dict[source_gid]
        this_input = input_rate_map_dict[source_gid].ravel() * this_syn_count
        input_matrix[:, i] = this_input
        syn_count_array[i] = this_syn_count
        initial_weight_array[i] = initial_weight_dict[source_gid]


    non_structured_input_matrix = None
    if non_structured_input_rate_map_dict is not None:
        non_structured_input_matrix = np.empty((target_map.size, len(non_structured_input_rate_map_dict)),
                                               dtype=np.float32)
        non_structured_weight_array = np.empty(len(non_structured_input_rate_map_dict), dtype=np.float32)
        for i, source_gid in enumerate(non_structured_input_rate_map_dict):
            this_syn_count = syn_count_dict[source_gid]
            this_input = non_structured_input_rate_map_dict[source_gid].ravel() * this_syn_count
            non_structured_input_matrix[:, i] = this_input
            non_structured_weight_array[i] = non_structured_weights_dict.get(source_gid, 1.0)

    return {'target_map': target_map,
            'input_matrix': input_matrix, 
            'initial_weight_array': initial_weight_array, 
            'non_structured_input_matrix': non_structured_input_matrix, 
            'non_structured_weight_array': non_structured_weight_array, 
            'syn_count_array': syn_count_array, 
            'source_gid_array': source_gid_array}


In [109]:
def get_scaled_input_maps(target_amplitude, structured_weights_dict, gid):
    
    input_arrays_dict = get_input_arrays(structured_weights_dict, gid)
    
    target_map = input_arrays_dict['target_map']
    initial_weight_array = input_arrays_dict['initial_weight_array']
    input_matrix = input_arrays_dict['input_matrix']
    non_structured_weight_array = input_arrays_dict['non_structured_weight_array']
    non_structured_input_matrix = np.asarray(input_arrays_dict['non_structured_input_matrix'], dtype=np.float64)
    
    mean_initial_weight = np.mean(initial_weight_array)
    initial_background_map = np.dot(input_matrix, initial_weight_array) + \
                             np.dot(non_structured_input_matrix, non_structured_weight_array)
    if np.mean(initial_background_map)<= 0.:
       raise RuntimeError('generate_structured_delta_weights: initial weights must produce positive activation')

    mean_initial_background = np.mean(initial_background_map)
    scaled_background_map = initial_background_map / mean_initial_background
    scaled_background_map -= 1.
    # I don't think the min of target_map should be subtracted here. If a target map has nonzero background, we need
    # structured weights to provide that (e.g. very wide fields or multiple fields)
    # scaled_target_map = np.asarray(target_map.flat - np.min(target_map), dtype=np.float64)
    scaled_target_map = np.asarray(target_map, dtype=np.float64)

    if np.max(scaled_target_map) > 0.:
       target_map_scaling_factor = target_amplitude / np.max(target_map)
       scaled_target_map = scaled_target_map * target_map_scaling_factor
    scaled_target_map = scaled_target_map

    scaled_input_matrix = input_matrix / mean_initial_background
    scaled_non_structured_input_matrix = non_structured_input_matrix / mean_initial_background

    return {'scaled_input_matrix' : scaled_input_matrix,
            'scaled_non_structured_input_matrix': scaled_non_structured_input_matrix,
            'scaled_target_map': scaled_target_map,
            'scaled_background_map': scaled_background_map,
            'initial_background_map': initial_background_map,
            'mean_initial_background': mean_initial_background,
            'mean_initial_weight': mean_initial_weight
           }
    

# The activity of the inputs is designed to be uniform inside the arena. By adding an additional arena margin,
# you are revealing a decrease in mean input activity at the margins. Perhaps the mean could be calculated only
# within arena.

In [110]:

scaled_maps_dict = { gid: get_scaled_input_maps (target_amplitude, structured_weights_dict, gid) 
                   for gid, structured_weights in structured_weights_dict.items() }


for gid, scaled_maps in scaled_maps_dict.items():
    fig, axs = plt.subplots(1, 2, figsize=(15,8))
    #plt.title("gid %d" % gid)
    scaled_target_map = scaled_maps['scaled_target_map']
    scaled_background_map = scaled_maps['scaled_background_map'].reshape(scaled_target_map.shape)
    axs[0].imshow(scaled_target_map, aspect='equal', interpolation='nearest')
    axs[1].imshow(scaled_background_map, aspect='equal', interpolation='nearest')

<IPython.core.display.Javascript object>

In [117]:
lsqr_dict = {}
for gid, scaled_maps in scaled_maps_dict.items():
    scaled_target_map = scaled_maps['scaled_target_map'].flat
    scaled_background_map = scaled_maps['scaled_background_map']
    scaled_input_matrix = scaled_maps['scaled_input_matrix']

    lsqr_target_map = scaled_target_map - scaled_background_map

    res = scipy.sparse.linalg.lsmr(scaled_input_matrix,
                                   lsqr_target_map,
                                   damp=0.1, show=True)
    lsqr_delta_weights = np.asarray(res[0], dtype=np.float32)    
    
    lsqr_dict[gid] = {'lsqr_delta_weights': lsqr_delta_weights,
                       'lsqr_target_map': lsqr_target_map}


 
LSMR            Least-squares solution of  Ax = b

The matrix A has 10609 rows and 162 columns
damp = 1.00000000000000e-01

atol = 1.00e-06                 conlim = 1.00e+08

btol = 1.00e-06             maxiter =      162

 
   itn      x(1)       norm r    norm Ar  compatible   LS      norm A   cond A
     0  0.00000e+00  1.318e+02  1.887e+02   1.0e+00  1.1e-02
     1  0.00000e+00  8.777e+01  5.395e+01   6.7e-01  3.2e-01  1.9e+00  1.0e+00
     2  0.00000e+00  6.935e+01  3.230e+01   5.3e-01  2.1e-01  2.2e+00  2.1e+00
     3  0.00000e+00  4.341e+01  1.525e+01   3.3e-01  1.4e-01  2.5e+00  3.2e+00
     4  0.00000e+00  3.626e+01  8.149e+00   2.8e-01  8.6e-02  2.6e+00  2.8e+00
     5  0.00000e+00  3.312e+01  5.673e+00   2.5e-01  6.3e-02  2.7e+00  3.7e+00
     6  0.00000e+00  3.017e+01  3.746e+00   2.3e-01  4.4e-02  2.8e+00  4.5e+00
     7  0.00000e+00  2.862e+01  2.640e+00   2.2e-01  3.2e-02  2.9e+00  4.7e+00
     8  0.00000e+00  2.754e+01  1.776e+00   2.1e-01  2.2e-02  3.0e+00  5.2e+00
 

In [125]:
for gid, lsqr_arrays in lsqr_dict.items():

    lsqr_delta_weights = lsqr_arrays['lsqr_delta_weights']
    scaled_maps = scaled_maps_dict[gid]
    scaled_target_map = scaled_maps['scaled_target_map'].flat
    scaled_background_map = scaled_maps['scaled_background_map']
    scaled_input_matrix = scaled_maps['scaled_input_matrix']
    scaled_non_structured_input_matrix = scaled_maps['scaled_non_structured_input_matrix']
    
    input_arrays_dict = get_input_arrays(structured_weights_dict, gid)
    initial_weight_array = input_arrays_dict['initial_weight_array']
    non_structured_weight_array = input_arrays_dict['non_structured_weight_array']

    opt_bounds = [ (-(max_weight_decay_fraction * x), max_delta_weight)
                   for x in initial_weight_array ]

    initial_LS_bound_arrays = (np.asarray([b[0] for b in opt_bounds]), np.asarray([b[1] for b in opt_bounds]))
    initial_LS_delta_weights = np.clip(lsqr_delta_weights, initial_LS_bound_arrays[0], initial_LS_bound_arrays[1])

    lsqr_arrays['LS_weight_bounds'] = opt_bounds
    lsqr_arrays['initial_LS_delta_weights'] = initial_LS_delta_weights

    lsqr_map = np.dot(scaled_input_matrix, lsqr_delta_weights + initial_weight_array) - 1
    if scaled_non_structured_input_matrix is not None:
        lsqr_map += np.dot(scaled_non_structured_input_matrix, non_structured_weight_array)
    initial_LS_map = np.dot(scaled_input_matrix, initial_LS_delta_weights + initial_weight_array) - 1
    if scaled_non_structured_input_matrix is not None:
        initial_LS_map += np.dot(scaled_non_structured_input_matrix, non_structured_weight_array)

    
    fig, axes = plt.subplots(1, 2, figsize=(9., 4.6))
    plt.title('gid %d' % gid)
    axes[0].plot(scaled_target_map, c='r', alpha=0.25, label='Target')
    axes[1].plot(scaled_target_map, c='r', alpha=0.25, label='Target')
    axes[0].plot(lsqr_map, c='b', alpha=0.25, label='LSQR')
    axes[1].plot(initial_LS_map, c='g', alpha=0.25, label='Truncated LSQR')
    axes[0].legend(loc='best', frameon=False, framealpha=0.5)
    axes[1].legend(loc='best', frameon=False, framealpha=0.5)

    plt.figure()
    hist, edges = np.histogram(lsqr_delta_weights, bins=50)
    plt.semilogy(edges[:-1], hist, color='r', alpha=0.25, label='LSQR')
    hist, edges = np.histogram(initial_LS_delta_weights, bins=edges)
    plt.semilogy(edges[:-1], hist, color='b', alpha=0.25, label='Truncated LSQR')
    plt.legend(loc='best', frameon=False, framealpha=0.5)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [126]:
def activation_map_residual(delta_weights, input_matrix, target_map):
    a = np.dot(input_matrix, delta_weights)
    e = np.subtract(target_map, a)
                                           
    res = np.square(e).mean() / 2.
    return np.asarray(res, dtype=np.float64)

def activation_map_residual_grad(weights, input_matrix, target_map):
    N = weights.shape[0]
    a = np.dot(input_matrix, weights)
    e = np.subtract(target_map, a)
    grad = -1./N * np.dot(input_matrix.T, e)
    return np.asarray(grad, dtype=np.float64)


In [127]:
import scipy.optimize as opt
optimize_method = 'L-BFGS-B'

method_options = {'disp': True, 'maxiter': 10000}
method_options['maxfun'] = 1000000


LS_weights_dict = {}
for gid, lsqr_arrays in lsqr_dict.items():
    lsqr_target_map = lsqr_arrays['lsqr_target_map']
    initial_LS_delta_weights = lsqr_arrays['initial_LS_delta_weights']
    LS_weight_bounds = lsqr_arrays['LS_weight_bounds']
    scaled_input_matrix = scaled_maps_dict[gid]['scaled_input_matrix']
    result = opt.minimize(activation_map_residual,
                          initial_LS_delta_weights, 
                          jac=activation_map_residual_grad if optimize_grad else None,
                          args=(scaled_input_matrix, lsqr_target_map.flat),
                          method=optimize_method,
                          bounds=LS_weight_bounds,
                          tol=optimize_tol,
                          options=method_options)
    LS_weights_dict[gid]  = {'LS_delta_weights': np.array(result.x)}


In [128]:
for gid, LS_arrays in LS_weights_dict.items():
    
    LS_delta_weights = LS_arrays['LS_delta_weights']
    
    scaled_maps = scaled_maps_dict[gid]
    scaled_input_matrix = scaled_maps['scaled_input_matrix']
    scaled_non_structured_input_matrix = scaled_maps['scaled_non_structured_input_matrix']
    scaled_target_map = scaled_maps['scaled_target_map'].flat
 
    input_arrays_dict = get_input_arrays(structured_weights_dict, gid)
    initial_weight_array = input_arrays_dict['initial_weight_array']
    non_structured_weight_array = input_arrays_dict['non_structured_weight_array']

    lsqr_arrays = lsqr_dict[gid]
    initial_LS_delta_weights = lsqr_arrays['initial_LS_delta_weights']

    LS_delta_map = np.dot(scaled_input_matrix, LS_delta_weights + initial_weight_array) - 1
    if scaled_non_structured_input_matrix is not None:
        LS_delta_map += np.dot(scaled_non_structured_input_matrix, non_structured_weight_array)

    
    initial_LS_map = np.dot(scaled_input_matrix, initial_LS_delta_weights + initial_weight_array) - 1
    if scaled_non_structured_input_matrix is not None:
        initial_LS_map += np.dot(scaled_non_structured_input_matrix, non_structured_weight_array)
    
    normalized_delta_weights_array = LS_delta_weights / np.max(LS_delta_weights)
    scaled_LS_delta_weights = LS_delta_weights * target_amplitude / np.max(LS_delta_map)
    
    scaled_LS_delta_map = np.dot(scaled_input_matrix, scaled_LS_delta_weights + initial_weight_array) - 1
    if scaled_non_structured_input_matrix is not None:
        scaled_LS_delta_map += np.dot(scaled_non_structured_input_matrix, non_structured_weight_array)
    print('max_delta_weight would have to increase to %.1f to reach target_amplitude' % (np.max(scaled_LS_delta_weights))) 

    fig, axes = plt.subplots(1, 3, figsize=(10., 4.8))
    plt.title('gid %d' % gid)
    axes[0].plot(scaled_target_map, c='r', alpha=0.25, label='Target')
    axes[1].plot(scaled_target_map, c='r', alpha=0.25, label='Target')
    axes[2].plot(scaled_target_map, c='r', alpha=0.25, label='Target')
    axes[0].plot(initial_LS_map, c='b', alpha=0.25, label="Truncated LSQR")
    axes[1].plot(LS_delta_map, c='g', alpha=0.25, label="L-BFGS")
    axes[2].plot(scaled_LS_delta_map, c='orange', alpha=0.25, label="Scaled L-BFGS")
    axes[0].legend(loc='best', frameon=False, framealpha=0.5)
    axes[1].legend(loc='best', frameon=False, framealpha=0.5)
    axes[2].legend(loc='best', frameon=False, framealpha=0.5)

    plt.figure()
    plt.title('gid %d' % gid)
    hist, edges = np.histogram(initial_LS_delta_weights, bins=50)
    plt.semilogy(edges[:-1], hist, color='r', alpha=0.25, label='Trucated LSQR')
    hist, edges = np.histogram(LS_delta_weights, bins=edges)
    plt.semilogy(edges[:-1], hist, color='b', alpha=0.25, label='L-BFGS')
    plt.legend(loc='best', frameon=False, framealpha=0.5)

max_delta_weight would have to increase to 23.5 to reach target_amplitude


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>