In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import datajoint as dj
from database import LnpFit, Net, Fit, NetFC, FitFC, NetFixedMask, FitFixedMask
from collections import OrderedDict
import tensorflow as tf

DataJoint 0.7.2 (June 1, 2017)
Loading settings from /gpfs01/bethge/home/aecker/.datajoint_config.json
Connecting aecker@52.202.15.166:3306


# Fit models

We run a grid search over multiple network architectures and regualrization settings. The model fits are stored in a database (using [DataJoint](https://datajoint.io)). The database is populated by running the Python script `populate.py`.

### Where is the code for fitting the models?

The implementations of the convolutional neural networks can be found in the module `convnet.py` in the class `ConvNet`. The grid search is implemented in the database framework (module `database.py`, class `Fit`).

# Table with results

In [2]:
def fetch_best(rel, *args):
    results = rel.fetch(*args, order_by='val_loss', limit=1)
    return [r[0] for r in results]

def get_n_layer_nets(region_num, num_layers):
    return list(Net().aggregate(
        Net.ConvLayer(), num_layers='count(*)').restrict(
        dict(region_num=region_num, num_layers=num_layers)).fetch(dj.key))

In [3]:
num_neurons = [103, 55, 102]
test_corrs = OrderedDict((
    ('Antolik', [0.51, 0.43, 0.46]),
    ('LNP', []),
    ('CNN 1 layer', []),
    ('CNN 2 layers', []),
    ('CNN 3 layers', []),
    ('CNN fully-connected readout', []),
    ('CNN fixed mask', []),
))
val_loss = OrderedDict((
    ('Antolik', [0, 0, 0]),
    ('LNP', []),
    ('CNN 1 layer', []),
    ('CNN 2 layers', []),
    ('CNN 3 layers', []),
    ('CNN fully-connected readout', []),
    ('CNN fixed mask', []),
))
best_net_key = OrderedDict((
    ('Antolik', []),
    ('LNP', []),
    ('CNN 1 layer', []),
    ('CNN 2 layers', []),
    ('CNN 3 layers', []),
    ('CNN fully-connected readout', []),
    ('CNN fixed mask', []),
))

In [4]:
for region_num in range(1, 4):
    region_key = {'region_num': region_num}
    r, l, k = fetch_best(LnpFit() & region_key, 'avg_corr', 'val_loss', dj.key)
    test_corrs['LNP'].append(r)
    val_loss['LNP'].append(l)
    best_net_key['LNP'].append(k)
    for n in range(1, 4):
        keys = get_n_layer_nets(region_num, num_layers=n)
        r, l, k = fetch_best(Fit() & region_key & keys, 'avg_corr', 'val_loss', dj.key)
        cnn = 'CNN {:d} layer'.format(n) + ('s' if n > 1 else '')
        test_corrs[cnn].append(r)
        val_loss[cnn].append(l)
        best_net_key[cnn].append(k)

    r, l, k = fetch_best(FitFC() & region_key, 'avg_corr', 'val_loss', dj.key)
    test_corrs['CNN fully-connected readout'].append(r)
    val_loss['CNN fully-connected readout'].append(l)
    best_net_key['CNN fully-connected readout'].append(k)
    r, l, k = fetch_best(FitFixedMask() & region_key, 'avg_corr', 'val_loss', dj.key)
    test_corrs['CNN fixed mask'].append(r)
    val_loss['CNN fixed mask'].append(l)
    best_net_key['CNN fixed mask'].append(k)

In [5]:
def results_table(results, n=None):
    print_avg = (n is not None)
    row = '{:30s}'.format('Region')
    for i in range(3):
        row += '  {:5d}'.format(i+1)
    if print_avg:
        row += '    Avg'
    print(row)
    print((58 if print_avg else 51) * '-')
    for model, val in results.items():
        row = '{:30s}'.format(model)
        for v in val:
            row += '  {:5.2f}'.format(v)
        if print_avg:
            avg = np.sum(np.array(val) * np.array(n)) / np.sum(n)
            row += '  {:5.2f}'.format(avg)
        print(row)
    print(' ')

print('Average correlations on test set')
results_table(test_corrs, num_neurons)

print('Loss on validation set')
results_table(val_loss)

Average correlations on test set
Region                              1      2      3    Avg
----------------------------------------------------------
Antolik                          0.51   0.43   0.46   0.47
LNP                              0.37   0.30   0.38   0.36
CNN 1 layer                      0.38   0.32   0.35   0.35
CNN 2 layers                     0.53   0.43   0.47   0.49
CNN 3 layers                     0.55   0.45   0.49   0.50
CNN fully-connected readout      0.47   0.34   0.43   0.43
CNN fixed mask                   0.45   0.38   0.41   0.42
 
Loss on validation set
Region                              1      2      3
---------------------------------------------------
Antolik                          0.00   0.00   0.00
LNP                             84.09  46.01  85.44
CNN 1 layer                     84.47  46.26  86.31
CNN 2 layers                    80.04  43.65  80.65
CNN 3 layers                    79.57  43.62  80.27
CNN fully-connected readout     82.03  45.25  8

# Network architectures of best-performing networks

### Our CNN

In [6]:
for i, k in enumerate(best_net_key['CNN 3 layers']):
    print('Region {:d}'.format(i+1))
    print(Net.ConvLayer() & k)

Region 1
*net_id    *layer_num    filter_size    out_channels   stride     padding     rel_smooth_wei rel_sparse_wei
+--------+ +-----------+ +------------+ +------------+ +--------+ +---------+ +------------+ +------------+
6          1             13             48             1          VALID       1.0            0.0           
6          2             3              48             1          SAME        0.0            1.0           
6          3             3              48             1          SAME        0.0            1.0           
 (3 tuples)

Region 2
*net_id    *layer_num    filter_size    out_channels   stride     padding     rel_smooth_wei rel_sparse_wei
+--------+ +-----------+ +------------+ +------------+ +--------+ +---------+ +------------+ +------------+
6          1             13             48             1          VALID       1.0            0.0           
6          2             3              48             1          SAME        0.0            1.0         

### CNN with fixed location mask estimated in advance

In [7]:
for i, k in enumerate(best_net_key['CNN fixed mask']):
    print('Region {:d}'.format(i+1))
    print(NetFixedMask.ConvLayer() & k)

Region 1
*net_id    *layer_num    filter_size    out_channels   stride     padding     rel_smooth_wei rel_sparse_wei
+--------+ +-----------+ +------------+ +------------+ +--------+ +---------+ +------------+ +------------+
2          1             13             48             1          VALID       1.0            0.0           
2          2             3              48             1          SAME        0.0            1.0           
2          3             3              48             1          SAME        0.0            1.0           
 (3 tuples)

Region 2
*net_id    *layer_num    filter_size    out_channels   stride     padding     rel_smooth_wei rel_sparse_wei
+--------+ +-----------+ +------------+ +------------+ +--------+ +---------+ +------------+ +------------+
2          1             13             48             1          VALID       1.0            0.0           
2          2             3              48             1          SAME        0.0            1.0         

### CNN with fully-connected readout

In [8]:
for i, k in enumerate(best_net_key['CNN fully-connected readout']):
    print('Region {:d}'.format(i+1))
    print(NetFC.ConvLayer() & k)

Region 1
*net_id    *layer_num    filter_size    out_channels   stride     padding     rel_smooth_wei rel_sparse_wei
+--------+ +-----------+ +------------+ +------------+ +--------+ +---------+ +------------+ +------------+
8          1             13             48             1          VALID       1.0            0.0           
8          2             3              32             1          SAME        0.0            1.0           
8          3             3              4              1          SAME        0.0            1.0           
 (3 tuples)

Region 2
*net_id    *layer_num    filter_size    out_channels   stride     padding     rel_smooth_wei rel_sparse_wei
+--------+ +-----------+ +------------+ +------------+ +--------+ +---------+ +------------+ +------------+
1          1             13             32             1          VALID       1.0            0.0           
1          2             3              16             1          SAME        0.0            1.0         