In [1]:
import os
import torch
from scipy.io import loadmat

mask_dir = os.path.join(os.getcwd(), 'masks')

mask_pths = [os.path.join(mask_dir, mask_name) for mask_name in os.listdir('masks')]

csmri_mask = [loadmat(mask_pth)['mask'] for mask_pth in mask_pths]


class GaussianModelD:  # discrete noise levels
    def __init__(self):
        super().__init__()
        
    def __call__(self, x, sigma):
        sigma = sigma / 255.
        y = x + torch.randn(*x.shape) * sigma
              
        return y
    
    
sigmas = [5, 10, 15]
noise_model = GaussianModelD()
    
    

In [3]:
import os

os.chdir('..')

In [4]:
os.getcwd()

'/Users/joesh/D4IR'

In [2]:
Ks = [4, 6, 8]

In [3]:
mask_dir = os.path.join(os.getcwd(), 'pr_masks')

mask_pths = [os.path.join(mask_dir, mask_name) for mask_name in os.listdir('pr_masks')]

pr_mask = [loadmat(mask_pth)['mask'] for mask_pth in mask_pths]

In [4]:
import numpy as np
import torch.nn.functional as F
import random
from src import utils


def build_csmri_observation(target, noise_val):
    acc = random.choice(csmri_mask)
    acc = torch.from_numpy(acc.astype(np.bool_))
    y0 = utils.fft(target)
    
    y0 = noise_model(y0, noise_val)
    y0[:, : ,~acc] = 0
    Aty0 = utils.ifft(y0)
    x0 = Aty0.clone().detach()
    output = Aty0.clone().detach().real
    dic = {'y0': y0, 
            'x0': x0, 
            'ATy0': Aty0, 
            'gt': target, 
            'output': output, 
            'mask': acc}
    return dic



def build_spi_observation(target, K):
    with torch.no_grad():
        y0 = utils.spi_forward(target, K, K**2, 1)
        x0 = F.avg_pool2d(y0, K)
    y0 = y0
    x0 = x0       
    x0 = x0.clone().detach()
    K = torch.ones_like(target) * K / 10
    dic = {'x0': x0, 'gt': target, 'K': K, 'y0': y0, 'output': x0.clone().detach()}
    return dic
        

In [5]:
x = torch.rand(1, 1, 128, 128)


dic = build_spi_observation(x, 4)

for key, value in dic.items():
    print(key)
    print(dic[key].shape)

x0
torch.Size([1, 1, 128, 128])
gt
torch.Size([1, 1, 128, 128])
K
torch.Size([1, 1, 128, 128])
y0
torch.Size([1, 1, 512, 512])
output
torch.Size([1, 1, 128, 128])


In [6]:
pr_mask[2].shape

(4, 128, 128)

In [7]:
def build_pr_observation(target, alpha_val):
        mask = random.choice(pr_mask)
        C = mask.shape[0]
        mask = torch.from_numpy(mask).reshape(1, C, 128, 128)

        y0 = utils.cdp_forward(torch.complex(target, torch.zeros_like(target)),
                                mask).abs()[0]
        y0 = noise_model(y0, alpha_val)
        x0 = torch.ones_like(target)
        #sigma_n = x0 * noise_lev
        dic = {'y0': y0, 'x0': x0, 'output': x0, 'gt': target, 'mask': mask}
        return dic

build_spi_observation(x, 4)['K']

tensor([[[[0.4000, 0.4000, 0.4000,  ..., 0.4000, 0.4000, 0.4000],
          [0.4000, 0.4000, 0.4000,  ..., 0.4000, 0.4000, 0.4000],
          [0.4000, 0.4000, 0.4000,  ..., 0.4000, 0.4000, 0.4000],
          ...,
          [0.4000, 0.4000, 0.4000,  ..., 0.4000, 0.4000, 0.4000],
          [0.4000, 0.4000, 0.4000,  ..., 0.4000, 0.4000, 0.4000],
          [0.4000, 0.4000, 0.4000,  ..., 0.4000, 0.4000, 0.4000]]]])

In [27]:
from src.pnp.denoiser import UNetDenoiser2D

class PrSolverMixin:
    def _forward_pr(self, env_ob, parameters):
        sigma_d, mu, tau = parameters['sigma_d'], parameters['mu'], parameters['tau']
        state = env_ob['output']
        y0 = env_ob['y0']
        mask = env_ob['mask']
        
        x, z, u = torch.chunk(state, chunks = 3, dim = 1)
        B = x.shape[0]
        
        for i in range(self.iter_num):
            _sigma_d = sigma_d[:, i]
            _mu = mu[:, i]
            _tau = tau[:, i]
            temp_var = (z - u)
            x = self.denoiser(temp_var.real, _sigma_d)
            
            _tau = _tau.view(B, 1, 1, 1)
            _mu = _mu.view(B, 1, 1, 1)
            
            Az = utils.cdp_forward(z, mask)  
            y_hat = Az.abs()
            meas_err = y_hat - y0
            gradient_forward = meas_err/y_hat * Az
            gradient = utils.cdp_backward(gradient_forward, mask)
            z = z - _tau * (gradient + _mu * (z - (x + u)))
            
        return torch.cat((x, z, u), dim = 1)
    
class CsmriSolverMixin:
    def _forward_csmri(self, env_ob, parameters):
        mu, sigma_d = parameters['mu'], parameters['sigma_d']
        state = env_ob['output']
        mask = env_ob['mask']
        y0 = env_ob['y0']
        
        B = state.shape[0]
        
        x, z, u = torch.chunk(state, chunks = 3, dim = 1)
        #mask = mask.unsqueeze(1) 
        
        
        for i in range(self.iter_num):
            _sigma_d = sigma_d[:, i]
            _mu = mu[:, i] 
            temp_var = (z - u)
            x = self.denoiser(temp_var.real, _sigma_d)
            z = utils.fft(x + u)
            _mu = _mu.view(B, 1, 1, 1)
            temp = ((_mu * z.clone()) + y0)/(1+ _mu)
            z[mask] = temp[mask]
            z = utils.ifft(z)
            
            u = u + x - z
        
        return torch.cat((x, z, u), dim = 1)
    
    
class SpiSolverMixin:
    def _forward_spi(self, env_ob, parameters):
        mu, sigma_d = parameters['mu'], parameters['sigma_d']
        state = env_ob['output']
        K = env_ob['K']
        
        x, z, u = torch.chunk(state, chunks = 3, dim = 1)
        
        B = state.shape[0]
        
        K = K[:, 0, 0, 0].view(B, 1, 1, 1) * 10 
        K1 = env_ob['x0'] * (K ** 2)
        
        for i in range(self.iter_num):
            _sigma_d = sigma_d[:, i]
            _mu = mu[:, i]        
            _mu = _mu.view(B, 1, 1, 1)

            # z step (x + u)
            z = utils.spi_inverse(x + u, K1, K, _mu)

            # u step
            u = u + x - z

            # x step
            x = self.denoiser((z - u).real, _sigma_d) 
        
        return torch.cat((x, z, u), dim = 1) 
    
class PnPSolver(PrSolverMixin, SpiSolverMixin, CsmriSolverMixin):
    def __init__(self) -> None:
        super().__init__()
        self.denoiser = UNetDenoiser2D()
        self.iter_num = 6
        
    def forward(self, env_ob, parameters):
        return self._forward_spi(env_ob, parameters)
    
env_ob = {}
parameters = {}
    
env_ob['output'] = torch.rand(1, 3, 128, 128)
env_ob['x0'] = torch.rand(1, 1, 128, 128)
env_ob['y0'] = torch.rand(1, 1, 128, 128)
env_ob['K'] = torch.rand(1, 1, 128, 128)
env_ob['mask'] = torch.from_numpy(csmri_mask[0].reshape(1, 1, 128, 128))

#env_ob['mask'] = torch.from_numpy(pr_mask[0]).to(dtype=torch.complex64).reshape(1, 2, 128, 128)

parameters['sigma_d'] = torch.rand(1, 6)
parameters['mu'] = torch.rand(1, 6)
parameters['tau'] = torch.rand(1, 6)
#parameters['tau'] = torch.rand(1, 6)

In [28]:
solver = PnPSolver()

solver.forward(env_ob, parameters)

tensor([[[[ 0.8680,  0.9779,  0.9051,  ...,  0.7032,  0.6159,  0.6657],
          [ 1.0000,  1.0000,  0.9062,  ...,  0.7946,  0.7446,  0.6229],
          [ 0.9484,  0.8183,  0.6790,  ...,  0.8326,  0.8503,  0.6723],
          ...,
          [ 0.6734,  0.6201,  0.5191,  ...,  0.8732,  0.5589,  0.7157],
          [ 0.6691,  0.5972,  0.5415,  ...,  0.8157,  0.5747,  0.7602],
          [ 0.4294,  0.3219,  0.4631,  ...,  0.8307,  0.7054,  0.7777]],

         [[ 0.6000,  1.0000,  1.0000,  ...,  0.7890,  0.5312,  0.4367],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  0.6214,  0.2970],
          [ 1.0000,  1.0000,  0.0757,  ...,  0.9856,  1.0000,  0.5001],
          ...,
          [ 0.1434,  0.8255,  0.0199,  ...,  1.0000,  0.4678,  1.0000],
          [ 0.9587,  0.2132,  0.3948,  ...,  1.0000,  0.4453,  1.0000],
          [ 0.2154,  0.0843,  0.2637,  ...,  1.0000,  0.6300,  1.0000]],

         [[ 0.1491, -0.8104, -0.5713,  ..., -0.0453, -0.0125, -0.0502],
          [-0.7976, -0.6402, -

In [38]:
network = nn.Sequential(
        nn.Conv2d(4, 8, kernel_size=5, stride=2),
        nn.ReLU(),

        nn.Conv2d(32, 32, kernel_size=5, stride=2),
        nn.ReLU(),

        nn.Conv2d(32, 32, kernel_size=4, stride=2),
        nn.ReLU(),

        nn.Flatten(),

        nn.Linear(32 , 128)
    )

x = torch.rand(1, 4, 128 ,128)

network(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x5408 and 32x128)

In [40]:
import torch.nn as nn

state_encoder = nn.Sequential(
                                nn.Conv2d(4, 8, 5, stride = 4, padding = 0), nn.ReLU(),
                                nn.Conv2d(8, 16, 5, stride = 2, padding = 0), nn.ReLU(),
                                nn.Conv2d(16, 16, 4, stride = 1, padding = 0), nn.ReLU(),
                                nn.Flatten(), nn.Linear(1936, 128))


x = torch.rand(1, 4, 128 ,128)

state_encoder(x)



tensor([[ 0.0245, -0.0012,  0.0085, -0.0094,  0.0010, -0.0160,  0.0282, -0.0136,
         -0.0044, -0.0340, -0.0468, -0.0055,  0.0178, -0.0037, -0.0244,  0.0261,
          0.0283,  0.0080, -0.0100, -0.0217,  0.0356,  0.0041, -0.0347,  0.0093,
         -0.0143,  0.0103,  0.0178,  0.0365, -0.0155, -0.0015,  0.0004, -0.0254,
          0.0246,  0.0352, -0.0214,  0.0106,  0.0361,  0.0497, -0.0351, -0.0080,
         -0.0099, -0.0089,  0.0223, -0.0018, -0.0564,  0.0129,  0.0228, -0.0029,
         -0.0040,  0.0499,  0.0075,  0.0153, -0.0106,  0.0033, -0.0236,  0.0061,
         -0.0042, -0.0030,  0.0290,  0.0016, -0.0178,  0.0302, -0.0139, -0.0222,
          0.0126, -0.0194,  0.0227, -0.0178, -0.0216, -0.0060, -0.0216, -0.0349,
         -0.0203,  0.0200,  0.0162, -0.0082,  0.0255,  0.0110, -0.0091,  0.0120,
          0.0133,  0.0114,  0.0436, -0.0052,  0.0147,  0.0013, -0.0289, -0.0316,
          0.0342, -0.0082, -0.0032,  0.0039,  0.0262,  0.0052, -0.0118,  0.0275,
          0.0249, -0.0089,  

In [57]:
from dataclasses import dataclass, field
from typing import List

class ElementBuffer:
    csmri_5_noise: List[float] = []
    csmri_10_noise: List[float] = []
    spi_4_k: List[int] = []
    spi_8_k: List[int] = []
    pr_27_alpha: List[float] = []
    pr_81_alpha: List[float] = []

buffer = ElementBuffer()


buffer.csmri_5_noise
getattr(buffer, csmri_5_noise)


[]

In [59]:
getattr(buffer, 'csmri_5_noise')

[]