In [1]:
import numpy as np
import torch
from sklearn.model_selection import train_test_split
import random
import time
import logging
import matplotlib as mpl, matplotlib.pyplot as plt
import torch.nn as nn

Model

In [None]:
model_ww = torch.load('/scratch/bs4283/visual_autolabel/best_func.pt')

In [None]:
import torchvision
def convrelu(in_channels, out_channels,
             kernel=3, padding=None, stride=1, bias=True, inplace=True):
    """Shortcut for creating a PyTorch 2D convolution followed by a ReLU.

    Parameters
    ----------
    in_channels : int
        The number of input channels in the convolution.
    out_channels : int
        The number of output channels in the convolution.
    kernel : int, optional
        The kernel size for the convolution (default: 3).
    padding : int or None, optional
        The padding size for the convolution; if `None` (the default), then
        chooses a padding size that attempts to maintain the image-size.
    stride : int, optional
        The stride to use in the convolution (default: 1).
    bias : boolean, optional
        Whether the convolution has a learnable bias (default: True).
    inplace : boolean, optional
        Whether to perform the ReLU operation in-place (default: True).

    Returns
    -------
    torch.nn.Sequential
        The model of a 2D-convolution followed by a ReLU operation.
    """
#    if padding is None:
#        padding = kernel_default_padding(kernel)
    return torch.nn.Sequential(
        torch.nn.Conv2d(in_channels, out_channels, kernel,
                        padding=padding, bias=bias),
        torch.nn.ReLU(inplace=inplace))

In [None]:
class UNet(torch.nn.Module):
    """a U-Net with a ResNet18 backbone for learning visual area labels.

    The `UNet` class implements a ["U-Net"](https://arxiv.org/abs/1505.04597)
    with a [ResNet-18](https://pytorch.org/hub/pytorch_vision_resnet/) bacbone.
    The class inherits from `torch.nn.Module`.
    
    The original implementation of this class was by Shaoling Chen
    (sc6995@nyu.edu), and additional modifications have been made by Noah C.
    Benson (nben@uw.edu).

    Parameters
    ----------
    feature_count : int
        The number of channels (features) in the input image. When using an
        `HCPVisualDataset` object for training, this value should be set to 4
        if the dataset uses the `'anat'` or `'func'` features and 8 if it uses
        the `'both'` features.
    segment_count : int
        The number of segments (AKA classes, labels) in the output data. For
        V1-V3 this is typically either 3 (V1, V2, V3) or 6 (LV1, LV2, LV3, RV1,
        RV2, RV3).
    base_model : model name or tuple, optional
        The name of the model that is to be used as the base/backbone of the
        UNet. The default is `'resnet18'`, but 
    pretrained : boolean, optional
        Whether to use a pretrained base model for the backbone (`True`) or not
        (`False`). The default is `False`.
    logits : boolean, optional
        Whether the model should return logits (`True`) or probabilities
        (`False`). The default is `True`.

    Attributes
    ----------
    pretrained_base : boolean
        `True` if the base model used in this `UNet` was originally pre-trained
        and `False` otherwise.
    base_model : PyTorch Module
        The ResNet-18 model that is used as the backbone of the `UNet` model.
    base_layers : list of PyTorch Modules
        The ResNet-18 layers that are used in the backbone of the `UNet` model.
    feature_count : int
        The number of input channels (features) that the model expects in input
        images.
    segment_count : int
        The number of segments (labels) predicted by the model.
    logits : bool
        `True` if the output of the model is in logits and `False` if its output
        is in probabilities.
    """
    def __init__(self, feature_count, segment_count,
                 base_model='resnet18',
                 pretrained=False,
                 logits=False):
        import torch.nn as nn
        # Initialize the super-class.
        super().__init__()
        # Store some basic attributes.
        self.feature_count = feature_count
        self.segment_count = segment_count
        self.pretrained = pretrained
        self.logits = logits
        # Set up the base model and base layers for the model.
        if pretrained:
            weights = 'IMAGENET1K_V1'
        else:
            weights = None
        import torchvision.models as mdls
        base_model = getattr(mdls, base_model)
        try:
            base_model = base_model(weights=weights,
                                    num_classes=segment_count)
        except TypeError:
            base_model = base_model(pretrained=pretrained,
                                    num_classes=segment_count)
        # Not sure we should store the base model; seems like a good idea, but
        # does it get caught up in PyTorch's Module data when we do?
        #self.base_model = resnet18(pretrained=pretrained)
        # Because the input size may not be 3 and the output size may not be 3,
        # we want to add an additional 
        if feature_count != 3:
            # Adjust the first convolution's number of input channels.
            c1 = base_model.conv1
            base_model.conv1 = nn.Conv2d(
                feature_count, c1.out_channels,
                kernel_size=c1.kernel_size, stride=c1.stride,
                padding=c1.padding, bias=c1.bias)
        base_layers = list(base_model.children())
        #self.base_layers = base_layers
        # Make the U-Net layers out of the base-layers.
        # size = (N, 64, H/2, W/2)
        self.layer0 = nn.Sequential(*base_layers[:3]) 
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        # size = (N, 64, H/4, W/4)
        self.layer1 = nn.Sequential(*base_layers[3:5])
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        # size = (N, 128, H/8, W/8)        
        self.layer2 = base_layers[5]
        self.layer2_1x1 = convrelu(128, 128, 1, 0)  
        # size = (N, 256, H/16, W/16)
        self.layer3 = base_layers[6]  
        self.layer3_1x1 = convrelu(256, 256, 1, 0)  
        # size = (N, 512, H/32, W/32)
        self.layer4 = base_layers[7]
        self.layer4_1x1 = convrelu(512, 512, 1, 0)
        # The up-swing of the UNet; we will need to upsample the image.
        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)
        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
        self.conv_original_size0 = convrelu(feature_count, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
        self.conv_last = nn.Conv2d(64, segment_count, 1)
    def forward(self, input):
        # Do the original size convolutions.
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)
        # Now the front few layers, which we save for adding back in on the UNet
        # up-swing below.
        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)        
        layer4 = self.layer4(layer3)
        # Now, we start the up-swing; each step must upsample the image.
        layer4 = self.layer4_1x1(layer4)
        # Up-swing Step 1
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)
        # Up-swing Step 2
        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)
        # Up-swing Step 3
        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)
        # Up-swing Step 4
        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)
        # Up-swing Step 5
        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)        
        # And the final convolution.
        out = self.conv_last(x)
        if not self.logits:
            out = torch.sigmoid(out)
        return out

In [None]:
model_unet  = UNet(feature_count = 11, segment_count = 3)

In [None]:
import torch.nn as nn
model_unet.layer0[0].weight = nn.Parameter(model_ww['layer0.0.weight'],requires_grad=True)
model_unet.layer0[1].weight = nn.Parameter(model_ww['layer0.1.weight'],requires_grad=True)
model_unet.layer0[1].bias = nn.Parameter(model_ww['layer0.1.bias'],requires_grad=True)
model_unet.layer0[1].running_mean = nn.Parameter(model_ww['layer0.1.running_mean'],requires_grad=False)
model_unet.layer0[1].running_var = nn.Parameter(model_ww['layer0.1.running_var'],requires_grad=False)
model_unet.layer0[1].num_batches_tracked = nn.Parameter(model_ww['layer0.1.num_batches_tracked'],requires_grad=False)
model_unet.layer0_1x1[0].weight = nn.Parameter(model_ww['layer0_1x1.0.weight'],requires_grad=True)
model_unet.layer0_1x1[0].bias = nn.Parameter(model_ww['layer0_1x1.0.bias'],requires_grad=True)
model_unet.layer1[1][0].conv1.weight = nn.Parameter(model_ww['layer1.1.0.conv1.weight'],requires_grad=True)
model_unet.layer1[1][0].bn1.weight = nn.Parameter(model_ww['layer1.1.0.bn1.weight'],requires_grad=True)
model_unet.layer1[1][0].bn1.bias = nn.Parameter(model_ww['layer1.1.0.bn1.bias'],requires_grad=True)
model_unet.layer1[1][0].bn1.running_mean = nn.Parameter(model_ww['layer1.1.0.bn1.running_mean'],requires_grad=False)
model_unet.layer1[1][0].bn1.running_var = nn.Parameter(model_ww['layer1.1.0.bn1.running_var'],requires_grad=False)
model_unet.layer1[1][0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer1.1.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer1[1][0].conv2.weight = nn.Parameter(model_ww['layer1.1.0.conv2.weight'],requires_grad=True)
model_unet.layer1[1][0].bn2.weight = nn.Parameter(model_ww['layer1.1.0.bn2.weight'],requires_grad=True)



model_unet.layer1[1][0].bn2.bias = nn.Parameter(model_ww['layer1.1.0.bn2.bias'],requires_grad=True)
model_unet.layer1[1][0].bn2.running_mean = nn.Parameter(model_ww['layer1.1.0.bn2.running_mean'],requires_grad=False)
model_unet.layer1[1][0].bn2.running_var = nn.Parameter(model_ww['layer1.1.0.bn2.running_var'],requires_grad=False)
model_unet.layer1[1][0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer1.1.0.bn2.num_batches_tracked'],requires_grad=False)


model_unet.layer1[1][1].conv1.weight = nn.Parameter(model_ww['layer1.1.1.conv1.weight'],requires_grad=True)
model_unet.layer1[1][1].bn1.weight = nn.Parameter(model_ww['layer1.1.1.bn1.weight'],requires_grad=True)
model_unet.layer1[1][1].bn1.bias = nn.Parameter(model_ww['layer1.1.1.bn1.bias'],requires_grad=True)
model_unet.layer1[1][1].bn1.running_mean = nn.Parameter(model_ww['layer1.1.1.bn1.running_mean'],requires_grad=False)
model_unet.layer1[1][1].bn1.running_var = nn.Parameter(model_ww['layer1.1.1.bn1.running_var'],requires_grad=False)
model_unet.layer1[1][1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer1.1.1.bn1.num_batches_tracked'],requires_grad=False)


model_unet.layer1[1][1].conv2.weight = nn.Parameter(model_ww['layer1.1.1.conv2.weight'],requires_grad=True)
model_unet.layer1[1][1].bn2.weight = nn.Parameter(model_ww['layer1.1.1.bn2.weight'],requires_grad=True)
model_unet.layer1[1][1].bn2.bias = nn.Parameter(model_ww['layer1.1.1.bn2.bias'],requires_grad=True)
model_unet.layer1[1][1].bn2.running_mean = nn.Parameter(model_ww['layer1.1.1.bn2.running_mean'],requires_grad=False)
model_unet.layer1[1][1].bn2.running_var = nn.Parameter(model_ww['layer1.1.1.bn2.running_var'],requires_grad=False)
model_unet.layer1[1][1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer1.1.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer1_1x1[0].weight = nn.Parameter(model_ww['layer1_1x1.0.weight'],requires_grad=True)
model_unet.layer1_1x1[0].bias = nn.Parameter(model_ww['layer1_1x1.0.bias'],requires_grad=True)

model_unet.layer2[0].conv1.weight = nn.Parameter(model_ww['layer2.0.conv1.weight'],requires_grad=True)
model_unet.layer2[0].bn1.weight = nn.Parameter(model_ww['layer2.0.bn1.weight'],requires_grad=True)
model_unet.layer2[0].bn1.bias = nn.Parameter(model_ww['layer2.0.bn1.bias'],requires_grad=True)
model_unet.layer2[0].bn1.running_mean = nn.Parameter(model_ww['layer2.0.bn1.running_mean'],requires_grad=False)
model_unet.layer2[0].bn1.running_var = nn.Parameter(model_ww['layer2.0.bn1.running_var'],requires_grad=False)
model_unet.layer2[0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer2.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer2[0].conv2.weight = nn.Parameter(model_ww['layer2.0.conv2.weight'],requires_grad=True)
model_unet.layer2[0].bn2.weight = nn.Parameter(model_ww['layer2.0.bn2.weight'],requires_grad=True)
model_unet.layer2[0].bn2.bias = nn.Parameter(model_ww['layer2.0.bn2.bias'],requires_grad=True)
model_unet.layer2[0].bn2.running_mean = nn.Parameter(model_ww['layer2.0.bn2.running_mean'],requires_grad=False)
model_unet.layer2[0].bn2.running_var = nn.Parameter(model_ww['layer2.0.bn2.running_var'],requires_grad=False)
model_unet.layer2[0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer2.0.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer2[0].downsample[0].weight = nn.Parameter(model_ww['layer2.0.downsample.0.weight'],requires_grad=True)
model_unet.layer2[0].downsample[1].weight = nn.Parameter(model_ww['layer2.0.downsample.1.weight'],requires_grad=True)
model_unet.layer2[0].downsample[1].bias = nn.Parameter(model_ww[ 'layer2.0.downsample.1.bias'],requires_grad=True)
model_unet.layer2[0].downsample[1].running_mean = nn.Parameter(model_ww['layer2.0.downsample.1.running_mean'],requires_grad=False)
model_unet.layer2[0].downsample[1].running_var = nn.Parameter(model_ww['layer2.0.downsample.1.running_var'],requires_grad=False)
model_unet.layer2[0].downsample[1].num_batches_tracked = nn.Parameter(model_ww['layer2.0.downsample.1.num_batches_tracked'],requires_grad=False)


model_unet.layer2[1].conv1.weight = nn.Parameter(model_ww['layer2.1.conv1.weight'],requires_grad=True)
model_unet.layer2[1].bn1.weight = nn.Parameter(model_ww['layer2.1.bn1.weight'],requires_grad=True)
model_unet.layer2[1].bn1.bias = nn.Parameter(model_ww['layer2.1.bn1.bias'],requires_grad=True)
model_unet.layer2[1].bn1.running_mean = nn.Parameter(model_ww['layer2.1.bn1.running_mean'],requires_grad=False)
model_unet.layer2[1].bn1.running_var = nn.Parameter(model_ww['layer2.1.bn1.running_var'],requires_grad=False)
model_unet.layer2[1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer2.1.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer2[1].conv2.weight = nn.Parameter(model_ww['layer2.1.conv2.weight'],requires_grad=True)
model_unet.layer2[1].bn2.weight = nn.Parameter(model_ww['layer2.1.bn2.weight'],requires_grad=True)
model_unet.layer2[1].bn2.bias = nn.Parameter(model_ww['layer2.1.bn2.bias'],requires_grad=True)
model_unet.layer2[1].bn2.running_mean = nn.Parameter(model_ww['layer2.1.bn2.running_mean'],requires_grad=False)
model_unet.layer2[1].bn2.running_var = nn.Parameter(model_ww['layer2.1.bn2.running_var'],requires_grad=False)
model_unet.layer2[1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer2.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer2_1x1[0].weight = nn.Parameter(model_ww['layer2_1x1.0.weight'],requires_grad=True)
model_unet.layer2_1x1[0].bias = nn.Parameter(model_ww['layer2_1x1.0.bias'],requires_grad=True)


model_unet.layer3[0].conv1.weight = nn.Parameter(model_ww['layer3.0.conv1.weight'],requires_grad=True)
model_unet.layer3[0].bn1.weight = nn.Parameter(model_ww['layer3.0.bn1.weight'],requires_grad=True)
model_unet.layer3[0].bn1.bias = nn.Parameter(model_ww['layer3.0.bn1.bias'],requires_grad=True)
model_unet.layer3[0].bn1.running_mean = nn.Parameter(model_ww['layer3.0.bn1.running_mean'],requires_grad=False)
model_unet.layer3[0].bn1.running_var = nn.Parameter(model_ww['layer3.0.bn1.running_var'],requires_grad=False)
model_unet.layer3[0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer3.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer3[0].conv2.weight = nn.Parameter(model_ww['layer3.0.conv2.weight'],requires_grad=True)
model_unet.layer3[0].bn2.weight = nn.Parameter(model_ww['layer3.0.bn2.weight'],requires_grad=True)
model_unet.layer3[0].bn2.bias = nn.Parameter(model_ww['layer3.0.bn2.bias'],requires_grad=True)
model_unet.layer3[0].bn2.running_mean = nn.Parameter(model_ww['layer3.0.bn2.running_mean'],requires_grad=False)
model_unet.layer3[0].bn2.running_var = nn.Parameter(model_ww['layer3.0.bn2.running_var'],requires_grad=False)
model_unet.layer3[0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer3.0.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer3[0].downsample[0].weight = nn.Parameter(model_ww['layer3.0.downsample.0.weight'],requires_grad=True)
model_unet.layer3[0].downsample[1].weight = nn.Parameter(model_ww['layer3.0.downsample.1.weight'],requires_grad=True)
model_unet.layer3[0].downsample[1].bias = nn.Parameter(model_ww[ 'layer3.0.downsample.1.bias'],requires_grad=True)
model_unet.layer3[0].downsample[1].running_mean = nn.Parameter(model_ww['layer3.0.downsample.1.running_mean'],requires_grad=False)
model_unet.layer3[0].downsample[1].running_var = nn.Parameter(model_ww['layer3.0.downsample.1.running_var'],requires_grad=False)
model_unet.layer3[0].downsample[1].num_batches_tracked = nn.Parameter(model_ww['layer3.0.downsample.1.num_batches_tracked'],requires_grad=False)


model_unet.layer3[1].conv1.weight = nn.Parameter(model_ww['layer3.1.conv1.weight'],requires_grad=True)
model_unet.layer3[1].bn1.weight = nn.Parameter(model_ww['layer3.1.bn1.weight'],requires_grad=True)
model_unet.layer3[1].bn1.bias = nn.Parameter(model_ww['layer3.1.bn1.bias'],requires_grad=True)
model_unet.layer3[1].bn1.running_mean = nn.Parameter(model_ww['layer3.1.bn1.running_mean'],requires_grad=False)
model_unet.layer3[1].bn1.running_var = nn.Parameter(model_ww['layer3.1.bn1.running_var'],requires_grad=False)
model_unet.layer3[1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer3.1.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer3[1].conv2.weight = nn.Parameter(model_ww['layer3.1.conv2.weight'],requires_grad=True)
model_unet.layer3[1].bn2.weight = nn.Parameter(model_ww['layer3.1.bn2.weight'],requires_grad=True)
model_unet.layer3[1].bn2.bias = nn.Parameter(model_ww['layer3.1.bn2.bias'],requires_grad=True)
model_unet.layer3[1].bn2.running_mean = nn.Parameter(model_ww['layer3.1.bn2.running_mean'],requires_grad=False)
model_unet.layer3[1].bn2.running_var = nn.Parameter(model_ww['layer3.1.bn2.running_var'],requires_grad=False)
model_unet.layer3[1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer3.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer3_1x1[0].weight = nn.Parameter(model_ww['layer3_1x1.0.weight'],requires_grad=True)
model_unet.layer3_1x1[0].bias = nn.Parameter(model_ww['layer3_1x1.0.bias'],requires_grad=True)


model_unet.layer4[0].conv1.weight = nn.Parameter(model_ww['layer4.0.conv1.weight'],requires_grad=True)
model_unet.layer4[0].bn1.weight = nn.Parameter(model_ww['layer4.0.bn1.weight'],requires_grad=True)
model_unet.layer4[0].bn1.bias = nn.Parameter(model_ww['layer4.0.bn1.bias'],requires_grad=True)
model_unet.layer4[0].bn1.running_mean = nn.Parameter(model_ww['layer4.0.bn1.running_mean'],requires_grad=False)
model_unet.layer4[0].bn1.running_var = nn.Parameter(model_ww['layer4.0.bn1.running_var'],requires_grad=False)
model_unet.layer4[0].bn1.num_batches_tracked = nn.Parameter(model_ww['layer4.0.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer4[0].conv2.weight = nn.Parameter(model_ww['layer4.0.conv2.weight'],requires_grad=True)
model_unet.layer4[0].bn2.weight = nn.Parameter(model_ww['layer4.0.bn2.weight'],requires_grad=True)
model_unet.layer4[0].bn2.bias = nn.Parameter(model_ww['layer4.0.bn2.bias'],requires_grad=True)
model_unet.layer4[0].bn2.running_mean = nn.Parameter(model_ww['layer4.0.bn2.running_mean'],requires_grad=False)
model_unet.layer4[0].bn2.running_var = nn.Parameter(model_ww['layer4.0.bn2.running_var'],requires_grad=False)
model_unet.layer4[0].bn2.num_batches_tracked = nn.Parameter(model_ww['layer4.0.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer4[0].downsample[0].weight = nn.Parameter(model_ww['layer4.0.downsample.0.weight'],requires_grad=True)
model_unet.layer4[0].downsample[1].weight = nn.Parameter(model_ww['layer4.0.downsample.1.weight'],requires_grad=True)
model_unet.layer4[0].downsample[1].bias = nn.Parameter(model_ww[ 'layer4.0.downsample.1.bias'],requires_grad=True)
model_unet.layer4[0].downsample[1].running_mean = nn.Parameter(model_ww['layer4.0.downsample.1.running_mean'],requires_grad=False)
model_unet.layer4[0].downsample[1].running_var = nn.Parameter(model_ww['layer4.0.downsample.1.running_var'],requires_grad=False)
model_unet.layer4[0].downsample[1].num_batches_tracked = nn.Parameter(model_ww['layer4.0.downsample.1.num_batches_tracked'],requires_grad=False)


model_unet.layer4[1].conv1.weight = nn.Parameter(model_ww['layer4.1.conv1.weight'],requires_grad=True)
model_unet.layer4[1].bn1.weight = nn.Parameter(model_ww['layer4.1.bn1.weight'],requires_grad=True)
model_unet.layer4[1].bn1.bias = nn.Parameter(model_ww['layer4.1.bn1.bias'],requires_grad=True)
model_unet.layer4[1].bn1.running_mean = nn.Parameter(model_ww['layer4.1.bn1.running_mean'],requires_grad=False)
model_unet.layer4[1].bn1.running_var = nn.Parameter(model_ww['layer4.1.bn1.running_var'],requires_grad=False)
model_unet.layer4[1].bn1.num_batches_tracked = nn.Parameter(model_ww['layer4.1.bn1.num_batches_tracked'],requires_grad=False)
model_unet.layer4[1].conv2.weight = nn.Parameter(model_ww['layer4.1.conv2.weight'],requires_grad=True)
model_unet.layer4[1].bn2.weight = nn.Parameter(model_ww['layer4.1.bn2.weight'],requires_grad=True)
model_unet.layer4[1].bn2.bias = nn.Parameter(model_ww['layer4.1.bn2.bias'],requires_grad=True)
model_unet.layer4[1].bn2.running_mean = nn.Parameter(model_ww['layer4.1.bn2.running_mean'],requires_grad=False)
model_unet.layer4[1].bn2.running_var = nn.Parameter(model_ww['layer4.1.bn2.running_var'],requires_grad=False)
model_unet.layer4[1].bn2.num_batches_tracked = nn.Parameter(model_ww['layer4.1.bn2.num_batches_tracked'],requires_grad=False)

model_unet.layer4_1x1[0].weight = nn.Parameter(model_ww['layer4_1x1.0.weight'],requires_grad=True)
model_unet.layer4_1x1[0].bias = nn.Parameter(model_ww['layer4_1x1.0.bias'],requires_grad=True)

model_unet.conv_up3[0].weight = nn.Parameter(model_ww['conv_up3.0.weight'],requires_grad=True)
model_unet.conv_up3[0].bias = nn.Parameter(model_ww['conv_up3.0.bias'],requires_grad=True)
model_unet.conv_up2[0].weight = nn.Parameter(model_ww['conv_up2.0.weight'],requires_grad=True)
model_unet.conv_up2[0].bias = nn.Parameter(model_ww['conv_up2.0.bias'],requires_grad=True)
model_unet.conv_up1[0].weight = nn.Parameter(model_ww['conv_up1.0.weight'],requires_grad=True)
model_unet.conv_up1[0].bias = nn.Parameter(model_ww['conv_up1.0.bias'],requires_grad=True)
model_unet.conv_up0[0].weight = nn.Parameter(model_ww['conv_up0.0.weight'],requires_grad=True)
model_unet.conv_up0[0].bias = nn.Parameter(model_ww['conv_up0.0.bias'],requires_grad=True)


model_unet.conv_original_size0[0].weight = nn.Parameter(model_ww['conv_original_size0.0.weight'],requires_grad=True)
model_unet.conv_original_size0[0].bias = nn.Parameter(model_ww['conv_original_size0.0.bias'],requires_grad=True)
model_unet.conv_original_size1[0].weight = nn.Parameter(model_ww['conv_original_size1.0.weight'],requires_grad=True)
model_unet.conv_original_size1[0].bias = nn.Parameter(model_ww['conv_original_size1.0.bias'],requires_grad=True)
model_unet.conv_original_size2[0].weight = nn.Parameter(model_ww['conv_original_size2.0.weight'],requires_grad=True)
model_unet.conv_original_size2[0].bias = nn.Parameter(model_ww['conv_original_size2.0.bias'],requires_grad=True)

model_unet.conv_last.weight = nn.Parameter(model_ww['conv_last.weight'],requires_grad=True)
model_unet.conv_last.bias = nn.Parameter(model_ww['conv_last.bias'],requires_grad=True)



Loss function (dice + bce loss)

In [21]:
def is_logits(data):
    """Attempts to guess whether the given PyTorch tensor contains logits.

    If the argument `data` contains only values that are no less than 0 and no
    greater than 1, then `False` is returned; otherwise, `True` is returned.
    """
    if   (data > 1).any(): return True
    elif (data < 0).any(): return True
    else:                  return False

In [22]:
def dice_loss(pred, gold, logits=None, smoothing=1, graph=False, metrics=None):
    """Returns the loss based on the dice coefficient.
    
    `dice_loss(pred, gold)` returns the dice-coefficient loss between the
    tensors `pred` and `gold` which must be the same shape and which should
    represent probabilities. The first two dimensions of both `pred` and `gold`
    must represent the batch-size and the classes.

    Parameters
    ----------
    pred : tensor
        The predicted probabilities of each class.
    gold : tensor
        The gold-standard labels for each class.
    logits : boolean, optional
        Whether the values in `pred` are logits--i.e., unnormalized scores that
        have not been run through a sigmoid calculation already. If this is
        `True`, then the BCE starts by calculating the sigmoid of the `pred`
        argument. If `None`, then attempts to deduce whether the input is or is
        not logits. The default is `None`.
    smoothing : number, optional
        The smoothing coefficient `s`. The default is `1`.
    metrics : dict or None, optional
        An optional dictionary into which the key `'dice'` should be inserted
        with the dice-loss as the value.

    Returns
    -------
    float
        The dice-coefficient loss of the prediction.
    """
    pred = pred.contiguous()
    gold = gold.contiguous()
    if logits is None: logits = is_logits(pred)
    if logits: pred = torch.sigmoid(pred)
    intersection = (pred * gold)
    pred = pred**2
    gold = gold**2
    while len(intersection.shape) > 2:
        intersection = intersection.sum(dim=-1)
        pred = pred.sum(dim=-1)
        gold = gold.sum(dim=-1)
    if smoothing is None: smoothing = 0
    loss = (1 - ((2 * intersection + smoothing) / (pred + gold + smoothing)))
    # Average the loss across classes then take the mean across batch elements.
    loss = loss.mean(dim=1).mean()
    if metrics is not None:
        if 'dice' not in metrics: metrics['dice'] = 0.0
        metrics['dice'] += loss.data.cpu().numpy() * gold.size(0)
    return loss

In [23]:
def bce_loss(pred, gold, logits=None, reweight=True, metrics=None):
    """Returns the loss based on the binary cross entropy.
    
    `bce_loss(pred, gold)` returns the binary cross entropy loss between the
    tensors `pred` and `gold` which must be the same shape and which should
    represent probabilities. The first two dimensions of both `pred` and `gold`
    must represent the batch-size and the classes.

    Parameters
    ----------
    pred : tensor
        The predicted probabilities of each class.
    gold : tensor
        The gold-standard labels for each class.
    logits : boolean, optional
        Whether the values in `pred` are logits--i.e., unnormalized scores that
        have not been run through a sigmoid calculation already. If this is
        `True`, then the BCE starts by calculating the sigmoid of the `pred`
        argument. If `None`, then attempts to deduce whether the input is or is
        not logits. The default is `None`.
    reweight : boolean, optional
        Whether to reweight the classes by calculating the BCE for each class
        then calculating the mean across classes. If `False`, then the raw BCE
        across all pixels, classes, and batches is returned (the default).
    metrics : dict or None, optional
        An optional dictionary into which the key `'bce'` should be inserted
        with the dice-loss as the value.

    Returns
    -------
    float
        The binary cross entropy loss of the prediction.
    """
    if logits is None: logits = is_logits(pred)
    if logits: f = torch.nn.functional.binary_cross_entropy_with_logits
    else:      f = torch.nn.functional.binary_cross_entropy
    if reweight:
        n = pred.shape[-1] * pred.shape[-2] * pred.shape[0]
        r = 0
        for k in range(pred.shape[1]):
            (p,t) = (pred[:,[k]], gold[:,[k]])
            r += f(p, t) * (n - torch.sum(t)) / n
    else:
        r = f(pred, gold)
    if metrics is not None:
        if 'bce' not in metrics: metrics['bce'] = 0.0
        metrics['bce'] += r.data.cpu().numpy() * gold.size(0)
    return r

In [24]:
class my_loss(nn.Module):
    def __init__(self,bce_weight):
        super(my_loss,self).__init__()
        self.bce_weight = bce_weight
        
        
    def forward(self,inputs,targets):
        bce_l = bce_loss(inputs,targets)*self.bce_weight
        dice_l = dice_loss(inputs,targets)*(1-self.bce_weight)
        m_l = bce_l + dice_l
        return m_l

Log set

In [50]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Create a file handler and set the level to info
file_handler = logging.FileHandler('training_tran_re.log')
file_handler.setLevel(logging.INFO)

# Create a stream handler to print the log messages to the console
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)

# Create a formatter and set it for both handlers
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)

# Add the handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)

I first transfer the data to pytorch.tensor version, which is called image_sum.pt and pred_sum.pt

In [25]:
image = torch.load('/scratch/bs4283/visual_autolabel/image_sum.pt')
gold =  torch.load('/scratch/bs4283/visual_autolabel/pred_sum.pt')

Training process with pre-trained weight, and training with 60 epochs.

In [None]:
X_train, X_test, y_train, y_test = train_test_split(image, gold, test_size=0.3, random_state= random.randint(3, 999))
iters = 60
for iteration in range(iters): 
    if iters < 20:
        loss_fn = my_loss(bce_weight = 0.67)
        lr = 0.00375
    elif iters > 20 and iters < 40 :   
        loss_fn = my_loss(bce_weight = 0.33)
        lr = 0.00250
    elif iters > 40:
        loss_fn = my_loss(bce_weight = 0.0)
        lr = 0.00125
    optimizer = torch.optim.Adam(model_unet.parameters(), lr=lr)
    y_pred = model_unet(X_train)
    train_loss = loss_fn(y_pred,y_train)
    y_valid = model_unet(X_test)
    val_loss = loss_fn(y_valid,y_test)
    logger.info(f'Epoch {iteration}, Training Loss: {train_loss}, Validation Loss: {val_loss}')
    train_loss.requires_grad_(True)
    optimizer.zero_grad()
    train_loss.backward()              
    optimizer.step()   
    
    


        

In [None]:
torch.save(model_unet,'/scratch/bs4283/visual_autolabel/model_e60.pt')
pred_image = model_unet(image)
torch.save(pred_image,'/scratch/bs4283/visual_autolabel/model_e60/output_image.pt')
torch.save(y_pred,'/scratch/bs4283/visual_autolabel/model_e60/train_output.pt')
torch.save(y_valid,'/scratch/bs4283/visual_autolabel/model_e60/valid_output.pt')
torch.save(y_train,'/scratch/bs4283/visual_autolabel/model_e60/train_true.pt')
torch.save(y_test,'/scratch/bs4283/visual_autolabel/model_e60/valid_true.pt')

Model for 5 samples training with 300 epochs

In [None]:
X_train, X_test, y_train, y_test = train_test_split(image, gold, test_size=0.3, random_state= random.randint(3, 999))

iters = 300
for iteration in range(iters): 
    if iters < 100:
        loss_fn = my_loss(bce_weight = 0.67)
        lr = 0.00375
    elif iters > 100 and iters < 200 :   
        loss_fn = my_loss(bce_weight = 0.33)
        lr = 0.00250
    elif iters > 200:
        loss_fn = my_loss(bce_weight = 0.0)
        lr = 0.00125  
    optimizer = torch.optim.Adam(model_unet.parameters(), lr=lr)
    x_sample_train,_ ,y_sample_train,_ , = train_test_split(X_train, y_train, test_size=0.83, random_state= random.randint(3, 999))
    sample_t_pred = model_unet(x_sample_train)
    sample_loss = loss_fn(sample_t_pred,y_sample_train)
    y_pred = model_unet(X_train)
    train_loss = loss_fn(y_pred,y_train)
    y_valid = model_unet(X_test)
    val_loss = loss_fn(y_valid,y_test)
    logger.info(f'Epoch {iteration}, Training Loss: {train_loss}, Sample loss: {sample_loss}, Validation Loss: {val_loss}')
    sample_loss.requires_grad_(True)
    optimizer.zero_grad()
    sample_loss.backward()              
    optimizer.step()   

Log show. 

In [2]:
with open('training.log', 'r') as file:
    contents = file.read()

print(contents)

2023-09-23 13:58:00,282 - INFO - Epoch 0, Training Loss: 0.7129805684089661, Validation Loss: 0.6832046508789062
2023-09-23 13:58:00,282 - INFO - Epoch 0, Training Loss: 0.7129805684089661, Validation Loss: 0.6832046508789062
2023-09-23 13:58:00,282 - INFO - Epoch 0, Training Loss: 0.7129805684089661, Validation Loss: 0.6832046508789062
2023-09-23 13:58:10,547 - INFO - Epoch 1, Training Loss: 0.5419676303863525, Validation Loss: 0.5152378678321838
2023-09-23 13:58:10,547 - INFO - Epoch 1, Training Loss: 0.5419676303863525, Validation Loss: 0.5152378678321838
2023-09-23 13:58:10,547 - INFO - Epoch 1, Training Loss: 0.5419676303863525, Validation Loss: 0.5152378678321838
2023-09-23 13:58:20,799 - INFO - Epoch 2, Training Loss: 0.4520067572593689, Validation Loss: 0.4290185272693634
2023-09-23 13:58:20,799 - INFO - Epoch 2, Training Loss: 0.4520067572593689, Validation Loss: 0.4290185272693634
2023-09-23 13:58:20,799 - INFO - Epoch 2, Training Loss: 0.4520067572593689, Validation Loss: 0.