# GPU Check

In [None]:
import GPUtil
GPUs = GPUtil.getGPUs()
for gpu in GPUs:
  print(gpu.name, gpu.memoryTotal)

# Imports

In [None]:
from search_eval.eval_generic import SGLDES
from search_eval.optimizer.SingleImageDataset import SingleImageDataset
from search_eval.utils.common_utils import *
from search_space.unetspaceOS import UNetSpace

from nni import trace
import nni.retiarii.strategy as strategy
import nni.retiarii.serializer as serializer

from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.evaluator.pytorch import Lightning, Trainer
from nni.retiarii.evaluator.pytorch.lightning import DataLoader

import torch

torch.cuda.empty_cache()
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
print('CUDA available: {}'.format(torch.cuda.is_available()))

# Strategy

In [None]:
# Select the Search Strategy
# search_strategy = strategy.DARTS()
# search_strategy = strategy.ENAS()
search_strategy = strategy.GumbelDARTS()
# search_strategy = strategy.RandomOneShot()

# GumbleDARTS

In [None]:

search_strategy = strategy.GumbelDARTS()

total_iterations = 4000

resolution = 64
noise_type = 'gaussian'
noise_level = .15
img_id = np.random.randint(0, 50)

phantom =       np.load(f'/home/joe/nas-for-dip/phantoms/ground_truth/{resolution}/{img_id}.npy')
phantom_noisy = np.load(f'/home/joe/nas-for-dip/phantoms/{noise_type}/res_{resolution}/nl_{noise_level}/p_{img_id}.npy')

learning_rate = 0.11
buffer_size = 1000 # need to tinker with to get right
patience = 1000 # need to tinker with to get right
weight_decay = 5e-7
show_every = 200
report_every = 25

# Create the lightning module
module = SGLDES(
                phantom=phantom, 
                phantom_noisy=phantom_noisy,
                
                learning_rate=learning_rate, # consider .01
                buffer_size=buffer_size,
                patience=patience,
                weight_decay= weight_decay,

                show_every=show_every,
                report_every=report_every,
                HPO=False,
                NAS=True,
                OneShot=True,
                SGLD_regularize=False,
                ES=False,
                )

# Create a PyTorch Lightning trainer
trainer = Trainer(
            max_epochs=total_iterations,
            fast_dev_run=False,
            gpus=1,
            )
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []


# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(phantom_noisy, num_iter=1), batch_size=1) # previously put phantom in here, testing if this is correct
val_loader = DataLoader(SingleImageDataset(phantom_noisy, num_iter=1), batch_size=1) # this may be negligable given how the forward pass is defined

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=val_loader)


model_space = UNetSpace(
         C_in=1, 
         C_out=1, 
         depth=4, 
         nodes_per_layer=2, # accept only 1 or 2,
         ops_per_node=4,
         use_attention=True,
        )

config = RetiariiExeConfig(execution_engine='oneshot')
experiment = RetiariiExperiment(model_space, evaluator=lightning, strategy=search_strategy)
experiment.run(config)

In [None]:
experiment.stop()

In [None]:

exported_arch = experiment.export_top_models(formatter='dict')

exported_arch[0]

# DARTS

In [None]:


total_iterations = 4000

resolution = 64
noise_type = 'gaussian'
noise_level = .09
img_id = np.random.randint(0, 50)

phantom =       np.load(f'/home/joe/nas-for-dip/phantoms/ground_truth/{resolution}/{img_id}.npy')
phantom_noisy = np.load(f'/home/joe/nas-for-dip/phantoms/{noise_type}/res_{resolution}/nl_{noise_level}/p_{img_id}.npy')

learning_rate = 0.11
buffer_size = 1000 # need to tinker with to get right
patience = 1000 # need to tinker with to get right
weight_decay = 5e-7
show_every = 200
report_every = 25

# Create the lightning module
module = SGLDES(
                phantom=phantom, 
                phantom_noisy=phantom_noisy,
                
                learning_rate=learning_rate, # consider .01
                buffer_size=buffer_size,
                patience=patience,
                weight_decay= weight_decay,

                show_every=show_every,
                report_every=report_every,
                HPO=False,
                NAS=True,
                OneShot=True,
                SGLD_regularize=False,
                ES=False,
                )

# Create a PyTorch Lightning trainer
trainer = Trainer(
            max_epochs=total_iterations,
            fast_dev_run=False,
            gpus=1,
            )
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []


# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(phantom_noisy, num_iter=1), batch_size=1) # previously put phantom in here, testing if this is correct
val_loader = DataLoader(SingleImageDataset(phantom_noisy, num_iter=1), batch_size=1) # this may be negligable given how the forward pass is defined

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=val_loader)


model_space = UNetSpace(
         C_in=1, 
         C_out=1, 
         depth=4, 
         nodes_per_layer=2, # accept only 1 or 2,
         ops_per_node=4,
         use_attention=True,
        )

search_strategy = strategy.DARTS()
config = RetiariiExeConfig(execution_engine='oneshot')
experiment = RetiariiExperiment(model_space, evaluator=lightning, strategy=search_strategy)
experiment.run(config)

In [None]:

exported_arch = experiment.export_top_models(formatter='dict')

exported_arch[0]

In [None]:
experiment.stop()