# GPU Check

In [None]:
import GPUtil
GPUs = GPUtil.getGPUs()
for gpu in GPUs:
  print(gpu.name, gpu.memoryTotal)

# Imports

In [None]:
from search_eval.eval_generic import SGLDES
from search_eval.optimizer.SingleImageDataset import SingleImageDataset
from search_eval.utils.common_utils import *
from search_space.node_space import NodeSpace
from search_space.unet.unetspaceMT import UNetSpaceMT

from nni import trace
import nni.retiarii.strategy as strategy
import nni.retiarii.serializer as serializer

from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.evaluator.pytorch import Lightning, Trainer
from nni.retiarii.evaluator.pytorch.lightning import DataLoader

from collections import OrderedDict
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import Cell

import torch

torch.cuda.empty_cache()
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
print('CUDA available: {}'.format(torch.cuda.is_available()))

# Strategy

In [None]:


@trace
def conv_2d(C_in, C_out, kernel_size=3, dilation=1, padding=1, activation=None):
    return nn.Sequential(
        nn.Conv2d(C_in, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation,
        nn.Conv2d(C_out, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation
    )

@trace
def depthwise_separable_conv(C_in, C_out, kernel_size=3, dilation=1, padding=1, activation=None):
    return nn.Sequential(
        nn.Conv2d(C_in, C_in, kernel_size=kernel_size, dilation=dilation, padding=padding, groups=C_in, bias=False),
        nn.Conv2d(C_in, C_out, 1, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation,
        nn.Conv2d(C_out, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, groups=C_out, bias=False),
        nn.Conv2d(C_out, C_out, 1, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation
    )

@trace
def transposed_conv_2d(C_in, C_out, kernel_size=4, stride=2, padding=1, activation=None):
    return nn.Sequential(
        nn.ConvTranspose2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation
    )

def pools():
    pool_dict = OrderedDict([
        ("MaxPool2d", nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),
        ("AvgPool2d", nn.AvgPool2d(kernel_size=2, stride=2, padding=0)),
        # ("AdaMaxPool2d", nn.AdaptiveMaxPool2d(1)),
        # ("AdaAvgPool2d", nn.AdaptiveAvgPool2d(1)),
        # ("DepthToSpace", nn.PixelShuffle(2)),
    ])
    return pool_dict

def upsamples(C_in, C_out):
    upsample_dict = OrderedDict([
        ("Upsample_nearest", nn.Upsample(scale_factor=2, mode='nearest')),
        ("Upsample_bilinear", nn.Upsample(scale_factor=2, mode='bilinear')),
        ("TransConv_4x4_Relu", transposed_conv_2d(C_in, C_out)),
        ("TransConv_2x2_RelU", transposed_conv_2d(C_in, C_out, kernel_size=2, stride=2, padding=0)),
    ])
    return upsample_dict

def convs(C_in, C_out):
    # all padding should follow this formula:
    # pd = (ks - 1) * dl // 2
    conv_dict = OrderedDict([
        
        ("conv2d_1x1_Relu", conv_2d(C_in, C_out)),
        # ("conv2d_1x1_SiLU", conv_2d(C_in, C_out, activation=nn.SiLU())),

        ("conv2d_3x3_Relu", conv_2d(C_in, C_out, kernel_size=3, padding=1)),
        ("conv2d_3x3_SiLU", conv_2d(C_in, C_out, kernel_size=3, padding=1, activation=nn.SiLU())),
        ("conv2d_3x3_Sigmoid", conv_2d(C_in, C_out, kernel_size=3, padding=1, activation=nn.Sigmoid())),
        # ("conv2d_3x3_Relu_1dil", conv_2d(C_in, C_out, kernel_size=3, padding=2, dilation=2)),

        # ("conv2d_5x5_Relu", conv_2d(C_in, C_out, kernel_size=5, padding=2)),
        # ("conv2d_5x5_Relu_1dil", conv_2d(C_in, C_out, kernel_size=5, padding=4, dilation=2, activation=nn.SiLU())),
        # ("conv2d_5x5_SiLU", conv_2d(C_in, C_out, kernel_size=5, padding=2, activation=nn.SiLU())),


        # ("convDS_1x1_Relu", depthwise_separable_conv(C_in, C_out)),
        ("convDS_1x1_SiLU", depthwise_separable_conv(C_in, C_out, activation=nn.SiLU())),

        ("convDS_3x3_Relu", depthwise_separable_conv(C_in, C_out, kernel_size=3, padding=1)),
        ("convDS_3x3_SiLU", depthwise_separable_conv(C_in, C_out, kernel_size=3, padding=1, activation=nn.SiLU())),

        # ("convDS_5x5_Relu", depthwise_separable_conv(C_in, C_out, kernel_size=5, padding=2)),
        # ("convDS_5x5_SiLU", depthwise_separable_conv(C_in, C_out, kernel_size=5, padding=2, activation=nn.SiLU())),
    ])
    return conv_dict

@trace
@model_wrapper
class SearchSpace(nn.Module):
    def __init__(self, C_in=1, C_out=1, depth=4, enNodes=1, deNodes=1):
        super().__init__()

        # all padding should follow this formula:
        # pd = (ks - 1) * dl // 2
        self.pr = False
        self.depth = depth
        
        self.in_layer = nn.Conv2d(C_in, 64, kernel_size=3, padding=1)
        
        # Encoders
        filters = 64
        self.encoders = nn.ModuleList()
        for i in range(depth):
            enNodes = enNodes
            pool_candidates = pools()
            enConv_candidates = convs(filters, filters*2 // enNodes)

            self.encoders.append(Cell(
                op_candidates=pool_candidates, 
                num_nodes=1, 
                num_ops_per_node=1, #len(pool_candidates), 
                num_predecessors=1, 
                label=f'pool_{i+1}',
                ))
            self.encoders.append(Cell(
                op_candidates=enConv_candidates, 
                num_nodes=enNodes, 
                num_ops_per_node=1,
                num_predecessors=1, 
                label=f'conv_{i+1}',
                ))
            filters *= 2

        # Decoders
        self.decoders = nn.ModuleList()
        for i in range(depth):
            deNodes = deNodes
            upsample_candidates = upsamples(filters, filters)

            self.decoders.append(Cell(
                op_candidates=upsample_candidates, 
                num_nodes=1, 
                num_ops_per_node=1, #len(upsample_candidates), 
                num_predecessors=1, 
                label=f'upsample_{i+1}'))
            
            filters //= 2

            deConv_candidates = convs(filters*3, filters // deNodes)

            self.decoders.append(Cell(
                op_candidates=deConv_candidates, 
                num_nodes=deNodes, 
                num_ops_per_node=1, #len(deConv_candidates), 
                num_predecessors=1, 
                label=f'conv_{i+1+depth}'))

        self.out_layer = nn.Conv2d(64, C_out, kernel_size=3, padding=1)

    def forward(self, x):

        if self.pr:
            print(f'input shape: {x.shape}\n')

        # print(f'input shape: {x.shape}')
        x = self.in_layer(x)  # Apply the initial layer
        # print(f'post in_layer shape: {x.shape}\n')
        skip_connections = [x]

        for i in range(self.depth):
            x = self.encoders[2*i]([x])
            x = self.encoders[2*i+1]([x])
            skip_connections.append(x)

        for i in range(self.depth):
            upsampled = self.decoders[2*i]([x])
            cropped = self.crop_tensor(upsampled, skip_connections[-(i+2)])
            x = torch.cat([cropped, upsampled], 1)
            x = self.decoders[2*i+1]([x])

        x = self.out_layer(x)  # Apply the final layer

        return x

    def crop_tensor(self, target_tensor, tensor):
        target_size = target_tensor.size()[2]  # Assuming height and width are same
        tensor_size = tensor.size()[2]
        delta = tensor_size - target_size
        delta = delta // 2
        return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]
    
    def test(self):
        """
        This will input a random tensor of 1x1x128x128 and test the forward pass.
        """
        self.pr = True
        x = torch.randn(1, 1, 128, 128)
        y = self.forward(x)
        assert y.shape == (1, 1, 128, 128), "Output shape should be (1, 1, 128, 128), got {}".format(y.shape)
        print(f'output shape: {y.shape}\n')
        print("Test passed.\n\n")



def get_U_Net(in_channels=1, out_channels=1, init_features=64, pretrained=False):
    return torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=in_channels, out_channels=out_channels, init_features=init_features, pretrained=pretrained)

# new space

In [None]:

search_strategy = strategy.DARTS()



@trace
def conv_2d(C_in, C_out, kernel_size=3, dilation=1, padding=1, activation=None):
    return nn.Sequential(
        nn.Conv2d(C_in, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation,
        nn.Conv2d(C_out, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation
    )

@trace
def depthwise_separable_conv(C_in, C_out, kernel_size=3, dilation=1, padding=1, activation=None):
    return nn.Sequential(
        nn.Conv2d(C_in, C_in, kernel_size=kernel_size, dilation=dilation, padding=padding, groups=C_in, bias=False),
        nn.Conv2d(C_in, C_out, 1, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation,
        nn.Conv2d(C_out, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, groups=C_out, bias=False),
        nn.Conv2d(C_out, C_out, 1, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation
    )

@trace
def transposed_conv_2d(C_in, C_out, kernel_size=4, stride=2, padding=1, activation=None):
    return nn.Sequential(
        nn.ConvTranspose2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation
    )

def pools():
    pool_dict = OrderedDict([
        ("MaxPool2d", nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),
        ("AvgPool2d", nn.AvgPool2d(kernel_size=2, stride=2, padding=0)),
        # ("AdaMaxPool2d", nn.AdaptiveMaxPool2d(1)),
        # ("AdaAvgPool2d", nn.AdaptiveAvgPool2d(1)),
        # ("DepthToSpace", nn.PixelShuffle(2)),
    ])
    return pool_dict

def upsamples(C_in, C_out):
    upsample_dict = OrderedDict([
        ("Upsample_nearest", nn.Upsample(scale_factor=2, mode='nearest')),
        ("Upsample_bilinear", nn.Upsample(scale_factor=2, mode='bilinear')),
        ("TransConv_4x4_Relu", transposed_conv_2d(C_in, C_out)),
        ("TransConv_2x2_RelU", transposed_conv_2d(C_in, C_out, kernel_size=2, stride=2, padding=0)),
    ])
    return upsample_dict

def convs(C_in, C_out):
    # all padding should follow this formula:
    # pd = (ks - 1) * dl // 2
    conv_dict = OrderedDict([
        
        ("conv2d_1x1_Relu", conv_2d(C_in, C_out)),
        # ("conv2d_1x1_SiLU", conv_2d(C_in, C_out, activation=nn.SiLU())),

        ("conv2d_3x3_Relu", conv_2d(C_in, C_out, kernel_size=3, padding=1)),
        ("conv2d_3x3_SiLU", conv_2d(C_in, C_out, kernel_size=3, padding=1, activation=nn.SiLU())),
        ("conv2d_3x3_Sigmoid", conv_2d(C_in, C_out, kernel_size=3, padding=1, activation=nn.Sigmoid())),
        # ("conv2d_3x3_Relu_1dil", conv_2d(C_in, C_out, kernel_size=3, padding=2, dilation=2)),

        # ("conv2d_5x5_Relu", conv_2d(C_in, C_out, kernel_size=5, padding=2)),
        # ("conv2d_5x5_Relu_1dil", conv_2d(C_in, C_out, kernel_size=5, padding=4, dilation=2, activation=nn.SiLU())),
        # ("conv2d_5x5_SiLU", conv_2d(C_in, C_out, kernel_size=5, padding=2, activation=nn.SiLU())),


        # ("convDS_1x1_Relu", depthwise_separable_conv(C_in, C_out)),
        ("convDS_1x1_SiLU", depthwise_separable_conv(C_in, C_out, activation=nn.SiLU())),

        ("convDS_3x3_Relu", depthwise_separable_conv(C_in, C_out, kernel_size=3, padding=1)),
        ("convDS_3x3_SiLU", depthwise_separable_conv(C_in, C_out, kernel_size=3, padding=1, activation=nn.SiLU())),

        # ("convDS_5x5_Relu", depthwise_separable_conv(C_in, C_out, kernel_size=5, padding=2)),
        # ("convDS_5x5_SiLU", depthwise_separable_conv(C_in, C_out, kernel_size=5, padding=2, activation=nn.SiLU())),
    ])
    return conv_dict

# Encoder preprocessor 
class EncoderPreprocessor(nn.Module):
  def __init__(self, in_channels, out_channels):
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
  
  def forward(self, inputs):
    return [self.conv(x) for x in inputs]

@trace
@model_wrapper
class SimpleNet(nn.Module):
    def __init__(
            self, 
            C_in=1, 
            C_out=1, 
            depth=2, 
            nodes_per_layer=1,
            ops_per_node=1,
            poolOps_per_node=1,
            upsampleOps_per_node=1,
            
            ):
        super().__init__()

        self.depth = depth
        self.nodes = nodes_per_layer
        
        nodes = nodes_per_layer
        start_filters = end_filters = 64
        self.in_layer = nn.Conv2d(C_in, start_filters, kernel_size=3, padding=1)

        # encoder layers
        mid_in = 64
        self.pools = nn.ModuleList()
        self.encoders = nn.ModuleList()
        self.postencoders = nn.ModuleList()
        for _ in range(self.depth):
            self.pools.append(Cell(
                op_candidates=pools(),
                num_nodes=1, 
                num_ops_per_node=poolOps_per_node,
                num_predecessors=1, 
            ))
            self.encoders.append(Cell(
                op_candidates=convs(mid_in,mid_in),
                num_nodes=nodes, 
                num_ops_per_node=ops_per_node,
                num_predecessors=1, 
            ))
            self.postencoders.append(nn.Conv2d(mid_in*nodes, mid_in*2, kernel_size=3, padding=1))
            mid_in *= 2

        # decoder layers
        self.upsamples = nn.ModuleList()
        self.predecoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.postdecoders = nn.ModuleList()
        for _ in range(self.depth):
            self.upsamples.append(Cell(
                op_candidates=upsamples(mid_in,mid_in),
                num_nodes=1, 
                num_ops_per_node=upsampleOps_per_node,
                num_predecessors=1, 
            ))
            mid_in //= 2
            self.predecoders.append(nn.Conv2d(mid_in*3, mid_in, kernel_size=3, padding=1))
            self.decoders.append(Cell(
                op_candidates=convs(mid_in,mid_in),
                num_nodes=nodes, 
                num_ops_per_node=ops_per_node,
                num_predecessors=1, 
            ))
            self.postdecoders.append(nn.Conv2d(mid_in*nodes, mid_in, kernel_size=3, padding=1))

        self.out_layer = nn.Conv2d(end_filters, C_out, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.in_layer(x)

        skip_connections = [x]

        for i in range(self.depth):
            x = self.pools[i]([x])
            x = self.encoders[i]([x])
            x = self.postencoders[i](x)
            skip_connections.append(x)

        for i in range(self.depth):
            upsampled = self.upsamples[i]([x])
            cropped = self.crop_tensor(upsampled, skip_connections[-(i+2)])
            x = torch.cat([cropped, upsampled], 1)
            x = self.predecoders[i](x)
            x = self.decoders[i]([x])
            x = self.postdecoders[i](x)

        x = self.out_layer(x)
        return x

    def crop_tensor(self, target_tensor, tensor):
        target_size = target_tensor.size()[2]  # Assuming height and width are same
        tensor_size = tensor.size()[2]
        delta = tensor_size - target_size
        delta = delta // 2
        return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

# Execute

In [1]:

total_iterations = 1200 # 650 for nodespace

resolution = 64
noise_type = 'gaussian'
noise_level = '0.09'
phantom =       np.load(f'/home/joe/nas-for-dip/phantoms/ground_truth/{resolution}/{45}.npy')
phantom_noisy = np.load(f'/home/joe/nas-for-dip/phantoms/{noise_type}/res_{resolution}/nl_{noise_level}/p_{45}.npy')

# Create the lightning module
learning_rate = 0.1
buffer_size = 100
patience = 500 # 75 for NodeSpace
weight_decay = 5e-5
show_every = 200
report_every = 50

module = SGLDES(
                phantom=phantom, 
                phantom_noisy=phantom_noisy,
                
                learning_rate=learning_rate, 
                buffer_size=buffer_size,
                patience=patience,
                weight_decay= weight_decay,

                show_every=show_every,
                report_every=report_every,
                HPO=False,
                NAS=True,
                OneShot=False,
                SGLD_regularize=True,
                switch=None,
                MCMC_iter=10
                )

# Create a PyTorch Lightning trainer
trainer = Trainer(
            max_epochs=total_iterations,
            fast_dev_run=False,
            gpus=1,
            )
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []


# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)
val_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=val_loader)


# Create a Search Space
depth = 4
nodes_per_layer = 1
ops_per_node = 1
poolOps_per_node = 1
upsampleOps_per_node = 1
# model_space = NodeSpace(
#          depth=depth, 
#          nodes_per_layer=nodes_per_layer,
#          ops_per_node=ops_per_node, 
#          poolOps_per_node=poolOps_per_node, 
#          upsampleOps_per_node=upsampleOps_per_node
#         )

model_space = UNetSpaceMT()

# Select the Search Strategy
# search_strategy = strategy.DARTS()
# # search_strategy = strategy.ENAS()
# search_strategy = strategy.GumbelDARTS()
# # search_strategy = strategy.RandomOneShot()

# Select a Search Strategy
search_strategy = strategy.Random(dedup=True)
# search_strategy = strategy.TPE()
# search_strategy = strategy.RegularizedEvolution(dedup=True)

# fast_dev_run=False
print(f'\n\n----------------------------------\n')
print(f'Configration:\n')
print(f'total_iterations: {total_iterations}')
print(f'resolution: {resolution}')
print(f'noise_type: {noise_type}')
print(f'noise_level: {noise_level}')

print(f'\n-------++++++++++++++++++++---------\n')

print(f'depth: {depth}')
print(f'nodes_per_layer: {nodes_per_layer}')
print(f'ops_per_node: {ops_per_node}')
# print(f'poolOps_per_node: {poolOps_per_node}')
# print(f'upsampleOps_per_node: {upsampleOps_per_node}')

print(f'\n-------++++++++++++++++++++---------\n')
print(f'strategy: {strategy.__class__.__name__}')

print(f'\n----------------------------------\n\n')



config = RetiariiExeConfig(execution_engine='oneshot')
experiment = RetiariiExperiment(model_space, evaluator=lightning, strategy=search_strategy)
experiment.run(config)

NameError: name 'np' is not defined

In [None]:
# retrain top model
from search_space.node_space import exportedModel
from search_eval.eval_no_search_SGLD_ES import Eval_SGLD_ES

# construct output and retrain
exported_arch = experiment.export_top_models()[0]
# extract value from key -- pool 0
print("--------------------")
print("--------------------")
for i in range(depth):
    print(f'pool {i+1}: ', exported_arch[f'pool {i}/op_1_0'])
    print(f'encoder {i+1}: ', exported_arch[f'encoder {i}/op_1_0'])
    print("--------------------")
for i in range(depth):
    print(f'upsample {i+1}: ', exported_arch[f'upsample {i}/op_1_0'])
    print(f'decoder {i+1}: ', exported_arch[f'decoder {i}/op_1_0'])
    print("--------------------")
print("--------------------\n\n\n")

In [None]:
# stop experiment and clear cache
experiment.stop()
torch.cuda.empty_cache()

In [1]:
from search_space.unetAttention import UNetWithAttention
model = UNetWithAttention()
model.test()

input shape: torch.Size([1, 1, 64, 64])
after in layer: torch.Size([1, 64, 64, 64])
after pool 0: torch.Size([1, 64, 32, 32])
after enc 0: torch.Size([1, 128, 32, 32])

after attention 0: torch.Size([1, 128, 32, 32])


after pool 1: torch.Size([1, 128, 16, 16])
after enc 1: torch.Size([1, 256, 16, 16])

after pool 2: torch.Size([1, 256, 8, 8])
after enc 2: torch.Size([1, 512, 8, 8])

after attention 2: torch.Size([1, 512, 8, 8])


after pool 3: torch.Size([1, 512, 4, 4])
after enc 3: torch.Size([1, 1024, 4, 4])

after upsample 0: torch.Size([1, 1024, 6, 6])
after dec 0: torch.Size([1, 512, 6, 6])

after attention 0: torch.Size([1, 512, 6, 6])


after upsample 1: torch.Size([1, 512, 10, 10])
after dec 1: torch.Size([1, 256, 10, 10])

after upsample 2: torch.Size([1, 256, 18, 18])
after dec 2: torch.Size([1, 128, 18, 18])

after attention 2: torch.Size([1, 128, 18, 18])


after upsample 3: torch.Size([1, 128, 34, 34])
after dec 3: torch.Size([1, 64, 34, 34])

torch.Size([1, 1, 34, 34])
