# GPU Check

In [3]:
import GPUtil
GPUs = GPUtil.getGPUs()
for gpu in GPUs:
  print(gpu.name, gpu.memoryTotal)

Tesla P4 7680.0


# Imports

In [1]:
from search_eval.eval_OneShot import Eval_OS
from search_eval.optimizer.SingleImageDataset import SingleImageDataset
from search_eval.utils.common_utils import *
from search_space.search_space import DARTS_UNet

from nni import trace
import nni.retiarii.strategy as strategy
import nni.retiarii.serializer as serializer

from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.evaluator.pytorch import Lightning, Trainer
from nni.retiarii.evaluator.pytorch.lightning import DataLoader
from nni.retiarii.strategy import DARTS as DartsStrategy

import torch

torch.cuda.empty_cache()
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
print('CUDA available: {}'.format(torch.cuda.is_available()))

# Strategy

In [None]:
# Select the Search Strategy
strategy = DartsStrategy()
# strategy = strategy.DartsStrategy()
# strategy = strategy.ENAS()
# strategy = strategy.GumbelDARTS()
# strategy = strategy.RandomOneShot()

# oneshot

In [None]:
total_iterations = 1200

resolution = 64
noise_type = 'gaussian'
noise_level = '0.09'
phantom =       np.load(f'/home/joe/nas-for-dip/phantoms/ground_truth/{resolution}/{45}.npy')
phantom_noisy = np.load(f'/home/joe/nas-for-dip/phantoms/{noise_type}/res_{resolution}/nl_{noise_level}/p_{45}.npy')

# Create the lightning module
module = Eval_OS(
                phantom=phantom, 
                phantom_noisy=phantom_noisy,
                lr=0.01, 
                buffer_size=100,
                patience=1000,
                weight_decay= 5e-7
                )
# Create a PyTorch Lightning trainer
trainer = Trainer(
            max_epochs=total_iterations,
            fast_dev_run=False,
            gpus=1,
            )
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []


# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)
val_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=val_loader)


# Create a Search Space
model_space = DARTS_UNet(depth=3)

# fast_dev_run=False

config = RetiariiExeConfig(execution_engine='oneshot')
experiment = RetiariiExperiment(model_space, evaluator=lightning, strategy=strategy)
experiment.run(config)

In [None]:
# stop experiment and clear cache
experiment.stop()
torch.cuda.empty_cache()

In [None]:

exported_arch = experiment.export_top_models()

exported_arch
