## [Multi-Scale Energy (MuSE) Framework for Inverse Problems in Imaging](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10645311)
Jyothi Rikhab Chand, Mathews Jacob

# Goal of this Notebook:
To train an implicit multi-scale energy model to represent the negative log-prior  using the [denoising score matching](https://ieeexplore.ieee.org/abstract/document/6795935) technique on an example MRI dataset. 

# Required libraries:
- numpy
- torch
- matplotlib
- os
  

In [1]:
# Preliminaries
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from tqdm.notebook import tqdm
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


#Load dataloader
from data_builder import DataGenBrain

#Load save model script
from saveModels import saveHelper

#Load energy model script
from energy_model import *
from network_unet import UNetRes


torch.cuda.empty_cache()

## Load Train Settings

In [2]:
import json
gpu = torch.device('cuda')
load_train_setting = open("settings.json")
ts = json.load(load_train_setting)
save_dir = "Models/Example_data/"
save_model = saveHelper(save_dir,"Training")
n_epochs = ts["epochs"]
epochs = np.arange(n_epochs)
rate = ts["rate"]
save_every_N_epoch = ts["save_every_N_epoch"]
t_loss = []

std_fixed = torch.linspace(ts["std_start"],ts["std_end"],ts["batch_size"])

Models/Example_data/Exists
Directory Exist; removing it


## Load Dataset

In [3]:
data_set = DataGenBrain(start_sub=ts["startSubj"], num_sub = ts["endSubj"],device = device, acc = ts["acc"])
batch_size = ts["batch_size"]
data_loader = DataLoader(dataset = data_set, batch_size = batch_size, shuffle = True)


T2_big_pickle - loading 1 of 1 subjects
ksp: (10, 12, 320, 320) 	csm: (10, 12, 320, 320)
Loaded dataset of 10 slices



## Training loop

In [4]:
#loading the energy net for training
net = UNetRes(in_nc=ts["input_channel"], out_nc = ts["output_channel"],nc = [ts["ch0"],ts["ch1"], ts["ch2"], ts["ch3"]])
net = net.to(device)
energy_net = EBM(net)
energy_net = energy_net.to(device)


optimizer = Adam(energy_net.parameters(), lr=rate)
#scheduler = ReduceLROnPlateau(optimizer, 'min')


with tqdm(total= n_epochs, position=0) as pbar:

    for epoch in epochs:
        avg_loss = torch.zeros(1,dtype=torch.double)
        num_items = 0
        
        for tstOrg, b, tstCsm, tstMask, idx in data_loader:
            tstOrg = torch.squeeze(tstOrg,1)
            tstOrg = tstOrg.type(torch.complex64)
            tstOrg = tstOrg.to(device) 
            
            
            if(tstOrg.shape[0] == batch_size):            
                perturbation = ts["pert"]* torch.rand(1) 
                std = std_fixed + perturbation
                std = std.cuda()
                random_var = torch.randn_like(tstOrg)
                noise = random_var * std[:, None, None, None]
                E,score = energy_net.giveScore(tstOrg+noise)
                std_new = std + 1e-4*(std==0)
                #Denoising score matching loss
                loss = torch.mean(torch.sum((-score +noise).abs()**2, dim=(1,2,3)))
                optimizer.zero_grad()
                loss.backward()    
                optimizer.step()
                
                with torch.no_grad():     
                    avg_loss = avg_loss + loss.item() * tstOrg.shape[0]       
                    num_items = num_items+ tstOrg.shape[0] 
        
        t_loss.append((avg_loss/num_items).item())
        #scheduler.step(avg_loss)
        if(np.mod(epoch,save_every_N_epoch)==0):
            pbar.set_description(f"loss: {avg_loss.item()/num_items: .6f}")
            pbar.update(save_every_N_epoch )
            save_model.saveModel(energy_net,epoch)
            save_model.write("loss",avg_loss.item()/num_items,epoch.item())





  0%|          | 0/500 [00:00<?, ?it/s]

## Evaluating denoising performance of the above energy net

In [None]:
#Testing the denoising performance of energy net 
energy_net = energy_net.eval()
std = 0.05*torch.ones(1)
std = std.cuda()
img,b0,tstCsm,tstMask,idx = data_set[2]
img = img.type(torch.complex64)
img = img.to(device) 
random_var = torch.randn_like(img)
noise = random_var * std[:, None, None, None]
noisy = img+noise
cost,Egrad = energy_net.giveScore(noisy)
denoised = noisy-((Egrad))
prediction = np.squeeze(torch.abs(denoised).detach().cpu().numpy())
del cost,Egrad


    
plt.figure(figsize=(8,5))
plt.subplot(231)
plt.imshow(img[0,0].abs().detach().cpu(),plt.cm.gray,vmax=0.5)
plt.title('Org')
plt.axis('off')
plt.subplot(232)
plt.imshow((noisy)[0,0].abs().detach().cpu(),plt.cm.gray,vmax=0.5)
plt.title('Noisy')
plt.axis('off')
plt.subplot(233)
plt.imshow((denoised)[0,0].abs().detach().cpu(),plt.cm.gray,vmax=0.5)
plt.title('Denoised')
plt.axis('off')
plt.subplot(235)
plt.imshow((img-noisy)[0,0].abs().detach().cpu())
plt.title('Org-Noisy')
plt.axis('off')
plt.subplot(236)
plt.imshow((img-denoised)[0,0].abs().detach().cpu())
plt.title('Org-Denoised')
plt.axis('off')