# GPU Check

In [None]:
import GPUtil
GPUs = GPUtil.getGPUs()
for gpu in GPUs:
  print(gpu.name, gpu.memoryTotal)

# Imports

In [None]:
from search_eval.utils.common_utils import *
from search_eval.eval_no_search_SGLD_ES import Eval_SGLD_ES, SingleImageDataset

from nni.retiarii.evaluator.pytorch import Lightning, Trainer
from nni.retiarii.evaluator.pytorch.lightning import DataLoader

import numpy as np
import torch
torch.cuda.empty_cache()

dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
print('CUDA available: {}'.format(torch.cuda.is_available()))

# Execute

In [None]:
# INPUTS

# Non HPO inputs
total_iterations = 1400
show_every = 50

# HPO inputs
learning_rate = 0.08 #  note a smaller learning rate affecs the SGLD, so overfitting happens FASTER at LOWER learning rates (start with 0.01)
patience = 200
buffer_size = 500
weight_decay=5e-8

resolution = 64
noise_level = 0.09
noise_type = 'gaussian'

phantom = np.load(f'phantoms/ground_truth/{resolution}/{45}.npy')
phantom_noisy= np.load(f'phantoms/{noise_type}/res_{resolution}/nl_{noise_level}/p_{45}.npy')

model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=1, out_channels=1, init_features=64, pretrained=False)

print(f"\n\n----------------------------------")
print(f'Experiment Configuration:')
print(f'\tTotal Iterations: {total_iterations}')

print(f'\tPatience: {patience}')
print(f'\tBuffer Size: {buffer_size}')
print(f'\tLearning Rate: {learning_rate}')
print(f'\tWeight Decay: {weight_decay}')

print(f'\tImage Resolution: {resolution}')
print(f'\tPlotting every {show_every} iterations')
print(f"----------------------------------\n\n")

# Create the lightning module
module = Eval_SGLD_ES(
                phantom=phantom, 
                phantom_noisy=phantom_noisy,

                learning_rate=learning_rate, 
                patience=patience,
                buffer_size=buffer_size,
                weight_decay=weight_decay,
                
                model=model, # model defaults to U-net 
                show_every=show_every,
                )

# Create a PyTorch Lightning trainer
trainer = Trainer(
            max_epochs=total_iterations,
            fast_dev_run=False,
            gpus=1,
            )
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []


# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=None)
lightning.fit(model)

: 