In [1]:
import os
import nibabel as nb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [5]:
jpeg_csv_path = '/w/246/gzk/PPMI/codes/first_model_jpeg_csv_info.csv'
df = pd.read_csv(jpeg_csv_path)
df.head()

Unnamed: 0,seriesIdentifier,Mfg Model,mode,preprocessed_3_path_jpeg,frame_num
0,103294,TrioTim,train,/w/284/gzk/result/JPEG_files/103294/0.jpg,0
1,103294,TrioTim,train,/w/284/gzk/result/JPEG_files/103294/1.jpg,1
2,103294,TrioTim,train,/w/284/gzk/result/JPEG_files/103294/2.jpg,2
3,103294,TrioTim,train,/w/284/gzk/result/JPEG_files/103294/3.jpg,3
4,103294,TrioTim,train,/w/284/gzk/result/JPEG_files/103294/4.jpg,4


In [4]:
seriesIdentifier_lists = df['seriesIdentifier'].tolist()

In [8]:
max_frame_per_subj = {}
for subj in seriesIdentifier_lists:
    tmp = df[df['seriesIdentifier'] == subj]
    print(tmp['frame_num'].max())
    max_frame_per_subj[subj] = tmp['frame_num'].max()
    break
### do not need, we have already pad the images and all they have 255 frame!

255


In [9]:
import torch
from torch import nn
from torchvision import models

class EncoderVGG(nn.Module):
    '''Encoder of image based on the architecture of VGG-16 with batch normalization.

    Args:
        pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters.
            Defaults to True.

    '''
    channels_in = 3
    channels_code = 512

    def __init__(self, pretrained_params=True):
        super(EncoderVGG, self).__init__()

        vgg = models.vgg16_bn(pretrained=pretrained_params)
        del vgg.classifier
        del vgg.avgpool

        self.encoder = self._encodify_(vgg)

In [10]:
def _encodify_(self, encoder):
        '''Create list of modules for encoder based on the architecture in VGG template model.

        In the encoder-decoder architecture, the unpooling operations in the decoder require pooling
        indices from the corresponding pooling operation in the encoder. In VGG template, these indices
        are not returned. Hence the need for this method to extent the pooling operations.

        Args:
            encoder : the template VGG model

        Returns:
            modules : the list of modules that define the encoder corresponding to the VGG model

        '''
        modules = nn.ModuleList()
        for module in encoder.features:
            if isinstance(module, nn.MaxPool2d):
                module_add = nn.MaxPool2d(kernel_size=module.kernel_size,
                                          stride=module.stride,
                                          padding=module.padding,
                                          return_indices=True)
                modules.append(module_add)
            else:
                modules.append(module)

        return modules

In [11]:
def forward(self, x):
    '''Execute the encoder on the image input

    Args:
        x (Tensor): image tensor

    Returns:
        x_code (Tensor): code tensor
        pool_indices (list): Pool indices tensors in order of the pooling modules

    '''
    pool_indices = []
    x_current = x
    for module_encode in self.encoder:
        output = module_encode(x_current)

        ## If the module is pooling, there are two outputs, the second the pool indices
        if isinstance(output, tuple) and len(output) == 2:
            x_current = output[0]
            pool_indices.append(output[1])
        else:
            x_current = output

    return x_current, pool_indices

In [12]:
class DecoderVGG(nn.Module):
    '''Decoder of code based on the architecture of VGG-16 with batch normalization.

    Args:
        encoder: The encoder instance of `EncoderVGG` that is to be inverted into a decoder

    '''
    channels_in = EncoderVGG.channels_code
    channels_out = 3

    def __init__(self, encoder):
        super(DecoderVGG, self).__init__()

        self.decoder = self._invert_(encoder)

    def _invert_(self, encoder):
        '''Invert the encoder in order to create the decoder as a (more or less) mirror image of the encoder

        The decoder is comprised of two principal types: the 2D transpose convolution and the 2D unpooling. The 2D transpose
        convolution is followed by batch normalization and activation. Therefore as the module list of the encoder
        is iterated over in reverse, a convolution in encoder is turned into transposed convolution plus normalization
        and activation, and a maxpooling in encoder is turned into unpooling.

        Args:
            encoder (ModuleList): the encoder

        Returns:
            decoder (ModuleList): the decoder obtained by "inversion" of encoder
        '''
        modules_transpose = []
        for module in reversed(encoder):

            if isinstance(module, nn.Conv2d):
                kwargs = {'in_channels' : module.out_channels, 'out_channels' : module.in_channels,
                          'kernel_size' : module.kernel_size, 'stride' : module.stride,
                          'padding' : module.padding}
                module_transpose = nn.ConvTranspose2d(**kwargs)
                module_norm = nn.BatchNorm2d(module.in_channels)
                module_act = nn.ReLU(inplace=True)
                modules_transpose += [module_transpose, module_norm, module_act]

            elif isinstance(module, nn.MaxPool2d):
                kwargs = {'kernel_size' : module.kernel_size, 'stride' : module.stride,
                          'padding' : module.padding}
                module_transpose = nn.MaxUnpool2d(**kwargs)
                modules_transpose += [module_transpose]
        ## Discard the final normalization and activation, so final module is convolution with bias
        modules_transpose = modules_transpose[:-2]

        return nn.ModuleList(modules_transpose)
                