In [6]:

import nni
import torch
import nni.retiarii.strategy as strategy
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn.functional as F

from search_eval.utils import main_evaluation, psnr
from search_space.space import SearchSpace

from nni.experiment import Experiment
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.evaluator.pytorch import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10


In [8]:

class DeepImagePriorDenoising(pl.LightningModule):
    def __init__(self, model_cls):
        super(DeepImagePriorDenoising, self).__init__()
        self.model = model_cls()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        noisy_img, clean_img = batch
        output = self.model(torch.randn(noisy_img.shape).to(self.device))
        loss = F.mse_loss(output, noisy_img)
        return {"loss": loss, "noisy_img": noisy_img, "clean_img": clean_img}

    def validation_step(self, batch, batch_idx):
        noisy_img, clean_img = batch
        denoised_output = self.model(noisy_img)
        psnr_value = psnr(clean_img, denoised_output)
        self.log('val_psnr', psnr_value)
        return {"val_psnr": psnr_value}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def on_validation_epoch_end(self):
        self.log('avg_val_psnr', self.trainer.callback_metrics['val_psnr'].mean())

    def teardown(self, stage):
        if stage == 'fit':
            nni.report_final_result(self.trainer.callback_metrics['avg_val_psnr'].item())


In [7]:
transform = nni.trace(transforms.ToTensor)
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(dataset, batch_size=100)


Files already downloaded and verified


In [None]:
lightning_module = DeepImagePriorDenoising(SearchSpace)  # Replace MyModelCls with your model's class name
trainer = pl.Trainer(max_epochs=10)

lightning = pl.Lightning(lightning_module, trainer, train_dataloaders=train_dataloader)

experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)


In [None]:
experiment.run()

In [None]:
# search space
model_space = SearchSpace()
evaluator = FunctionalEvaluator(main_evaluation)

# search strategy
search_strategy = strategy.Random(dedup=True)

# experiment
exp = RetiariiExperiment(model_space, evaluator, [], search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_code_directory = 'C:/Users/Public/Public_VS_Code/NAS_test'
exp_config.experiment_working_directory = 'C:/Users/Public/nni-experiments'

exp_config.max_trial_number = 12   # spawn 50 trials at most
exp_config.trial_concurrency = 2  # will run two trials concurrently

exp_config.trial_gpu_number = 1 # will run 1 trial(s) concurrently
exp_config.training_service.use_active_gpu = True

# Execute
exp.run(exp_config, 8081)