In [None]:
import skimage
import numpy as np
import torch
import torch.nn as nn
from torch import cuda, optim, tensor, zeros_like
from torch import device as torch_device
from torch.nn import L1Loss, MSELoss
from matplotlib import pyplot as plt


from darts.common_utils import *
from darts.phantom import generate_phantom, phantom_to_torch
from darts.noises import add_selected_noise
from darts.early_stop import EarlyStop, MSE, MAE

torch.cuda.empty_cache()

# base

In [None]:

device = torch_device('cuda' if cuda.is_available() else "cpu")
dtype = cuda.FloatTensor


model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=1, out_channels=1, init_features=64, pretrained=False)

buffer_size = 100
patience = 1000
num_iter = 7500
show_every = 1
lr = 0.00005

# reg_noise_std = 1./30. 
reg_noise_std = tensor(1./30.).type(dtype).to(device)
noise_type = 'gaussian'
noise_factor = 0.1
resolution= 6
n_channels = 1

raw_img_np = generate_phantom(resolution=resolution) # 1x64x64 np array
img_np = raw_img_np.copy() # 1x64x64 np array
img_torch = torch.tensor(raw_img_np, dtype=torch.float32).unsqueeze(0) # 1x1x64x64 torch tensor
img_noisy_torch = add_selected_noise(img_torch, noise_type=noise_type,noise_factor=noise_factor) # 1x1x64x64 torch tensor
img_noisy_np = img_noisy_torch.squeeze(0).numpy() # 1x64x64 np array

img_noisy_torch = img_noisy_torch.to(device)
net_input = get_noise(input_depth=1, spatial_size=raw_img_np.shape[1], noise_type=noise_type).type(dtype).to(device)

# Add synthetic noise
net = model.to(device)
net = net.type(dtype)

# Loss
criterion = MSELoss().type(dtype).to(device)

# Optimizer
p = get_params('net', net, net_input)  # network parameters to be optimized
optimizer = optim.Adam(p, lr=lr)

# Optimize

loss_history = []
psnr_history = []
ssim_history = []
variance_history = []
x_axis = []
earlystop = EarlyStop(size=buffer_size,patience=patience)
def closure(iterator):
    #DIP
    net_input_perturbed = net_input + zeros_like(net_input).normal_(std=reg_noise_std)
    r_img_torch = net(net_input_perturbed)
    total_loss = criterion(r_img_torch, img_noisy_torch)
    total_loss.backward()
    loss_history.append(total_loss.item())
    if iterator % show_every == 0:
        # evaluate recovered image (PSNR, SSIM)
        r_img_np = torch_to_np(r_img_torch)
        psnr = skimage.metrics.peak_signal_noise_ratio(img_np, r_img_np)
        temp_img_np = np.transpose(img_np,(1,2,0))
        temp_r_img_np = np.transpose(r_img_np,(1,2,0))
        data_range = temp_img_np.max() - temp_img_np.min()
        if n_channels == 1:
            multichannel = False
        else:
            multichannel = True
        ssim = skimage.metrics.structural_similarity(temp_img_np, temp_r_img_np, multichannel=multichannel, win_size=7, channel_axis=-1, data_range=data_range)
        psnr_history.append(psnr)
        ssim_history.append(ssim)
        
        #variance hisotry
        r_img_np = r_img_np.reshape(-1)
        earlystop.update_img_collection(r_img_np)
        img_collection = earlystop.get_img_collection()
        if iterator % (show_every*10) == 0:
            print(f'Iteration %05d    Loss %.4f' % (iterator, total_loss.item()) + '    PSNR %.4f' % (psnr) + '    SSIM %.4f' % (ssim) + '    Collection Size %.4f' % (int(len(img_collection))))
        if len(img_collection) == buffer_size:
            ave_img = np.mean(img_collection,axis = 0)
            variance = []
            for tmp in img_collection:
                variance.append(MSE(ave_img, tmp))
            cur_var = np.mean(variance)
            cur_epoch = iterator
            variance_history.append(cur_var)
            x_axis.append(cur_epoch)
            if earlystop.stop == False:
                earlystop.stop = earlystop.check_stop(cur_var, cur_epoch)
    if earlystop.stop:
        return "STOP"
    return total_loss
    
for iterator in range(num_iter):
    optimizer.zero_grad()
    early_stop = closure(iterator)
    optimizer.step()
    
    if iterator % (show_every*100) == 0:
        r_img_np = torch_to_np(net(net_input))
        plot_side_by_side(np.clip(img_np, 0, 1), np.clip(r_img_np, 0, 1), np.clip(img_noisy_np,0,1))

    if early_stop == "STOP":
        print("Early stopping triggered.")
        break


# next iter

In [10]:
import skimage
import numpy as np

from torch import cuda, optim, tensor, zeros_like
from torch import device as torch_device


from darts.common_utils import *
from darts.early_stop import EarlyStop, MSE, MAE
from darts.noises import add_selected_noise
from darts.phantom import generate_phantom, phantom_to_torch
from darts.space import SearchSpace


import nni
import torch
import nni.retiarii.strategy as strategy
from nni.retiarii import model_wrapper
import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F


from nni.experiment import Experiment
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.evaluator import FunctionalEvaluator

torch.cuda.empty_cache()

In [11]:
def preprocess_image(resolution, noise_type, noise_factor, input_img_np=None):
    """
    Generates an image, adds noise, and converts it to both numpy and torch tensors.

    Args:
    - resolution (int): Resolution for the phantom image.
    - noise_type (str): Type of noise to add.
    - noise_factor (float): Noise factor.
    - input_img_np (numpy.ndarray, optional): Input raw image in numpy format. If not provided, a new image will be generated.

    Returns:
    - img_np (numpy.ndarray): Original image in numpy format.
    - img_noisy_np (numpy.ndarray): Noisy image in numpy format.
    - img_torch (torch.Tensor): Original image in torch tensor format.
    - img_noisy_torch (torch.Tensor): Noisy image in torch tensor format.
    """
    if input_img_np is None:
        raw_img_np = generate_phantom(resolution=resolution) # 1x64x64 np array
    else:
        raw_img_np = input_img_np.copy()
        
    img_np = raw_img_np.copy() # 1x64x64 np array
    img_torch = torch.tensor(raw_img_np, dtype=torch.float32).unsqueeze(0) # 1x1x64x64 torch tensor
    img_noisy_torch = add_selected_noise(img_torch, noise_type=noise_type, noise_factor=noise_factor) # 1x1x64x64 torch tensor
    img_noisy_np = img_noisy_torch.squeeze(0).numpy() # 1x64x64 np array
    
    return img_np, img_noisy_np, img_torch, img_noisy_torch


In [12]:
# model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
#                        in_channels=1, out_channels=1, init_features=64, pretrained=False)

def main_evaluation(model_cls):
    device = torch_device('cuda' if cuda.is_available() else "cpu")
    dtype = cuda.FloatTensor if cuda.is_available() else torch.FloatTensor

    buffer_size = 100
    patience = 600
    num_iter = 1200
    show_every = 1
    lr = 0.00005

    reg_noise_std = tensor(1./30.).type(dtype).to(device)
    noise_type = 'gaussian'
    noise_factor = 0.1
    resolution= 6
    n_channels = 1

    img_np, _, _, img_noisy_torch = preprocess_image(resolution, noise_type, noise_factor)
    img_noisy_torch = img_noisy_torch.to(device)
    net_input = get_noise(input_depth=1, spatial_size=img_np.shape[1], noise_type=noise_type).type(dtype).to(device)

    # Add synthetic noise
    net = model_cls().to(device)
    net = net.type(dtype)

    # Loss
    criterion = nn.MSELoss().type(dtype).to(device)

    # Optimizer
    p = get_params('net', net, net_input)  # network parameters to be optimized
    optimizer = optim.Adam(p, lr=lr)

    # Optimize

    loss_history = []
    psnr_history = []
    ssim_history = []
    variance_history = []
    x_axis = []
    earlystop = EarlyStop(size=buffer_size,patience=patience)
    def closure(iterator):
        #DIP
        net_input_perturbed = net_input + zeros_like(net_input).normal_(std=reg_noise_std)
        r_img_torch = net(net_input_perturbed)
        total_loss = criterion(r_img_torch, img_noisy_torch)
        total_loss.backward()
        loss_history.append(total_loss.item())
        if iterator % show_every == 0:
            # evaluate recovered image (PSNR, SSIM)
            r_img_np = torch_to_np(r_img_torch)
            psnr = skimage.metrics.peak_signal_noise_ratio(img_np, r_img_np)
            temp_img_np = np.transpose(img_np,(1,2,0))
            temp_r_img_np = np.transpose(r_img_np,(1,2,0))
            data_range = temp_img_np.max() - temp_img_np.min()
            if n_channels == 1:
                multichannel = False
            else:
                multichannel = True
            ssim = skimage.metrics.structural_similarity(temp_img_np, temp_r_img_np, multichannel=multichannel, win_size=7, channel_axis=-1, data_range=data_range)
            psnr_history.append(psnr)
            ssim_history.append(ssim)
            
            #variance hisotry
            r_img_np = r_img_np.reshape(-1)
            earlystop.update_img_collection(r_img_np)
            img_collection = earlystop.get_img_collection()
            if iterator % (show_every*10) == 0:
                print(f'Iteration %05d    Loss %.4f' % (iterator, total_loss.item()) + '    PSNR %.4f' % (psnr) + '    SSIM %.4f' % (ssim))
                nni.report_intermediate_result(psnr)
            if len(img_collection) == buffer_size:
                ave_img = np.mean(img_collection,axis = 0)
                variance = []
                for tmp in img_collection:
                    variance.append(MSE(ave_img, tmp))
                cur_var = np.mean(variance)
                cur_epoch = iterator
                variance_history.append(cur_var)
                x_axis.append(cur_epoch)
                if earlystop.stop == False:
                    earlystop.stop = earlystop.check_stop(cur_var, cur_epoch)
        if earlystop.stop:
            # Report final PSNR to NNI
            nni.report_final_result(psnr)
            return "STOP"
        return total_loss, psnr
        
    for iterator in range(num_iter):
        optimizer.zero_grad()
        early_stop, psnr = closure(iterator)
        optimizer.step()

        if early_stop == "STOP":
            print("Early stopping triggered.")
            break
    
    if earlystop.stop != "STOP":
        nni.report_final_result(psnr)
    


In [13]:

# search space
model_space = SearchSpace(in_channels=1, out_channels=1)
evaluator = FunctionalEvaluator(main_evaluation)

# search strategy
search_strategy = strategy.Random(dedup=True)

# experiment
exp = RetiariiExperiment(model_space, evaluator, [], search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_code_directory = 'C:/Users/Public/Public_VS_Code/NAS_test'
exp_config.experiment_working_directory = 'C:/Users/Public/nni-experiments'

exp_config.max_trial_number = 12   # spawn 50 trials at most
exp_config.trial_concurrency = 2  # will run two trials concurrently

exp_config.trial_gpu_number = 1 # will run 1 trial(s) concurrently
exp_config.training_service.use_active_gpu = True

# Execute
exp.run(exp_config, 8081)

[2023-08-15 22:07:07] [32mCreating experiment, Experiment ID: [36mej2c5i91[0m


2023-08-15 22:07:07,385 - INFO - Creating experiment, Experiment ID: ${CYAN}ej2c5i91


[2023-08-15 22:07:07] [32mStarting web server...[0m


2023-08-15 22:07:07,436 - INFO - Starting web server...


[2023-08-15 22:07:08] [32mSetting up...[0m


2023-08-15 22:07:08,541 - INFO - Setting up...


[2023-08-15 22:07:08] [32mWeb portal URLs: [36mhttp://169.254.138.100:8081 http://169.254.67.161:8081 http://169.254.50.13:8081 http://10.0.0.172:8081 http://127.0.0.1:8081[0m


2023-08-15 22:07:08,719 - INFO - Web portal URLs: ${CYAN}http://169.254.138.100:8081 http://169.254.67.161:8081 http://169.254.50.13:8081 http://10.0.0.172:8081 http://127.0.0.1:8081


[2023-08-15 22:07:08] [32mDispatcher started[0m


2023-08-15 22:07:08,806 - INFO - Dispatcher started


[2023-08-15 22:07:08] [32mStart strategy...[0m


2023-08-15 22:07:08,846 - INFO - Start strategy...


[2023-08-15 22:07:08] [32mSuccessfully update searchSpace.[0m


2023-08-15 22:07:08,876 - INFO - Successfully update searchSpace.


[2023-08-15 22:07:08] [32mRandom search running in fixed size mode. Dedup: on.[0m


2023-08-15 22:07:08,879 - INFO - Random search running in fixed size mode. Dedup: on.


[2023-08-15 22:18:07] [32mStrategy exit[0m


2023-08-15 22:18:07,552 - INFO - Strategy exit


[2023-08-15 22:18:07] [32mSearch process is done, the experiment is still alive, `stop()` can terminate the experiment.[0m


2023-08-15 22:18:07,581 - INFO - Search process is done, the experiment is still alive, `stop()` can terminate the experiment.


In [14]:
experiment = Experiment.connect(8081)
experiment.stop()

[2023-08-15 22:23:13] [32mConnect to port 8081 success, experiment id is ej2c5i91, status is DONE.[0m


2023-08-15 22:23:13,127 - INFO - Connect to port 8081 success, experiment id is ej2c5i91, status is DONE.


[2023-08-15 22:23:13] [32mStopping experiment, please wait...[0m


2023-08-15 22:23:13,129 - INFO - Stopping experiment, please wait...


[2023-08-15 22:23:13] [32mExperiment stopped[0m


2023-08-15 22:23:13,162 - INFO - Experiment stopped


[2023-08-15 22:23:13] [32mDispatcher exiting...[0m


2023-08-15 22:23:13,168 - INFO - Dispatcher exiting...


[2023-08-15 22:23:15] [32mDispatcher terminiated[0m


2023-08-15 22:23:15,106 - INFO - Dispatcher terminiated
