In [1]:
import os, pathlib
import numpy as np
import neuropythy as ny
import pimms
import matplotlib as mpl, matplotlib.pyplot as plt
import ipyvolume as ipv
import pandas as pd

import torch
from collections.abc import Mapping

In [3]:
# Configuration: Where is the visual_autolabel library?
visual_autolabel_path = '/scratch/bs4283/visual_autolabel'

In [4]:
import sys
if visual_autolabel_path not in sys.path:
    sys.path.append(visual_autolabel_path)
import visual_autolabel

In [5]:
nyu_retinotopy_pp = ny.util.pseudo_path(
    's3://openneuro.org/ds003787/',
    cache_path='/scratch/bs4283/visual_autolabel/data/openneuro/ds003787')

In [6]:
def nyu_fssubj(subj_id, path=nyu_retinotopy_pp, max_eccen=12.4):
    # Get the FreeSurfer subject directory:
    subpp = path.subpath(f'derivatives/freesurfer/{subj_id}')
    sub = ny.freesurfer_subject(subpp)
    # Get the PRF subpath:
    prfpp = path.subpath(f'derivatives/prfanalyze-vista/{subj_id}/ses-nyu3t01/')
    labpp = path.subpath(f'derivatives/ROIs/{subj_id}/')
    for h in ['lh', 'rh']:
        hem = sub.hemis[h]
        # Load the retinotopy data for this hemisphere:
        prfs = dict(
            prf_polar_angle=ny.load(prfpp.local_path(f'{h}.angle_adj.mgz')),
            prf_eccentricity=ny.load(prfpp.local_path(f'{h}.eccen.mgz')),
            prf_variance_explained=ny.load(prfpp.local_path(f'{h}.vexpl.mgz')),
            prf_radius=ny.load(prfpp.local_path(f'{h}.sigma.mgz')),
            prf_x=ny.load(prfpp.local_path(f'{h}.x.mgz')),
            prf_y=ny.load(prfpp.local_path(f'{h}.y.mgz')),
            label=ny.load(labpp.local_path(f'{h}.ROIs_V1-4.mgz')))
        # Add scaled eccentricity to the subject:
        prfs['prf_scaled_eccentricity'] = prfs['prf_eccentricity'] / max_eccen * 8
        prfs['prf_scaled_x'] = prfs['prf_eccentricity'] / max_eccen * 8
        prfs['prf_scaled_y'] = prfs['prf_eccentricity'] / max_eccen * 8
        hem = hem.with_prop(prfs)
        sub = sub.with_hemi({h: hem})
    return sub

## The ImageCache class for the NYU Retinotopy Dataset

See the `HCPImageCache` class in the `visual_autolabel.image._hcp` namespace.

In [7]:
from visual_autolabel.image import (
    BilateralFlatmapImageCache,
    FlatmapFeature,
    LabelFeature
)

class NYURetinotopyImageCache(BilateralFlatmapImageCache):
    """An ImageCache subclass that handles features of the NYU
    Retinotopy Dataset Occipital Pole.
    """
    # The pseudo-path for the NYU retinotopy dataset:
    nyu_retinotopy_pp = ny.util.pseudo_path(
        's3://openneuro.org/ds003787/',
        cache_path='/data/openneuro/ds003787')
    # The subject list:
    subject_list = [
        'sub-wlsubj001',
        'sub-wlsubj004',
        'sub-wlsubj006',
        'sub-wlsubj007',
        'sub-wlsubj014',
        'sub-wlsubj019',
        'sub-wlsubj023',
        'sub-wlsubj042',
        'sub-wlsubj043',
        'sub-wlsubj045',
        'sub-wlsubj046',
        'sub-wlsubj055',
        'sub-wlsubj056',
        'sub-wlsubj057',
        'sub-wlsubj062',
        'sub-wlsubj064',
        'sub-wlsubj067',
        'sub-wlsubj071',
        'sub-wlsubj076',
        'sub-wlsubj079',
        'sub-wlsubj081',
        'sub-wlsubj083',
        'sub-wlsubj084',
        'sub-wlsubj085',
        'sub-wlsubj086',
        'sub-wlsubj087',
        'sub-wlsubj088',
        'sub-wlsubj090',
        'sub-wlsubj091',
        'sub-wlsubj092',
        'sub-wlsubj094',
        'sub-wlsubj095',
        'sub-wlsubj104',
        'sub-wlsubj105',
        'sub-wlsubj109',
        'sub-wlsubj114',
        'sub-wlsubj115',
        'sub-wlsubj116',
        'sub-wlsubj117',
        'sub-wlsubj118',
        'sub-wlsubj120',
        'sub-wlsubj121',
        'sub-wlsubj122',
        'sub-wlsubj126']
    # The NYU Retinotopy subjects:
    subjects = pimms.lazy_map(
        {s: ny.util.curry(lambda s: nyu_fssubj(s), s)
         for s in subject_list})
    # The featuers we know how to make.
    _builtin_features = {
        # Functional Features first.
        'prf_polar_angle':  FlatmapFeature('prf_polar_angle', 'nearest'),
        'prf_eccentricity': FlatmapFeature('prf_eccentricity', 'linear'),
        'prf_cod':          FlatmapFeature('prf_variance_explained', 'linear'),
        'prf_sigma':        FlatmapFeature('prf_radius', 'linear'),
        'prf_x':            FlatmapFeature('prf_x', 'linear'),
        'prf_y':            FlatmapFeature('prf_y', 'linear'),
        'prf_scaled_eccentricity': FlatmapFeature('prf_scaled_eccentricity', 'linear'),
        'prf_scaled_x':            FlatmapFeature('prf_scaled_x', 'linear'),
        'prf_scaled_y':            FlatmapFeature('prf_scaled_y', 'linear'),
        # The vertex coordinates themselves; we add these in.
        'x': FlatmapFeature('midgray_x', 'linear'),
        'y': FlatmapFeature('midgray_y', 'linear'),
        'z': FlatmapFeature('midgray_z', 'linear'),
        # The visual area and visual sector-based features.
        
        'V1':  LabelFeature('label:1', 'nearest'),
        'V2':  LabelFeature('label:2', 'nearest'),
        'V3':  LabelFeature('label:3', 'nearest')
#        'V1':    LabelFeature('visual_area:1', 'nearest'),
#        'V2':    LabelFeature('visual_area:2', 'nearest'),
#        'V3':    LabelFeature('visual_area:3', 'nearest')
    }
    @classmethod
    def builtin_features(cls):
        fs = NYURetinotopyImageCache._builtin_features
        return dict(BilateralFlatmapImageCache.builtin_features(), **fs)
    @classmethod
    def unpack_target(cls, target):
        if isinstance(target, Mapping):
            sid = target['subject']
        else:
            sid = target
        return (sid,)
    def cache_filename(self, target, feature):
        return os.path.join(feature, f"{target}.pt")
    def make_flatmap(self, target, view=None):
        # We may have been given (rater, sid, h) or ((rater, sid), h):
        (sid,) = self.unpack_target(target)
        if view is None:
            raise ValueError("NYURetinotopyImageCache requires a view")
        h = view['hemisphere']
        # Get the subject and hemi.
        sub = NYURetinotopyImageCache.subjects[sid]
        hem = sub.hemis[h]
        # Fix the properties now, if needed:
        (x,y,z) = hem.surface('midgray').coordinates
        hem = hem.with_prop(midgray_x=x, midgray_y=y, midgray_z=z)
        # Make the flatmap:
        fmap = ny.to_flatmap('occipital_pole', hem, radius=np.pi/2.25)
        fmap = fmap.with_meta(subject_id=sid, hemisphere=h)
        # And return!
        return fmap
    # We overload fill_image so that we can call down then turn NaNs into 0s.
    def fill_image(self, target, feature, im):
        super().fill_image(target, feature, im)
        im[torch.isnan(im)] = 0
        return im

In [8]:
retinotopy_cache = NYURetinotopyImageCache()


1. image shape changed

In [1]:
import torch.nn.functional as F
def resize_tensor(tensor,new_shape):
    return F.interpolate(tensor,size = new_shape,mode = 'bilinear',align_corners = True)

model and model weight

In [None]:
model_ww = torch.load('/scratch/bs4283/visual_autolabel/best_func.pt')

In [None]:
import torchvision
def convrelu(in_channels, out_channels,
             kernel=3, padding=None, stride=1, bias=True, inplace=True):
    """Shortcut for creating a PyTorch 2D convolution followed by a ReLU.

    Parameters
    ----------
    in_channels : int
        The number of input channels in the convolution.
    out_channels : int
        The number of output channels in the convolution.
    kernel : int, optional
        The kernel size for the convolution (default: 3).
    padding : int or None, optional
        The padding size for the convolution; if `None` (the default), then
        chooses a padding size that attempts to maintain the image-size.
    stride : int, optional
        The stride to use in the convolution (default: 1).
    bias : boolean, optional
        Whether the convolution has a learnable bias (default: True).
    inplace : boolean, optional
        Whether to perform the ReLU operation in-place (default: True).

    Returns
    -------
    torch.nn.Sequential
        The model of a 2D-convolution followed by a ReLU operation.
    """
#    if padding is None:
#        padding = kernel_default_padding(kernel)
    return torch.nn.Sequential(
        torch.nn.Conv2d(in_channels, out_channels, kernel,
                        padding=padding, bias=bias),
        torch.nn.ReLU(inplace=inplace))

In [None]:
class UNet(torch.nn.Module):
    """a U-Net with a ResNet18 backbone for learning visual area labels.

    The `UNet` class implements a ["U-Net"](https://arxiv.org/abs/1505.04597)
    with a [ResNet-18](https://pytorch.org/hub/pytorch_vision_resnet/) bacbone.
    The class inherits from `torch.nn.Module`.
    
    The original implementation of this class was by Shaoling Chen
    (sc6995@nyu.edu), and additional modifications have been made by Noah C.
    Benson (nben@uw.edu).

    Parameters
    ----------
    feature_count : int
        The number of channels (features) in the input image. When using an
        `HCPVisualDataset` object for training, this value should be set to 4
        if the dataset uses the `'anat'` or `'func'` features and 8 if it uses
        the `'both'` features.
    segment_count : int
        The number of segments (AKA classes, labels) in the output data. For
        V1-V3 this is typically either 3 (V1, V2, V3) or 6 (LV1, LV2, LV3, RV1,
        RV2, RV3).
    base_model : model name or tuple, optional
        The name of the model that is to be used as the base/backbone of the
        UNet. The default is `'resnet18'`, but 
    pretrained : boolean, optional
        Whether to use a pretrained base model for the backbone (`True`) or not
        (`False`). The default is `False`.
    logits : boolean, optional
        Whether the model should return logits (`True`) or probabilities
        (`False`). The default is `True`.

    Attributes
    ----------
    pretrained_base : boolean
        `True` if the base model used in this `UNet` was originally pre-trained
        and `False` otherwise.
    base_model : PyTorch Module
        The ResNet-18 model that is used as the backbone of the `UNet` model.
    base_layers : list of PyTorch Modules
        The ResNet-18 layers that are used in the backbone of the `UNet` model.
    feature_count : int
        The number of input channels (features) that the model expects in input
        images.
    segment_count : int
        The number of segments (labels) predicted by the model.
    logits : bool
        `True` if the output of the model is in logits and `False` if its output
        is in probabilities.
    """
    def __init__(self, feature_count, segment_count,
                 base_model='resnet18',
                 pretrained=False,
                 logits=False):
        import torch.nn as nn
        # Initialize the super-class.
        super().__init__()
        # Store some basic attributes.
        self.feature_count = feature_count
        self.segment_count = segment_count
        self.pretrained = pretrained
        self.logits = logits
        # Set up the base model and base layers for the model.
        if pretrained:
            weights = 'IMAGENET1K_V1'
        else:
            weights = None
        import torchvision.models as mdls
        base_model = getattr(mdls, base_model)
        try:
            base_model = base_model(weights=weights,
                                    num_classes=segment_count)
        except TypeError:
            base_model = base_model(pretrained=pretrained,
                                    num_classes=segment_count)
        # Not sure we should store the base model; seems like a good idea, but
        # does it get caught up in PyTorch's Module data when we do?
        #self.base_model = resnet18(pretrained=pretrained)
        # Because the input size may not be 3 and the output size may not be 3,
        # we want to add an additional 
        if feature_count != 3:
            # Adjust the first convolution's number of input channels.
            c1 = base_model.conv1
            base_model.conv1 = nn.Conv2d(
                feature_count, c1.out_channels,
                kernel_size=c1.kernel_size, stride=c1.stride,
                padding=c1.padding, bias=c1.bias)
        base_layers = list(base_model.children())
        #self.base_layers = base_layers
        # Make the U-Net layers out of the base-layers.
        # size = (N, 64, H/2, W/2)
        self.layer0 = nn.Sequential(*base_layers[:3]) 
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        # size = (N, 64, H/4, W/4)
        self.layer1 = nn.Sequential(*base_layers[3:5])
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        # size = (N, 128, H/8, W/8)        
        self.layer2 = base_layers[5]
        self.layer2_1x1 = convrelu(128, 128, 1, 0)  
        # size = (N, 256, H/16, W/16)
        self.layer3 = base_layers[6]  
        self.layer3_1x1 = convrelu(256, 256, 1, 0)  
        # size = (N, 512, H/32, W/32)
        self.layer4 = base_layers[7]
        self.layer4_1x1 = convrelu(512, 512, 1, 0)
        # The up-swing of the UNet; we will need to upsample the image.
        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)
        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
        self.conv_original_size0 = convrelu(feature_count, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
        self.conv_last = nn.Conv2d(64, segment_count, 1)
    def forward(self, input):
        # Do the original size convolutions.
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)
        # Now the front few layers, which we save for adding back in on the UNet
        # up-swing below.
        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)        
        layer4 = self.layer4(layer3)
        # Now, we start the up-swing; each step must upsample the image.
        layer4 = self.layer4_1x1(layer4)
        # Up-swing Step 1
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)
        # Up-swing Step 2
        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)
        # Up-swing Step 3
        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)
        # Up-swing Step 4
        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)
        # Up-swing Step 5
        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)        
        # And the final convolution.
        out = self.conv_last(x)
        if not self.logits:
            out = torch.sigmoid(out)
        return out

In [None]:
model_unet  = UNet(feature_count = 11, segment_count = 3)

Set weight parameter

In [None]:
import torch.nn as nn
model_unet.layer0[0].weight = nn.Parameter(model_ww['layer0.0.weight'],requires_grad=False)
model_unet.layer0[1].weight = nn.Parameter(model_ww['layer0.1.weight'],requires_grad=False)
model_unet.layer0[1].bias = nn.Parameter(model_ww['layer0.1.bias'],requires_grad=False)
model_unet.layer0[1].running_mean = nn.Parameter(model_ww['layer0.1.running_mean'],requires_grad=False)
model_unet.layer0[1].running_var = nn.Parameter(model_ww['layer0.1.running_var'],requires_grad=False)
model_unet.layer0[1].num_batches_tracked = nn.Parameter(model_ww['layer0.1.num_batches_tracked'],requires_grad=False)
model_unet.layer0_1x1[0].weight = nn.Parameter(model_ww['layer0_1x1.0.weight'],requires_grad=False)
model_unet.layer0_1x1[0].bias = nn.Parameter(model_ww['layer0_1x1.0.bias'],requires_grad=False)
model_unet.layer1[1][0].conv1.weight = nn.Parameter(model_ww['layer1.1.0.conv1.weight'],requires_grad=False)
model_unet.layer1[1][0].bn1.weight = nn.Parameter(model_ww['layer1.1.0.bn1.weight'],requires_grad=False)
model_unet.layer1[1][0].bn1.bias = nn.Parameter(model_ww['layer1.1.0.bn1.bias'],requires_grad=False)
model_unet.layer1[1][0].bn1.running_mean = nn.Parameter(model_ww['layer1.1.0.bn1.running_mean'],requires_grad=False)
model_unet.layer1[1][0].bn1.running_var = nn.Parameter(model_ww['layer1.1.0.bn1.running_var'],requires_grad=False)
model_unet.layer1[1][0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer1.1.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer1[1][0].conv2.weight = nn.Parameter(model_ww['layer1.1.0.conv2.weight'],requires_grad=False)
model_unet.layer1[1][0].bn2.weight = nn.Parameter(model_ww['layer1.1.0.bn2.weight'],requires_grad=False)



model_unet.layer1[1][0].bn2.bias = nn.Parameter(model_ww['layer1.1.0.bn2.bias'],requires_grad=False)
model_unet.layer1[1][0].bn2.running_mean = nn.Parameter(model_ww['layer1.1.0.bn2.running_mean'],requires_grad=False)
model_unet.layer1[1][0].bn2.running_var = nn.Parameter(model_ww['layer1.1.0.bn2.running_var'],requires_grad=False)
model_unet.layer1[1][0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer1.1.0.bn2.num_batches_tracked'],requires_grad=False)


model_unet.layer1[1][1].conv1.weight = nn.Parameter(model_ww['layer1.1.1.conv1.weight'],requires_grad=False)
model_unet.layer1[1][1].bn1.weight = nn.Parameter(model_ww['layer1.1.1.bn1.weight'],requires_grad=False)
model_unet.layer1[1][1].bn1.bias = nn.Parameter(model_ww['layer1.1.1.bn1.bias'],requires_grad=False)
model_unet.layer1[1][1].bn1.running_mean = nn.Parameter(model_ww['layer1.1.1.bn1.running_mean'],requires_grad=False)
model_unet.layer1[1][1].bn1.running_var = nn.Parameter(model_ww['layer1.1.1.bn1.running_var'],requires_grad=False)
model_unet.layer1[1][1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer1.1.1.bn1.num_batches_tracked'],requires_grad=False)


model_unet.layer1[1][1].conv2.weight = nn.Parameter(model_ww['layer1.1.1.conv2.weight'],requires_grad=False)
model_unet.layer1[1][1].bn2.weight = nn.Parameter(model_ww['layer1.1.1.bn2.weight'],requires_grad=False)
model_unet.layer1[1][1].bn2.bias = nn.Parameter(model_ww['layer1.1.1.bn2.bias'],requires_grad=False)
model_unet.layer1[1][1].bn2.running_mean = nn.Parameter(model_ww['layer1.1.1.bn2.running_mean'],requires_grad=False)
model_unet.layer1[1][1].bn2.running_var = nn.Parameter(model_ww['layer1.1.1.bn2.running_var'],requires_grad=False)
model_unet.layer1[1][1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer1.1.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer1_1x1[0].weight = nn.Parameter(model_ww['layer1_1x1.0.weight'],requires_grad=False)
model_unet.layer1_1x1[0].bias = nn.Parameter(model_ww['layer1_1x1.0.bias'],requires_grad=False)

model_unet.layer2[0].conv1.weight = nn.Parameter(model_ww['layer2.0.conv1.weight'],requires_grad=False)
model_unet.layer2[0].bn1.weight = nn.Parameter(model_ww['layer2.0.bn1.weight'],requires_grad=False)
model_unet.layer2[0].bn1.bias = nn.Parameter(model_ww['layer2.0.bn1.bias'],requires_grad=False)
model_unet.layer2[0].bn1.running_mean = nn.Parameter(model_ww['layer2.0.bn1.running_mean'],requires_grad=False)
model_unet.layer2[0].bn1.running_var = nn.Parameter(model_ww['layer2.0.bn1.running_var'],requires_grad=False)
model_unet.layer2[0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer2.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer2[0].conv2.weight = nn.Parameter(model_ww['layer2.0.conv2.weight'],requires_grad=False)
model_unet.layer2[0].bn2.weight = nn.Parameter(model_ww['layer2.0.bn2.weight'],requires_grad=False)
model_unet.layer2[0].bn2.bias = nn.Parameter(model_ww['layer2.0.bn2.bias'],requires_grad=False)
model_unet.layer2[0].bn2.running_mean = nn.Parameter(model_ww['layer2.0.bn2.running_mean'],requires_grad=False)
model_unet.layer2[0].bn2.running_var = nn.Parameter(model_ww['layer2.0.bn2.running_var'],requires_grad=False)
model_unet.layer2[0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer2.0.bn2.num_batches_tracked'],requires_grad=False)


model_unet.layer2[0].downsample[0].weight = nn.Parameter(model_ww['layer2.0.downsample.0.weight'],requires_grad=False)
model_unet.layer2[0].downsample[1].weight = nn.Parameter(model_ww['layer2.0.downsample.1.weight'],requires_grad=False)
model_unet.layer2[0].downsample[1].bias = nn.Parameter(model_ww[ 'layer2.0.downsample.1.bias'],requires_grad=False)
model_unet.layer2[0].downsample[1].running_mean = nn.Parameter(model_ww['layer2.0.downsample.1.running_mean'],requires_grad=False)
model_unet.layer2[0].downsample[1].running_var = nn.Parameter(model_ww['layer2.0.downsample.1.running_var'],requires_grad=False)
model_unet.layer2[0].downsample[1].num_batches_tracked = nn.Parameter(model_ww['layer2.0.downsample.1.num_batches_tracked'],requires_grad=False)


model_unet.layer2[1].conv1.weight = nn.Parameter(model_ww['layer2.1.conv1.weight'],requires_grad=False)
model_unet.layer2[1].bn1.weight = nn.Parameter(model_ww['layer2.1.bn1.weight'],requires_grad=False)
model_unet.layer2[1].bn1.bias = nn.Parameter(model_ww['layer2.1.bn1.bias'],requires_grad=False)
model_unet.layer2[1].bn1.running_mean = nn.Parameter(model_ww['layer2.1.bn1.running_mean'],requires_grad=False)
model_unet.layer2[1].bn1.running_var = nn.Parameter(model_ww['layer2.1.bn1.running_var'],requires_grad=False)
model_unet.layer2[1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer2.1.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer2[1].conv2.weight = nn.Parameter(model_ww['layer2.1.conv2.weight'],requires_grad=False)
model_unet.layer2[1].bn2.weight = nn.Parameter(model_ww['layer2.1.bn2.weight'],requires_grad=False)
model_unet.layer2[1].bn2.bias = nn.Parameter(model_ww['layer2.1.bn2.bias'],requires_grad=False)
model_unet.layer2[1].bn2.running_mean = nn.Parameter(model_ww['layer2.1.bn2.running_mean'],requires_grad=False)
model_unet.layer2[1].bn2.running_var = nn.Parameter(model_ww['layer2.1.bn2.running_var'],requires_grad=False)
model_unet.layer2[1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer2.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer2_1x1[0].weight = nn.Parameter(model_ww['layer2_1x1.0.weight'],requires_grad=False)
model_unet.layer2_1x1[0].bias = nn.Parameter(model_ww['layer2_1x1.0.bias'],requires_grad=False)


model_unet.layer3[0].conv1.weight = nn.Parameter(model_ww['layer3.0.conv1.weight'],requires_grad=False)
model_unet.layer3[0].bn1.weight = nn.Parameter(model_ww['layer3.0.bn1.weight'],requires_grad=False)
model_unet.layer3[0].bn1.bias = nn.Parameter(model_ww['layer3.0.bn1.bias'],requires_grad=False)
model_unet.layer3[0].bn1.running_mean = nn.Parameter(model_ww['layer3.0.bn1.running_mean'],requires_grad=False)
model_unet.layer3[0].bn1.running_var = nn.Parameter(model_ww['layer3.0.bn1.running_var'],requires_grad=False)
model_unet.layer3[0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer3.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer3[0].conv2.weight = nn.Parameter(model_ww['layer3.0.conv2.weight'],requires_grad=False)
model_unet.layer3[0].bn2.weight = nn.Parameter(model_ww['layer3.0.bn2.weight'],requires_grad=False)
model_unet.layer3[0].bn2.bias = nn.Parameter(model_ww['layer3.0.bn2.bias'],requires_grad=False)
model_unet.layer3[0].bn2.running_mean = nn.Parameter(model_ww['layer3.0.bn2.running_mean'],requires_grad=False)
model_unet.layer3[0].bn2.running_var = nn.Parameter(model_ww['layer3.0.bn2.running_var'],requires_grad=False)
model_unet.layer3[0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer3.0.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer3[0].downsample[0].weight = nn.Parameter(model_ww['layer3.0.downsample.0.weight'],requires_grad=False)
model_unet.layer3[0].downsample[1].weight = nn.Parameter(model_ww['layer3.0.downsample.1.weight'],requires_grad=False)
model_unet.layer3[0].downsample[1].bias = nn.Parameter(model_ww[ 'layer3.0.downsample.1.bias'],requires_grad=False)
model_unet.layer3[0].downsample[1].running_mean = nn.Parameter(model_ww['layer3.0.downsample.1.running_mean'],requires_grad=False)
model_unet.layer3[0].downsample[1].running_var = nn.Parameter(model_ww['layer3.0.downsample.1.running_var'],requires_grad=False)
model_unet.layer3[0].downsample[1].num_batches_tracked = nn.Parameter(model_ww['layer3.0.downsample.1.num_batches_tracked'],requires_grad=False)


model_unet.layer3[1].conv1.weight = nn.Parameter(model_ww['layer3.1.conv1.weight'],requires_grad=False)
model_unet.layer3[1].bn1.weight = nn.Parameter(model_ww['layer3.1.bn1.weight'],requires_grad=False)
model_unet.layer3[1].bn1.bias = nn.Parameter(model_ww['layer3.1.bn1.bias'],requires_grad=False)
model_unet.layer3[1].bn1.running_mean = nn.Parameter(model_ww['layer3.1.bn1.running_mean'],requires_grad=False)
model_unet.layer3[1].bn1.running_var = nn.Parameter(model_ww['layer3.1.bn1.running_var'],requires_grad=False)
model_unet.layer3[1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer3.1.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer3[1].conv2.weight = nn.Parameter(model_ww['layer3.1.conv2.weight'],requires_grad=False)
model_unet.layer3[1].bn2.weight = nn.Parameter(model_ww['layer3.1.bn2.weight'],requires_grad=False)
model_unet.layer3[1].bn2.bias = nn.Parameter(model_ww['layer3.1.bn2.bias'],requires_grad=False)
model_unet.layer3[1].bn2.running_mean = nn.Parameter(model_ww['layer3.1.bn2.running_mean'],requires_grad=False)
model_unet.layer3[1].bn2.running_var = nn.Parameter(model_ww['layer3.1.bn2.running_var'],requires_grad=False)
model_unet.layer3[1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer3.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer3_1x1[0].weight = nn.Parameter(model_ww['layer3_1x1.0.weight'],requires_grad=False)
model_unet.layer3_1x1[0].bias = nn.Parameter(model_ww['layer3_1x1.0.bias'],requires_grad=False)


model_unet.layer4[0].conv1.weight = nn.Parameter(model_ww['layer4.0.conv1.weight'],requires_grad=False)
model_unet.layer4[0].bn1.weight = nn.Parameter(model_ww['layer4.0.bn1.weight'],requires_grad=False)
model_unet.layer4[0].bn1.bias = nn.Parameter(model_ww['layer4.0.bn1.bias'],requires_grad=False)
model_unet.layer4[0].bn1.running_mean = nn.Parameter(model_ww['layer4.0.bn1.running_mean'],requires_grad=False)
model_unet.layer4[0].bn1.running_var = nn.Parameter(model_ww['layer4.0.bn1.running_var'],requires_grad=False)
model_unet.layer4[0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer4.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer4[0].conv2.weight = nn.Parameter(model_ww['layer4.0.conv2.weight'],requires_grad=False)
model_unet.layer4[0].bn2.weight = nn.Parameter(model_ww['layer4.0.bn2.weight'],requires_grad=False)
model_unet.layer4[0].bn2.bias = nn.Parameter(model_ww['layer4.0.bn2.bias'],requires_grad=False)
model_unet.layer4[0].bn2.running_mean = nn.Parameter(model_ww['layer4.0.bn2.running_mean'],requires_grad=False)
model_unet.layer4[0].bn2.running_var = nn.Parameter(model_ww['layer4.0.bn2.running_var'],requires_grad=False)
model_unet.layer4[0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer4.0.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer4[0].downsample[0].weight = nn.Parameter(model_ww['layer4.0.downsample.0.weight'],requires_grad=False)
model_unet.layer4[0].downsample[1].weight = nn.Parameter(model_ww['layer4.0.downsample.1.weight'],requires_grad=False)
model_unet.layer4[0].downsample[1].bias = nn.Parameter(model_ww[ 'layer4.0.downsample.1.bias'],requires_grad=False)
model_unet.layer4[0].downsample[1].running_mean = nn.Parameter(model_ww['layer4.0.downsample.1.running_mean'],requires_grad=False)
model_unet.layer4[0].downsample[1].running_var = nn.Parameter(model_ww['layer4.0.downsample.1.running_var'],requires_grad=False)
model_unet.layer4[0].downsample[1].num_batches_tracked = nn.Parameter(model_ww['layer4.0.downsample.1.num_batches_tracked'],requires_grad=False)


model_unet.layer4[1].conv1.weight = nn.Parameter(model_ww['layer4.1.conv1.weight'],requires_grad=False)
model_unet.layer4[1].bn1.weight = nn.Parameter(model_ww['layer4.1.bn1.weight'],requires_grad=False)
model_unet.layer4[1].bn1.bias = nn.Parameter(model_ww['layer4.1.bn1.bias'],requires_grad=False)
model_unet.layer4[1].bn1.running_mean = nn.Parameter(model_ww['layer4.1.bn1.running_mean'],requires_grad=False)
model_unet.layer4[1].bn1.running_var = nn.Parameter(model_ww['layer4.1.bn1.running_var'],requires_grad=False)
model_unet.layer4[1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer4.1.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer4[1].conv2.weight = nn.Parameter(model_ww['layer4.1.conv2.weight'],requires_grad=False)
model_unet.layer4[1].bn2.weight = nn.Parameter(model_ww['layer4.1.bn2.weight'],requires_grad=False)
model_unet.layer4[1].bn2.bias = nn.Parameter(model_ww['layer4.1.bn2.bias'],requires_grad=False)
model_unet.layer4[1].bn2.running_mean = nn.Parameter(model_ww['layer4.1.bn2.running_mean'],requires_grad=False)
model_unet.layer4[1].bn2.running_var = nn.Parameter(model_ww['layer4.1.bn2.running_var'],requires_grad=False)
model_unet.layer4[1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer4.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer4_1x1[0].weight = nn.Parameter(model_ww['layer4_1x1.0.weight'],requires_grad=False)
model_unet.layer4_1x1[0].bias = nn.Parameter(model_ww['layer4_1x1.0.bias'],requires_grad=False)

model_unet.conv_up3[0].weight = nn.Parameter(model_ww['conv_up3.0.weight'],requires_grad=False)
model_unet.conv_up3[0].bias = nn.Parameter(model_ww['conv_up3.0.bias'],requires_grad=False)
model_unet.conv_up2[0].weight = nn.Parameter(model_ww['conv_up2.0.weight'],requires_grad=False)
model_unet.conv_up2[0].bias = nn.Parameter(model_ww['conv_up2.0.bias'],requires_grad=False)
model_unet.conv_up1[0].weight = nn.Parameter(model_ww['conv_up1.0.weight'],requires_grad=False)
model_unet.conv_up1[0].bias = nn.Parameter(model_ww['conv_up1.0.bias'],requires_grad=False)
model_unet.conv_up0[0].weight = nn.Parameter(model_ww['conv_up0.0.weight'],requires_grad=False)
model_unet.conv_up0[0].bias = nn.Parameter(model_ww['conv_up0.0.bias'],requires_grad=False)


model_unet.conv_original_size0[0].weight = nn.Parameter(model_ww['conv_original_size0.0.weight'],requires_grad=False)
model_unet.conv_original_size0[0].bias = nn.Parameter(model_ww['conv_original_size0.0.bias'],requires_grad=False)
model_unet.conv_original_size1[0].weight = nn.Parameter(model_ww['conv_original_size1.0.weight'],requires_grad=False)
model_unet.conv_original_size1[0].bias = nn.Parameter(model_ww['conv_original_size1.0.bias'],requires_grad=False)
model_unet.conv_original_size2[0].weight = nn.Parameter(model_ww['conv_original_size2.0.weight'],requires_grad=False)
model_unet.conv_original_size2[0].bias = nn.Parameter(model_ww['conv_original_size2.0.bias'],requires_grad=False)

model_unet.conv_last.weight = nn.Parameter(model_ww['conv_last.weight'],requires_grad=False)
model_unet.conv_last.bias = nn.Parameter(model_ww['conv_last.bias'],requires_grad=False)




model_unet.eval()


Organize the data structure

In [10]:
subject_list1 = [
        'sub-wlsubj001',
        'sub-wlsubj004',
        'sub-wlsubj006',
        'sub-wlsubj007',
        'sub-wlsubj014',
        'sub-wlsubj019',
        'sub-wlsubj023',
        'sub-wlsubj042',
        'sub-wlsubj043',
        'sub-wlsubj045',
        'sub-wlsubj046',
        'sub-wlsubj055',
        'sub-wlsubj056',
        'sub-wlsubj057',
        'sub-wlsubj062',
        'sub-wlsubj064',
        'sub-wlsubj067']

subject_list2 = [
        'sub-wlsubj071',
        'sub-wlsubj076',
        'sub-wlsubj079',
        'sub-wlsubj081',
        'sub-wlsubj083',
        'sub-wlsubj084',
        'sub-wlsubj085',
        'sub-wlsubj086',
        'sub-wlsubj087',
        'sub-wlsubj088',
        'sub-wlsubj090',
        'sub-wlsubj091',
        'sub-wlsubj092',
        'sub-wlsubj094',
        'sub-wlsubj095',
        'sub-wlsubj104',
        'sub-wlsubj105',
        'sub-wlsubj109',
        'sub-wlsubj114',
        'sub-wlsubj115',
        'sub-wlsubj116',
        'sub-wlsubj117',
        'sub-wlsubj118',
        'sub-wlsubj120',
#        'sub-wlsubj121',
        'sub-wlsubj122',
        'sub-wlsubj126']

feature_list = ['x', 'y', 'z','curvature', 'convexity','thickness', 'surface_area','prf_x','prf_y','prf_sigma','prf_cod']
rescale_feature_list = ['x', 'y', 'z','curvature', 'convexity','thickness', 'surface_area','prf_scaled_x','prf_scaled_y','prf_sigma','prf_cod']

bce loss (copy paste from visual autolabel.util._core.py)

In [None]:
def is_logits(data):
    """Attempts to guess whether the given PyTorch tensor contains logits.

    If the argument `data` contains only values that are no less than 0 and no
    greater than 1, then `False` is returned; otherwise, `True` is returned.
    """
    if   (data > 1).any(): return True
    elif (data < 0).any(): return True
    else:                  return False

In [None]:
def dice_loss(pred, gold, logits=None, smoothing=1, graph=False, metrics=None):
    """Returns the loss based on the dice coefficient.
    
    `dice_loss(pred, gold)` returns the dice-coefficient loss between the
    tensors `pred` and `gold` which must be the same shape and which should
    represent probabilities. The first two dimensions of both `pred` and `gold`
    must represent the batch-size and the classes.

    Parameters
    ----------
    pred : tensor
        The predicted probabilities of each class.
    gold : tensor
        The gold-standard labels for each class.
    logits : boolean, optional
        Whether the values in `pred` are logits--i.e., unnormalized scores that
        have not been run through a sigmoid calculation already. If this is
        `True`, then the BCE starts by calculating the sigmoid of the `pred`
        argument. If `None`, then attempts to deduce whether the input is or is
        not logits. The default is `None`.
    smoothing : number, optional
        The smoothing coefficient `s`. The default is `1`.
    metrics : dict or None, optional
        An optional dictionary into which the key `'dice'` should be inserted
        with the dice-loss as the value.

    Returns
    -------
    float
        The dice-coefficient loss of the prediction.
    """
    pred = pred.contiguous()
    gold = gold.contiguous()
    if logits is None: logits = is_logits(pred)
    if logits: pred = torch.sigmoid(pred)
    intersection = (pred * gold)
    pred = pred**2
    gold = gold**2
    while len(intersection.shape) > 2:
        intersection = intersection.sum(dim=-1)
        pred = pred.sum(dim=-1)
        gold = gold.sum(dim=-1)
    if smoothing is None: smoothing = 0
    loss = (1 - ((2 * intersection + smoothing) / (pred + gold + smoothing)))
    # Average the loss across classes then take the mean across batch elements.
    loss = loss.mean(dim=1).mean()
    if metrics is not None:
        if 'dice' not in metrics: metrics['dice'] = 0.0
        metrics['dice'] += loss.data.cpu().numpy() * gold.size(0)
    return loss

In [None]:
cor_list = ['V1','V2','V3']
loss_sum = torch.ones(len(subject_list2),3)

ots = - 1 
for sub_name in subject_list2:
    ots += 1
    pred_all = np.load('func_result/' + sub_name + '.npy')
    for ii in range(3):
        pred = torch.from_numpy(pred_all[0,ii,:,:].squeeze())
        true_val = retinotopy_cache.get(sub_name,cor_list[ii],multiproc= False)
        true_val = resize_tensor(true_val.unsqueeze(dim=0).unsqueeze(dim=0),(128,256)).squeeze()   
        dice_l = dice_loss(pred,true_val) 
        loss_sum[ots,ii] = dice_l
        del true_val
        
torch.save(loss_sum,'func_result/loss2.pt')