# GPU Check

In [None]:
import GPUtil
GPUs = GPUtil.getGPUs()
for gpu in GPUs:
  print(gpu.name, gpu.memoryTotal)

# Imports

In [None]:
from eval_sgld import LightningEvalSearchSGLD, SingleImageDataset
from darts.common_utils import *
from darts.phantom import generate_phantom


from nni.retiarii.evaluator.pytorch import Lightning, Trainer
from nni.retiarii.evaluator.pytorch.lightning import DataLoader

import torch
torch.cuda.empty_cache()

dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
print('CUDA available: {}'.format(torch.cuda.is_available()))


# Execute

In [None]:


model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=1, out_channels=1, init_features=64, pretrained=False)

num_iter=1
total_iterations = 25000

resolution = 6
max_depth = resolution - 1
phantom = generate_phantom(resolution=resolution)

# Create the lightning module
module = LightningEvalSearchSGLD(
                phantom=phantom, 
                num_iter=num_iter,
                lr=0.01, # note a smaller learning rate affecs the SGLD, so overfitting happens FASTER at LOWER learning rates (start with 0.01)
                noise_type='gaussian', 
                noise_factor=0.09,
                resolution=resolution,
                burnin_iter=350,
                model_cls=model
                )

# Create a PyTorch Lightning trainer
trainer = Trainer(
            # callbacks=[module.checkpoint_callback],
            max_epochs=total_iterations
            fast_dev_run=False,
            gpus=1,
            )
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []


# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)
val_loader = DataLoader(SingleImageDataset(phantom, num_iter=1), batch_size=1)

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=val_loader)
lightning.fit(model)