In [1]:
# IMPLEMENT THE RESNET
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Type, Any, Callable, Union, List, Optional
import blocks
import SimpleITK as sitk
import numpy as np

In [27]:
def set_requires_grad(nets, requires_grad=False):
    """https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/f13aab8148bd5f15b9eb47b690496df8dadbab0c/models/base_model.py#L219
    Set requires_grad=Fasle for all the networks to avoid unnecessary computations
    Parameters:
        nets (network list)   -- a list of networks
        requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad
    return net

test = torch.randn(3)
model = nn.Sequential(nn.Linear(3,10),nn.Linear(10,10),nn.Linear(10,1))

model = set_requires_grad(model, False)
    
model = set_requires_grad(model, True)



In [33]:
# TESTING PIX2PIX

a = torch.randn(2,1,512,512)
b = torch.randn(2,1,1024,1024)

m = nn.ConvTranspose2d(1, 22, kernel_size=3, stride=2)
enc = blocks.Pix2Pix_Encoder_Block(1,22, _normType=None)
#print(enc(a).shape)
dec = blocks.Pix2Pix_DecoderBlock( _in_channels=1, _out_channels=22, _kernel_size=(4,4), _stride=(2,2), _padding=(1,1), _dilation=(1,1), _normType="BatchNorm", _dropoutType=None)
#print(dec(a,b).shape)

# Generator
generator = blocks.Generator_Pix2Pix(a.shape)
#print(generator(a).shape)

# Discriminator
b = torch.randn(2,2,512,512)
disc = blocks.Discriminator_Pix2Pix(_input_array_size=b.shape, _first_out_channels=64, _normType="BatchNorm")
print(disc(b).view(-1).size(0))

7688


In [3]:
# TESTING RESUNET (Zhang 2018)
def weights_init(m):
    # From DCGAN paper
    classname = m.__class__.__name__
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        if m.affine:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    for i in m.children():
        # Specific weight setting for ResUNet shortcut.
        if i.__class__.__name__ == "ResUNet_shortcut":
            for ii in i.children():
                if isinstance(ii, nn.Conv2d) or isinstance(ii, nn.ConvTranspose2d):
                    nn.init.constant_(ii.weight.data, 1.)
            for param in i.parameters():
                param.requires_grad=False
                

a = torch.randn(2,1,512,512)
gen = blocks.Generator_ResUNet(input_array_shape=a.shape, _first_out_channels=64, _reluType="leaky")
gen.apply(weights_init)
print(gen(a).shape)


gen2 = blocks.Generator_ResUNet_PixelShuffle(input_array_shape=a.shape, _first_out_channels=64, _reluType="leaky")
print(gen2(a).shape)

torch.Size([2, 1, 512, 512])
torch.Size([2, 1, 512, 512])


In [4]:
# TESTING RESUNET-A COMPONENTS

m = blocks.ResUNet_A_miniBlock(16)
n = blocks.Conv2DN(16,20)
p = blocks.ResUNet_A_Block_4(16, _kernel_size=(3,3), _dilation_rates=[1,3,5,7])
ds = blocks.DownSample(16)
mp = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2), padding=0, dilation=1, return_indices=False, ceil_mode=False)

a = torch.randn(1, 16, 256, 256)
c = torch.randn(1, 2, 512, 512)
d = torch.randn(1, 1024, 8, 8)

enc = blocks.Encoder_ResUNet_A_d7(32, c.shape)
msc = blocks.MultiScale_Classifier(_input_channels=32, _input_array_shape=c.shape)

print(m(a).shape)
print(n(a).shape)
print(p(a).shape)
print(ds(a).shape)
print(enc(c).shape)
print(msc(c).shape)

print("TESTING BRIDGE")

output_size = (a.shape[2]//4,a.shape[3]//4)
mp = blocks.PSPPooling_miniBlock(_in_channels=16, _output_size=output_size, _kernel_size=output_size, _stride=output_size, _padding=0, _dilation=(1,1), _pyramid_levels=4)
print(mp(a).shape)

pspp = blocks.PSPPooling(_tensor_array_shape = a.shape)
print(pspp(a).shape)


upsh = blocks.UpSampleAndHalveChannels( d.shape[1])
print(upsh(d).shape)

t1 = torch.randn(1,32,128,128)
upsh = blocks.UpSampleAndHalveChannels( t1.shape[1])
cbn= blocks.Combine(_in_channels=16)
print(cbn(a,upsh(t1)).shape)

torch.Size([1, 16, 256, 256])
torch.Size([1, 20, 256, 256])
torch.Size([1, 16, 256, 256])
torch.Size([1, 32, 128, 128])
torch.Size([1, 2048, 8, 8])
torch.Size([1, 1])
TESTING BRIDGE
torch.Size([1, 4, 64, 64])
torch.Size([1, 16, 256, 256])
torch.Size([1, 512, 16, 16])
torch.Size([1, 16, 256, 256])


  "See the documentation of nn.Upsample for details.".format(mode)


In [5]:
# Generator
print(tuple(map(lambda x: int(x), (1., 2.))))

input_test = torch.randn(1,2,512, 512)
generator = blocks.Generator_ResUNet_A(_input_channels=16, _input_array_shape=input_test.shape, _norm_type='BatchNorm', _ADL_drop_rate=0.75, _ADL_gamma=0.9)

print(generator(input_test).shape)

(1, 2)


  "See the documentation of nn.Upsample for details.".format(mode)


torch.Size([1, 1, 512, 512])
