# Analysis of the Segmentation of V1, V2, and V3 by Convolutional Neural Networks

## About

### Authors

[Noah C. Benson](mailto:nben@uw.edu)$^{1}$ and [Bogeng Song](mailto:bs4283@nyu.edu)$^{2,3}$

$^1$eScience Institute, University of Washington, Seattle, United States  
$^2$Department of Psychology, New York University, New York City, United States  
$^3$(Current Affiliation) Department of Psychology, Georgia Institute of Technology, Georgia, United States

### How to Use this Notebook

This notebook contains documentation, analyses, and visualizations that were performed by Benson, Song, et al. (2025) in their report on the segmentation of V1, V2, and V3 using convolutional neural networks. It was published together with a virtual machine, containerized in a docker image, with the intention of providing a persistent means of reproducing the computation in the original paper. To duplicate the environment used in the analysis of the publication, the following instructions are provided:

1. **Obtain access to the Human Connectome Project (HCP) data**.
    1. **Register at the [HCP connectome database page](https://db.humanconnectome.org/)**.
    2. **Obtain and save AWS access credentials**. Once you have an account, log into the database; near the top of the initial splash page is a cell titled "WU-Minn HCP Data - 1200 Subjects", and inside this cell is a button for activating Amazon S3 Access. When you activate this feature, you will be given a "Key" and a "Secret". These should be saved to the file `${HOME}/.aws/credentials` under the heading `[hcp]` where `${HOME}` is your home directory. For more information about the format of this file, see [this page](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html), but, in brief: if you don't already have a credentials file, you can put the following block of text in it, and if you do already have such a file, you can append the following text to the end of it except with the `______________` replaced with your key and the `********************` replaced with your secret.
       ```
       [hcp]
       aws_access_key_id = ______________
       aws_secret_access_key = ********************
       ```
    3. **Obtain access to the restricted dataset and save the restricted data file**. To obtain access to the restricted data, you need to sign an agreement and send it to the HCP. For more information on restricted data access, see [this page](https://www.humanconnectome.org/study/hcp-young-adult/document/restricted-data-usage). Once your access has been granted, you should receive instructions on obtaining the restricted data CSV file.
2. **Set up Docker and obtain the docker image**.
    1. **Install and start [docker](https://docker.com/)'s `docker-desktop` service**. See [this page](https://docs.docker.com/get-started/get-docker/) for download and installation instructions. Note that docker requires administrator privileges.
    2. **Download the `analysis.tar.gz` docker image from the repository** (DOI:[10.5281/zenodo.14502583](https://doi.org/10.5281/zenodo.14502583)).
    3. **Unzip the analysis image file**.
       ```bash
       gunzip analyziz.tar.gz
       ```
    4. **Load the docker image**.  
       ```bash
       docker load -i analysis.tar
       ```
3. **Use the `docker run` command to start the docker image**. This command should be structured as follows with the exception that text in angle-brackets (`<text>`), including the brackets, should be replaced with an appropriate substitution for the local machine on which you are running the virtual machine.  
   ```bash
   docker run --rm -it \
       -p 8888:8888 \
       -v <HCP-restricted-data>:/data/hcp/meta/RESTRICTED_full.csv \
       -v "${HOME}/.aws:/home/jovyan/.aws" \
       nben/visual_autolabel:benson2025 jupyter
   ```

The above command will start a Jupyter server inside the virtual machine and will expose it on port 8888, allowing you to point a browser to `http://127.0.0.1:8888/` to connect to the virtual machine's compute environment. This environment should remain identical to that used to analyze the data for the paper. The parameter `<HCP-restricted-data>` in the command is the path on your local environment to the restricted data CSV file from the Human Connectome Project (see instruction 1C, above).

You can optionally include additional volume mount data in order to save cached data across uses of the docker image or to reduce compute time. If, for example, you have the HCP subject data from the 1200 subject release loaded in the directory `/hcp/subjects` on your local computer, you can add the line `-v /hcp/subjects:/data/hcp/subjects` after the `-p 8888:8888` line. If you wish to save cache files across runs so that subsequent computations are faster, you can use put the line `-v <cache-dir>:/data` after the `-p 8888:8888` line where `<cache-dir>` should be replaced with a directory that will hold all the various cache data.

## 1. Configuration

Here we define any configuration item that needs to be set locally for the system running this notebook. If you are running this notebook from within the docker-iage that was published with by Benson, Song, et al. (2025) with their project (accessible [here](https://doi.org/10.5281/zenodo.14502583)), i.e. by following the instructions above, then these configuration items will not need to be changed.

If you are running this notebook from outside of the docker image, then you will most likely need to edit these in order for the code to work correctly.

In [None]:
# data_root
# The root of the data directory, in which this notebook expects by default to
# find the datasets, analysis, and models directories. This directory is
# typically mounted into the docker container using the `-v` (volumes) option
# to the docker run command.
data_root = '/data/visual-autolabel'

# dataset_cache_path
# The directory into which data for the model training should be cached. This
# can be None, but if it is, then the training images will need to be
# regenerated every time the notebook is run.
dataset_cache_path = f'{data_root}/datasets'

# analysis_cache_path
# The directory where analysis data (primarily dataframes that compare the
# predictions of various methods) are stored.
analysis_path = f'{data_root}/analysis'

# model_cache_path
# The directory into which to store models that are generated during training.
# This may be None, but if it is, then the best models will not be saved out to
# disk during rounds of training.
model_cache_path = f'{data_root}/models'

# figures_path
# Where this notebook should save figures out to.
figures_path = f'{data_root}/figures'

# grid_path
# Where the grid-search data either lives or should be downloaded and extracted
# to. Note that desipte the relatively large number of files in the grid-search
# data, the data themselves are fairly small (megabytes, not gigabytes).
# The directory must contain another directory named 'grid-search', otherwise
# the grid-search.tar.gz file will be extracted into the grid_path to create
# this directory.
grid_path = f'{data_root}/grid'

# dwi_filename_pattern
# Where and how to load diffusion-weighted imaging data files. This is only
# important if you are regenerating the endpoint tract images from scratch;
# otherwise this can be ignored.
# This may be either a string or a tuple of strings; in either case all strings
# are formatted with the target data ('rater' and 'subject') and 'hemisphere'
# and 'tract_name' values then joined using `os.path.join`.
# How this pattern is interpreted can be changed by editing the code for the
# DWIFeature class below.
dwi_filename_pattern = (
    # We load from the directory /data/hcp/tracts/<subject_id>
    data_root, 'tracts', '{subject}',
    # The filename is like lh.VOF_normalized.mgz
    '{hemisphere}.{tract_name}_normalized.mgz')

# hcp_restricted_path
# The path of the CSV file containing the HCP restricted dataset. This file
# must include the genetic data from the HCP young adult dataset. If you
# have configured neuropythy to have access to this file, then you may leave
# it as None. See this website for more information: 
# https://www.humanconnectome.org/study/hcp-young-adult/document/restricted-data-usage
hcp_restricted_path = None

## 2. Initialization

This section contains initialization code that loads libraries and data relevant to the project.

### 2.1. Import Dependencies

In [None]:
import json
from pathlib import Path

import torch
import numpy as np
import neuropythy as ny
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

import visual_autolabel as va
import visual_autolabel.benson2025 as proj

proj.config.model_cache_path     = model_cache_path
proj.config.dataset_cache_path   = dataset_cache_path
proj.config.analysis_path        = analysis_path
proj.config.dwi_filename_pattern = dwi_filename_pattern

### 2.2. Load Data

In this section we load the various data that we plan to plot in the sections below. This includes the HCP and NYU datasets, both of which are lazily loaded as subject data or model results are requested from the data structures in this section. Additionally, the grid-search data is loaded here for visualization of the hyperparameter search.

#### 2.2.1. HCP and NYU Datasets

In [None]:
# Here we declare the hcp datasets we are using in this notebook.

# Make the HCP partition; if this fails, you probably haven't provided
# the hcp_restricted_path correctly.
(hcp_trn_sids, hcp_val_sids) = proj.hcp.partition(
    hcp_restricted_path=hcp_restricted_path)
# Make the HCP datasets for each input and output type.
hcp_data = proj.hcp.all_datasets()
# From these we can also make subject flatmaps:
hcp_maps = proj.hcp.all_flatmaps(hcp_data)

# Do the same for the NYU datasets.
(nyu_trn_sids, nyu_val_sids) = proj.nyu.partition()
nyu_data = proj.nyu.all_datasets()
nyu_maps = proj.nyu.all_flatmaps(nyu_data)

# We also want to read in the dice dataframes (which contain the dice score
# comparisons of all models/predictors versus all others for all subjects).
dice = proj.all_scores()

# We can add a column that specifies whether the subject is in the validation or
# the training partition; we have to add it here because the twin-status is
# restricted data, and specifying which subjects are trn and which are val
# provides some limited clues about which subjects are twins.
all_val_sids = np.union1d(hcp_val_sids.astype(str), nyu_val_sids)
is_val = np.isin(dice['sid'].astype(str), all_val_sids)
dice.insert(1, 'partition', np.select([is_val], ['val'], 'trn'))
# We can separate these out also:
trndice = dice[dice['partition'] == 'trn']
valdice = dice[dice['partition'] == 'val']

#### 2.2.2. Hyperparameter Grid-search Data

The hyperparameter search results are stored in a gzipped tarball file in the [OSF repository](https://osf.io/c49dv) associated with this notebook. This code cell downloads those data and extracts them if the `'grid-search'` subdirectory isn't found in the directory named by the `grid_path` variable.

In [None]:
gridsearch_path = Path(f"{grid_path}/grid-search")
if gridsearch_path.is_dir():
    print(f"Grid search directory found: {gridsearch_path}")
else:
    import urllib.request, tarfile
    print(f"Downloading and extracting grid-search.tar.gz to {grid_path}... ", end="")
    # This is the permanent URL for the grid-search.tar.gz file on the OSF:
    url = 'https://osf.io/download/9wfde/'
    gridsearch_filename = f'{grid_path}/grid-search.tar.gz'
    ny.util.url_download(url, gridsearch_filename)
    # Extract the tarball:
    with tarfile.open(gridsearch_filename, "r:gz") as tar:
        tar.extractall(str(gridsearch_path), filter=tarfile.fully_trusted_filter)
    print("Done.")

## 3. Data Visualization

### 3.1. Prediction Accuracies

In this section of the analysis notebook, we plot visualizations of the accuracies of the CNN predictions as well as plots of the predictions on cortex. This part of the notebook as well as later parts use a set of 4-character codes to represent different possible input data given to the CNNs. Each CNN was trained to make predictions of either visual area boundaries (`'area'` predictions) or iso-eccentric regions (`'ring'` predictions) based on one of the following sets of input data:
 * **`'anat'`**: T1-weighted (anatomical) data only (`'x'`, `'y'`, `'z'`, `'curvature'`, `'convexity'`, `'thickness'`, `'surface_area'`).
 * **`'t1t2'`**: T1-weighted and T2-weighted data (everything from **`'anat'`** plus `'myelin'`).
 * **`'trac'`**: T1-weighted and diffusion-weighted data (everything from **`'anat'`** plus `'dwi_OR'` and `'dwi_VOF'`).
 * **`'nofn'`**: T1-weighted, T2-weighted, and diffusion-weighted data (everything from **`'t1t2'`** and from **`'trac'`**).
 * **`'func'`**: T1-weighted and functional data (everything from **`'anat'`** plus `'prf_x'`, `'prf_y'`, `'prf_sigma'`, and `'prf_cod'`).
 * **`'nodw'`**: T1-weighted, T2-weighted, and functional data (everything frmo **`'func'`** plus `'myelin'`).
 * **`'not2'`**: T1-weighted, diffusion-weighted, and functional data (everything frmo **`'func'`** plus `'dwi_OR'` and `'dwi_VOF'`).
 * **`'full'`**: All of the above data.

Additionally, the following codes refer to non-CNN data sources:
 * **`'prior'`**: The retinotopic prior by [Benson & Winawer (2018)](https://doi.org/10.7554/eLife.40224).
 * **`'inf'`**: Bayesian inferred retinotopic maps of the HCP subjectsby [Benson & Winawer (2018)](https://doi.org/10.7554/eLife.40224).
 * **`'rely'`**: Inter-rater reliability of the dataset.

#### 3.1.1. Plotting Code

In [None]:
# This function generates and returns a matplotlib figure of the CNN accuracies.
# The cells below produce the figures used in the paper using this function.

# Default error-bar options for the function.
default_ebaropts = dict(middle='mean', extent='ste', fw=0.05, lw=0.5, ms=1.5)

# These are the models that are plotted by the function based on the requested
# parcellation--this variable gives the order of the columns in the figure.
benson2025_accuracy_models = (
    # We always plot the prior and the inferred maps as the first two columns.
    ('prior', 'hcp'), ('inf',   'hcp'),
    # We then plot the four CNNs that did not receive functional inputs.
    ('anat',  'hcp'), ('t1t2',  'hcp'), ('trac',  'hcp'), ('nofn',  'hcp'),
    # We then plot the four CNNs that did receive functional inputs.
    ('full',  'hcp'), ('func',  'hcp'), ('nodw',  'hcp'), ('not2',  'hcp'),
    # Finally, we plot the inter-rater reliability.
    ('rely',  'hcp'))
# Now process that shared list into a dict for area and ring parcellations:
benson2025_accuracy_models = {
    # For the visual area boundaries, we also want to show the NYU data.
    'area': (benson2025_accuracy_models 
             + (('anat', 'nyu'), ('func', 'nyu'), ('fnyu', 'nyu'))),
    # For the ring models, we just plot the above.
    'ring': benson2025_accuracy_models}

def benson2025_accuracy_plot(
        parcellation,
        dataframe=valdice,
        colwidth=1,
        lrspace=0.2,
        ebarlw=0.5,
        ebarwidth=0.15,
        ebaropts=default_ebaropts,
        pointms=2.5,
        ebarclr = {'lh':'k', 'rh':'k'},
        pointclr = {'lh': (0.2, 0.8, 0.8), 'rh': (1.0, 0.4, 0.4)},
        printsummary=True):
    """Generates and returns a matplotlib figure of the accuracies of the CNNs
    from Benson, Song, et al. (2025).

    This function must be passed a parcellation, either `'area'` or `'ring'`
    for visual area boundaries or iso-eccentric regions. The resulting figure
    plots the results for this parcellation from the given dataframe, which
    should typically be `valdice`. The remaining options tweak the visuals of
    the plot.

    The resulting figure should be equivalent to those produces by Benson,
    Song, et al. (2025) if run correctly in the associated docker image.
    """
    df0 = dataframe
    yticks = np.linspace(0,1,6)
    df0 = df0[df0['parcellation'] == parcellation]
    rois = np.unique(df0['label'].values)
    nrois = len(rois)
    # The models we will be plotting depend on the parcellation type:
    mdls = benson2025_accuracy_models[parcellation]
    nmdls = len(mdls)
    # We want to perform Bonferroni correction on the confidence intervals we report,
    # so we count up the number of confidence intervals we are displaying. This is
    # the number of models times the number of hemispheres (2).
    bfcount = 2 * nmdls
    # Make the figure itself:
    (fig,axs) = plt.subplots(nrois, 1, figsize=(5,nrois*1.5), dpi=72*8)
    fig.subplots_adjust(0,0,1,1,0.1,0.15)
    # Go through each axis/ROI area first:
    for (roi,ax) in zip(rois, axs):
        df1 = df0[df0['label'] == roi]
        # We don't use spines.
        for sp in ax.spines.values():
            sp.set_visible(False)
        # We do use ticks.
        ax.set_ylim([-0.025, 1.025])
        ax.set_yticks(yticks)
        # We draw our own y-axis vertical bar.
        ax.set_xlim([-0.25, 0.25 + (nmdls-1)*colwidth])
        ax.plot([-0.25,-0.25], [0,1], 'k-', zorder=-99)
        # Plot some mesh lines.
        for y in yticks:
            ax.plot([-0.25, 0.25 + (nmdls-1)*colwidth], [y, y], '-',
                    lw=0.3, c='0.65', zorder=-100)
        for y in np.mean([yticks[:-1], yticks[1:]], 0):
            ax.plot([-0.25, 0.25 + (nmdls-1)*colwidth], [y, y], '-',
                    lw=0.2, c='0.85', zorder=-100)
        # Specify the labels and xticks.
        if roi == 'mean':
            ax.set_ylabel("Mean Dice Score")
        elif parcellation == 'area':
            ax.set_ylabel(f'V{roi} Dice Score')
        else:
            ax.set_ylabel(f'E{roi} Dice Score', weight='bold')
        ax.set_xlabel(None)
        ax.set_xticks([])
        # Plot the lh and rh data side-by-side.
        for (h,dx) in zip(['lh','rh'], [-lrspace/2, lrspace/2]):
            df2 = df1[df1['hemisphere'] == h]
            # We want to go through each model in the dataframe.
            for (mdl,x) in zip(mdls, np.arange(nmdls)*colwidth + dx):
                (tag,ds) = mdl
                df3 = df2[(df2['tag'] == tag)]
                df3 = df3[(df3['dataset'] == ds.upper())]
                if len(df3) == 0:
                    continue
                ys = df3['score'].values
                if roi == 'mean':
                    (mn,mu,mx) = va.plot.summarize_dist(ys, extent='ste', bfcount=bfcount)
                    assert np.abs((mx-mu) - (mu-mn)) < 0.00001
                    if printsummary:
                        print(f"{roi} {h} {tag:5s}: {mu:5.3f} ± {mx-mu:5.3f}")
                va.plot.plot_distbars(x, ys, axes=ax,
                                      lc=ebarclr[h], mc=ebarclr[h], zorder=2,
                                      bfcount=bfcount,
                                      **ebaropts)
                ax.plot([x]*len(ys), ys, '.', ms=pointms, c=pointclr[h],
                        alpha=0.05, zorder=1)
    # And return the figure.
    return fig

#### 3.1.2. Accuracies of the CNNs Predicting Visual Area Boundaries (**Figs. 3** & **S3**)

In [None]:
parcellation = 'area'
printsummary = True  # To print the precise means ± SEMs.
fig = benson2025_accuracy_plot(parcellation, printsummary=printsummary)
plt.savefig(f'{data_root}/figures/summary_{parcellation}.pdf',
            bbox_inches='tight',
            transparent=True)
plt.show();

#### 3.1.3. Accuracies of the CNNs Predicting Iso-Eccentric Regions (**Figs. 4** & **S4**)

In [None]:
parcellation = 'ring'
printsummary = True  # To print the precise means ± SEMs.
fig = benson2025_accuracy_plot(parcellation, printsummary=printsummary)
plt.savefig(f'{data_root}/figures/summary_{parcellation}.pdf',
            bbox_inches='tight',
            transparent=True)
plt.show();

### 3.2. Predictions on Cortex

#### 3.2.1 Plotting Code

In [None]:
################################################################################
# Default options for individual cortex-plot calls used by the plot function.
default_cortex_plot_opts = dict(
    mask=('prf_variance_explained', 0.1, 1))

def cortex_prediction_plot(
        prefix1,
        prefix2,
        parcellation,
        sid=115017,
        cortex_plot_opts=default_cortex_plot_opts):
    # Make the figure:
    dpi = 72*8
    (fig,axs) = plt.subplots(1,2, figsize=(5, 2.5), dpi=dpi)
    fig.subplots_adjust(0,0,1,1,0,0)
    # Make/get the flatmaps for the subject:
    if isinstance(sid, int):
        fmaps = hcp_maps[sid]
    else:
        fmaps = nyu_maps[sid]
    # Plot each hemisphere
    for (ax,fmap) in zip(axs, fmaps):
        if 'color' not in cortex_plot_opts:
            if parcellation == 'area':
                clr = 'prf_polar_angle'
            else:
                clr = 'prf_eccentricity'
            opts = dict(cortex_plot_opts, color=clr)
        else:
            opts = cortex_plot_opts
        ny.cortex_plot(fmap, axes=ax, **opts)
        # Add lines.
        for (pre,clr) in zip([prefix1, prefix2], ['w', 'k']):
            if pre == 'tmpl':
                pre = 'prior'
            elif pre == 'warp':
                pre = 'inf'
            pnm = f'{pre}_visual_{parcellation}'
            p = fmap.prop(pnm)
            (u,v) = fmap.tess.indexed_edges
            ii = p[u] != p[v]
            xy = fmap.coordinates
            xy = np.mean([xy[:,u[ii]], xy[:,v[ii]]], axis=0)
            ax.scatter(xy[0],xy[1], c=clr, s=0.5)
        # Turn off axes.
        ax.axis('off')
    # Return the figure.
    return fig

#### 3.2.2 Visual Area Boundaries on Cortex (**Fig. 3**)

In [None]:
sid = 115017
prefix1 = 'A1'
prefix2 = 'anat'
parc = 'area'
dpi = 288

fig = cortex_prediction_plot(prefix1, prefix2, parc, sid=sid)

flnm = f'sample-contours_{sid}_{parc}_{prefix1}-{prefix2}.png'
plt.savefig(
    f'{data_root}/figures/{flnm}',
    bbox_inches='tight',
    dpi=dpi)

plt.show()

#### 3.2.3 Iso-Eccentric Regions on Cortex (**Fig. 4**)

In [None]:
sid = 115017
prefix1 = 'A1'
prefix2 = 'anat'
parc = 'ring'
dpi = 288

fig = cortex_prediction_plot(prefix1, prefix2, parc, sid=sid)

flnm = f'sample-contours_{sid}_{parc}_{prefix1}-{prefix2}.png'
plt.savefig(
    f'{data_root}/figures/{flnm}',
    bbox_inches='tight',
    dpi=dpi)

plt.show()

### 3.3. Hyperparameter Search (**Figs. S1** & **S2**)

This section contains code for plotting and examining the results of the hyperparamter search. The hyperparameter grid-search results are stored on the OSF website under the directory `hyperparameters` in a gzipped tarball file named `grid-search.tar.gz`. This file contains one directory for each cell of the grid named `f'grid{index:05d}'` where `index` is the 0-based index of the flattened grid (e.g., `grid00000` is the directory of the first cell in the grid and `grid00001` is the directory of the second cell, etc.). Inside of each of these directories are three files: `opts.json`, `plan.json`, and `run.log`. The `opts.json` file contains the options used in all stages of the fitting, the `plan.json` file contains the specific option settings for each epoch, and the `run.log` file contains a log of the accuracies achieved during the training.

The overall accuracy of the model can be found in the `run.log` file, where it appears as the lowest loss among the values in the validation dice-loss column.

#### 3.3.1. Loading and Plotting Code

This function (`load_cell`) loads a single cell from the grid-search; the function following it loads all cells into a dataframe.

In [None]:
def load_cell(cell_path): 
    """Loads the meta-data for one cell of the grid.
    
    `load_cell(path)` loads the cell at the given `path`, which should be a
    directory containing the `opts.json`, `plan.json`, and `run.log` files.
    """
    cell_path = Path(cell_path)
    with (cell_path / 'opts.json').open('rt') as fl:
        opts = json.load(fl)
    opts = dict({'cell_id': int(opts['model_key'][4:])}, **opts)
    with (cell_path / 'plan.json').open('rt') as fl:
        plan = json.load(fl)
    for (ii,p) in enumerate(plan):
        for (k,v) in p.items():
            opts[f'{k}_{ii}'] = v
    with (cell_path / 'run.log').open('rt') as fl:
        log = fl.readlines()
    it0s = [ii for (ii,ln) in enumerate(log) if ln.startswith('Iteration')]
    it0s.append(-1)
    losses = np.array(
        [min([float(ln.split()[9])
              for ln in log[ii0:ii1]
              if ln.endswith('*\n')])
         for (ii0,ii1) in zip(it0s[:-1], it0s[1:])])
    if len(losses) > 0:
        opts['loss_mean'] = np.mean(losses)
        opts['loss_median'] = np.median(losses)
        opts['loss_std'] = np.std(losses)
        opts['loss_max'] = np.max(losses)
        opts['loss_min'] = np.min(losses)
    else:
        opts['loss_mean'] = np.nan
        opts['loss_median'] = np.nan
        opts['loss_std'] = np.nan
        opts['loss_max'] = np.nan
        opts['loss_min'] = np.nan
    opts['partition'] = va.partition_id(opts['partition'])
    return opts
def load_gridsearch(gridsearch_path=gridsearch_path):
    """Loads all cells from the gridsearch data.
    
    This function requires that the grid-search.tar.gz file (available from the OSF repository
    associated with this notebook: osf.io/c49dv) be downloaded and available in the provided
    `gridsearch_path`. The `gridsearch_path` argument must be path of the extracted
    `grid-search` directory.
    """
    cells = [
        load_cell(gridsearch_path / f'grid{k:05d}')
        for k in range(1800)]  # There are 1800 cells in the grid.
    cells_full = pd.DataFrame(cells)
    cells = cells_full.drop(
        columns=[
            'pretrained', 'multiproc', 
            'model_cache_path', 'data_cache_path',
            'model_key', 'partition',
            'lr_1', 'lr_2', 'bce_weight_1', 'bce_weight_2'])
    return cells.rename(columns=dict(lr_0='lr', bce_weight_0='bce_weight'))

This cell contains visualization code, i.e. to plot the grids based on particular conditions. The plots are made as a grid of grids where the small grids plot `lr` versus `gamma` and the outer grids capture the remaining hyperparameters.

In [None]:
def plotgrid(cells, inputs, outputs, base_model, col='loss_min',
             vmin=0, vmax=0.2, cmap='cividis',
             figsize=(7,7), dpi=512,
             gammas=(0.8, 0.85, 0.9, 0.95, 1),
             lrs=(0.00167, 0.0025 , 0.00375, 0.00562, 0.00844),
             axes=None,
             star=True,
             printbest=False):
    from warnings import warn
    # First get the set of cells we are planning to use:
    cells = cells[(cells['inputs'] == inputs) & (cells['prediction'] == outputs)]
    cells = cells[cells['base_model'] == base_model]
    # We now need a grid of 3 x 3 matrices, each of which will be 5x5:
    if axes is None:
        (fig,axs) = plt.subplots(3, 3, dpi=dpi, figsize=figsize, sharex=True, sharey=True)
    else:
        axs = axes
        fig = None
    posmin = None
    totmin = np.inf
    mtcs = []
    for (axrow,batch_size) in zip(axs, [2,4,6]):
        subcells0 = cells[cells['batch_size'] == batch_size]
        mtxrow = []
        for (ax, bcew0) in zip(axrow, [0.5, 0.67, 0.75]):
            # Get the subset of cells that match:
            subcells = subcells0[subcells0['bce_weight'] == bcew0]
            # Now go through and make the matrix (there may be values missing, so we
            # build this up iteratively).
            mtx = []
            g = []
            l = []
            for gamma in gammas:
                row = []
                for lr in lrs:
                    g.append(gamma)
                    l.append(lr)
                    cell = subcells[(subcells['gamma'] == gamma) & (subcells['lr'] == lr)]
                    if len(cell) == 0:
                        warn(f"missing cell: {inputs}/{outputs}/"
                             f"{base_model}/{batch_size}/{bcew0}/{gamma}/{lr}")
                        row.append(np.nan)
                    elif len(cell) > 1:
                        warn(f"identical cells: {[r['cell_id'] for (ii,r) in cell.iterrows()]}")
                        row.append(np.nan)
                    else:
                        row.append(cell[col].values[0])
                mtx.append(row)
            mtx = np.round(np.array(mtx), 3)
            mtcs.append(mtx)
            ax.imshow(mtx, vmin=vmin, vmax=vmax, cmap=cmap)
            ax.invert_yaxis()
            ax.set_title(f'batch={batch_size}, BCE$_0$={bcew0}')
            # We want to track the smallest value
            mtxmin_ii = np.nanargmin(mtx)
            mtxmin = mtx.flat[mtxmin_ii]
            if mtxmin < totmin:
                totmin = mtxmin
                posmin = {
                    'batch_size': batch_size, 'bce_weight': bcew0,
                    'gamma': g[mtxmin_ii], 'lr': l[mtxmin_ii],
                    'value': totmin}
    for ax in axs[:,0]:
        ax.set_ylabel(r'$\gamma$')
        ax.set_yticks(range(len(gammas)))
        ax.set_yticklabels(gammas)
    for ax in axs[-1,:]:
        ax.set_xlabel(r'Learning Rate [$10^{-3}$]')
        ax.set_xticks(range(len(lrs)))
        ax.set_xticklabels([f'{lr*1000:3.2f}' for lr in lrs])
    if star and np.isfinite(totmin):
        for (ax,mtx) in zip(axs.flat, mtcs):
            for (ri,row) in enumerate(mtx):
                for (ci,val) in enumerate(row):
                    if np.round(val, 3) <= totmin:
                        ax.plot(ci, ri, 'w*', zorder=10)
                        if printbest:
                            print(ci, ri, np.round(val, 3))
    if fig is not None:
        fig.subplots_adjust(0,0,1,1,0.2,0.2)
    return posmin

#### Visualization of the Grid Search

In [None]:
# First, we load in the grid-search dataframe.
grid = load_gridsearch()

# We might want to adjust the colormap scale for each subpart of the images; however, we might
# alternatively want them all the same. We can adjust that here:
kws = {('anat','area'): dict(vmin=0.05, vmax=0.4, cmap='hot'),
       ('full','area'): dict(vmin=0.05, vmax=0.4, cmap='hot'),
       ('anat','ring'): dict(vmin=0.05, vmax=0.4, cmap='hot'),
       ('full','ring'): dict(vmin=0.05, vmax=0.4, cmap='hot')}

# Now make the plots:
for (inputs,outputs) in [('anat','area'),('full','area'),('anat','ring'),('full','ring')]:
    kw = kws[(inputs, outputs)]
    (fig,axs) = plt.subplots(3, 6, figsize=(18,9), dpi=512, sharex=True, sharey=True)
    min18 = plotgrid(cells, inputs, outputs, 'resnet18', axes=axs[:,:3], **kw)
    min34 = plotgrid(cells, inputs, outputs, 'resnet34', axes=axs[:,3:], **kw)
    # Print the inputs/outputs that this plot represents plus the best result from the
    # ResNet18 and ResNet34 for comparison.
    print(f'[{inputs}:{outputs}]  ResNet18: {min18["value"]:3.2f};  ResNet34: {min34["value"]:3.2f}')
    for ax in axs[:,3]:
        ax.set_ylabel('')
    fig.subplots_adjust(0,0,1,1,0.25,0.25)
    plt.savefig(f'/data/visual-autolabel/figures/grid_{inputs}_{outputs}.pdf', bbox_inches='tight')
    plt.show()