In [1]:
import sys
sys.path.append('/home/liu.6221/PnP_fastMRI-master _Liu')
import logging
import pathlib
import random
import shutil
import time
import os
import h5py
import numpy as np
import torch
import torchvision
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader
# from common.args import Args
from common.subsample import MaskFunc
from data import transforms
from data.mri_data import SelectiveSliceData
from models.unet.unet_model import UnetModel
from models.PnP.dncnn import DnCNN

from data.multicoil_sim import random_map

In [2]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


In [None]:
data_root = '/storage/fastMRI_brain/data/Sizhuo_preprocessed_data/training'
files = list(pathlib.Path(data_root).iterdir())
files = files[1:2]
keep_slices_o = []
keep_slices_i = []
for fname in sorted(files):
    with h5py.File(fname, 'r') as data:
        pnp_o = data['pnp_truth'].value
        random_slices = pnp_o[:,:,round(random.uniform(0,5))]
        print(abs(np.real(random_slices)).max())
        random_slices = random_slices/abs(np.real(random_slices)).max()
        print(abs(np.real(random_slices)).max())
        print(abs(np.imag(random_slices)).max())
        sigma = np.linalg.norm(random_slices)/np.sqrt(random_slices.size)/(10**(20/20))/np.sqrt(2)
        print(sigma)
        sigma = 0.037
        data_out = torch.from_numpy(np.stack((random_slices.real, random_slices.imag), axis=-1)).float()
                
#                 for i in range()
#                 data_out = transforms.complex_random_crop(data_out,(64,64))
    
        data_in = data_out + sigma*torch.randn(data_out.size())
#         data_in_max = torch.max(torch.abs(data_in[:,:,0]))
#         data_out = data_out / data_in_max
#         data_in = data_in / data_in_max
        
        error = ((data_in - data_out)**2).mean()
        print(error)
        

        
#                 self.keep_slices_o.append(data_out.permute(2,0,1))
#                 self.keep_slices_i.append(data_in.permute(2,0,1))

In [None]:
o = (data_out[:,:,0]).numpy()+1j*(data_out[:,:,1]).numpy()
print(o.shape)
i = (data_in[:,:,0]).numpy()+1j*(data_in[:,:,1]).numpy()
import matplotlib.pyplot as plt
# print(np.linalg.norm(i))
# print(np.linalg.norm(o))
plt.imshow(abs(o), cmap = 'gray',vmin=0, vmax=1)
# plt.imshow(abs(i), cmap = 'gray',vmin=0, vmax=1)

In [3]:
class create_datasets():
    def __init__(self):
        data_root = '/storage/fastMRI_brain/data/Sizhuo_preprocessed_data/training'
        files = list(pathlib.Path(data_root).iterdir())
#         files = files[0:1]
        self.keep_slices_o = []
        self.keep_slices_i = []
        for fname in sorted(files):
            with h5py.File(fname, 'r') as data:
                pnp_o = data['pnp_truth'].value
                random_slices = pnp_o[:,:,round(random.uniform(0,5))]
                random_slices = random_slices/abs(np.real(random_slices)).max()
#                 sigma = np.linalg.norm(random_slices)/np.sqrt(random_slices.size)/(10**(14/20))/np.sqrt(2)
                sigma = 0.015
                data_out = torch.from_numpy(np.stack((random_slices.real, random_slices.imag), axis=-1)).float()
                data_in = data_out + sigma*torch.randn(data_out.size())
                data_in_max = torch.max(torch.abs(data_in[:,:,0]))
                data_out = data_out / data_in_max
                data_in = data_in / data_in_max 
                
                for i in range(144):
                    data_out,data_in = transforms.complex_random_crop(data_out,data_in,(64,64))
    

#                     self.keep_slices_o.append(data_in.permute(2,0,1))
#                     self.keep_slices_i.append(data_in.permute(2,0,1))        
                    self.keep_slices_o.append(data_out.permute(2,0,1))
                    self.keep_slices_i.append(data_in.permute(2,0,1))
    def __len__(self):
        return len(self.keep_slices_o)
    def __getitem__(self, index):
        outp = self.keep_slices_o[index]
        inp = self.keep_slices_i[index]
        return inp, outp

In [4]:
def create_data_loaders():
#     dev_data, train_data = create_datasets(args)
#     display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]

    train_loader = DataLoader(
        create_datasets(),
        batch_size=32,
        shuffle=True,
        num_workers=16,
        pin_memory=True,
    )

    return train_loader

In [6]:
def main():

    writer = SummaryWriter(log_dir='/home/liu.6221/PnP_fastMRI-master _Liu/summary')

    if resume:
        device = torch.device('cuda')
        checkpoint = torch.load('/home/liu.6221/PnP_fastMRI-master _Liu/model.pt')
        model = DnCNN(depth=5, image_channels=2, n_channels=64, snorm = True, realsnorm = False, L = 1, bnorm_type = 'mean', residual=False).cuda()
        model.load_state_dict(checkpoint['model'])
        optimizer = torch.optim.Adam(model.parameters(),lr = 0.001, weight_decay = 0)
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        del checkpoint
    else:
        model = DnCNN(depth=5, image_channels=2, n_channels=64, snorm = True, realsnorm = False, L = 1, bnorm_type = 'mean', residual=False).cuda()
#         if 1:
#             model = torch.nn.DataParallel(model)
        optimizer = torch.optim.Adam(model.parameters(),lr = 0.001, weight_decay = 0)
        start_epoch = 0

    logging.info(model)

    train_loader = create_data_loaders()
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,  step_size = 600, gamma = 0.1)

    for epoch in range(start_epoch, 800):
        scheduler.step(epoch)
        model.train()
        avg_loss = 0.
        losses = []
        start_epoch = start_iter = time.perf_counter()
        global_step = epoch * len(train_loader)
        for train_iter, data in enumerate(train_loader):
            inp, outp = data
            inp = inp.cuda()
            outp = outp.cuda()
            my_recon = model(inp)
            loss = F.mse_loss(my_recon, outp, reduction='sum')
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if train_iter > 0 else loss.item()
            writer.add_scalar('TrainLoss', loss.item(), global_step + train_iter)
            start_iter = time.perf_counter()
            train_time = time.perf_counter() - start_epoch
#             print('%4d %4d / %4d loss = %2.4f' % (epoch+1, train_iter, inp.size()[0], loss.item()/inp.size()[0]))
        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'exp_dir':  '/home/liu.6221/PnP_fastMRI-master _Liu/5layer_patch_0.1'
            },
            f='/home/liu.6221/PnP_fastMRI-master _Liu/5layer_patch_0.1/model.pt'
        )
        torch.save(model, '/home/liu.6221/PnP_fastMRI-master _Liu/5layer_patch_0.1/%03d_0.1.pt' % (epoch+1))   
#         save_model(args, args.exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best)
        logging.info(
            f'Epoch = [{epoch:4d}/{3000:4d}] TrainLoss = {avg_loss:.4g} '
            f'TrainTime = {train_time:.4f}s',
        )
    writer.close()

In [None]:
if __name__ == '__main__':
#     args = create_arg_parser().parse_args()
    # restrict visible cuda devices
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
#     random.seed(1000)
#     np.random.seed(1000)
#     torch.manual_seed(1000)
    
    resume = 0
    
    main()

INFO:root:DnCNN(
  (dncnn): Sequential(
    (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ScalarMultiplyLayer(L=1)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): ScalarMultiplyLayer(L=1)
    (5): MeanOnlyBatchNorm(64, momentum=0.95 )
    (6): ReLU(inplace)
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): ScalarMultiplyLayer(L=1)
    (9): MeanOnlyBatchNorm(64, momentum=0.95 )
    (10): ReLU(inplace)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): ScalarMultiplyLayer(L=1)
    (13): MeanOnlyBatchNorm(64, momentum=0.95 )
    (14): ReLU(inplace)
    (15): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (16): ScalarMultiplyLayer(L=1)
  )
)
INFO:root:Epoch = [   0/3000] TrainLoss = 96.69 TrainTime = 29.2326s
INFO:root:Epoch = [   1/3000] TrainLoss = 53.96 TrainT

INFO:root:Epoch = [ 107/3000] TrainLoss = 25.73 TrainTime = 27.9348s
INFO:root:Epoch = [ 108/3000] TrainLoss = 24.15 TrainTime = 29.5742s
INFO:root:Epoch = [ 109/3000] TrainLoss = 24.76 TrainTime = 28.6836s
INFO:root:Epoch = [ 110/3000] TrainLoss = 24.2 TrainTime = 28.1178s
INFO:root:Epoch = [ 111/3000] TrainLoss = 25.7 TrainTime = 31.6328s
INFO:root:Epoch = [ 112/3000] TrainLoss = 23.86 TrainTime = 26.9748s
INFO:root:Epoch = [ 113/3000] TrainLoss = 24.14 TrainTime = 28.5979s
INFO:root:Epoch = [ 114/3000] TrainLoss = 23.83 TrainTime = 29.5034s
INFO:root:Epoch = [ 115/3000] TrainLoss = 23.61 TrainTime = 28.8828s
INFO:root:Epoch = [ 116/3000] TrainLoss = 23.78 TrainTime = 28.3405s
INFO:root:Epoch = [ 117/3000] TrainLoss = 23.77 TrainTime = 29.1197s
INFO:root:Epoch = [ 118/3000] TrainLoss = 24.27 TrainTime = 29.2509s
INFO:root:Epoch = [ 119/3000] TrainLoss = 23.79 TrainTime = 27.8307s
INFO:root:Epoch = [ 120/3000] TrainLoss = 24.58 TrainTime = 29.6581s
INFO:root:Epoch = [ 121/3000] TrainL

INFO:root:Epoch = [ 227/3000] TrainLoss = 22.91 TrainTime = 30.7281s
INFO:root:Epoch = [ 228/3000] TrainLoss = 22.99 TrainTime = 32.1122s
INFO:root:Epoch = [ 229/3000] TrainLoss = 23.09 TrainTime = 31.1746s
INFO:root:Epoch = [ 230/3000] TrainLoss = 22.69 TrainTime = 30.7062s
INFO:root:Epoch = [ 231/3000] TrainLoss = 23.41 TrainTime = 32.4739s
INFO:root:Epoch = [ 232/3000] TrainLoss = 23.37 TrainTime = 31.5412s
INFO:root:Epoch = [ 233/3000] TrainLoss = 23.3 TrainTime = 30.6380s
INFO:root:Epoch = [ 234/3000] TrainLoss = 23.91 TrainTime = 33.8441s
INFO:root:Epoch = [ 235/3000] TrainLoss = 23.64 TrainTime = 30.6789s
INFO:root:Epoch = [ 236/3000] TrainLoss = 23.15 TrainTime = 29.4085s
INFO:root:Epoch = [ 237/3000] TrainLoss = 23.07 TrainTime = 30.6639s
INFO:root:Epoch = [ 238/3000] TrainLoss = 23.28 TrainTime = 34.5902s
INFO:root:Epoch = [ 239/3000] TrainLoss = 23.34 TrainTime = 31.3709s
INFO:root:Epoch = [ 240/3000] TrainLoss = 23.36 TrainTime = 30.7023s
INFO:root:Epoch = [ 241/3000] Train

INFO:root:Epoch = [ 346/3000] TrainLoss = 22.89 TrainTime = 33.5769s
INFO:root:Epoch = [ 347/3000] TrainLoss = 22.58 TrainTime = 32.6542s
INFO:root:Epoch = [ 348/3000] TrainLoss = 22.81 TrainTime = 31.7185s
INFO:root:Epoch = [ 349/3000] TrainLoss = 23 TrainTime = 32.8981s
INFO:root:Epoch = [ 350/3000] TrainLoss = 22.84 TrainTime = 33.0670s
INFO:root:Epoch = [ 351/3000] TrainLoss = 23.06 TrainTime = 32.5330s
INFO:root:Epoch = [ 352/3000] TrainLoss = 22.58 TrainTime = 32.6333s
INFO:root:Epoch = [ 353/3000] TrainLoss = 23.39 TrainTime = 34.0941s
INFO:root:Epoch = [ 354/3000] TrainLoss = 22.64 TrainTime = 37.2638s
INFO:root:Epoch = [ 355/3000] TrainLoss = 23.01 TrainTime = 34.0128s
INFO:root:Epoch = [ 356/3000] TrainLoss = 23.19 TrainTime = 41.0486s
INFO:root:Epoch = [ 357/3000] TrainLoss = 22.63 TrainTime = 33.3907s
INFO:root:Epoch = [ 358/3000] TrainLoss = 22.61 TrainTime = 31.8207s
INFO:root:Epoch = [ 359/3000] TrainLoss = 24.11 TrainTime = 35.9961s
INFO:root:Epoch = [ 360/3000] TrainLo

INFO:root:Epoch = [ 465/3000] TrainLoss = 22.59 TrainTime = 37.6699s
INFO:root:Epoch = [ 466/3000] TrainLoss = 22.7 TrainTime = 35.3524s
INFO:root:Epoch = [ 467/3000] TrainLoss = 22.84 TrainTime = 38.4808s
INFO:root:Epoch = [ 468/3000] TrainLoss = 22.98 TrainTime = 40.3103s
INFO:root:Epoch = [ 469/3000] TrainLoss = 22.94 TrainTime = 35.2263s
INFO:root:Epoch = [ 470/3000] TrainLoss = 22.48 TrainTime = 39.0667s
INFO:root:Epoch = [ 471/3000] TrainLoss = 22.85 TrainTime = 36.9266s
INFO:root:Epoch = [ 472/3000] TrainLoss = 22.97 TrainTime = 38.0669s
INFO:root:Epoch = [ 473/3000] TrainLoss = 22.77 TrainTime = 35.9625s
INFO:root:Epoch = [ 474/3000] TrainLoss = 22.87 TrainTime = 39.3157s
INFO:root:Epoch = [ 475/3000] TrainLoss = 22.55 TrainTime = 36.1714s
INFO:root:Epoch = [ 476/3000] TrainLoss = 22.65 TrainTime = 39.7668s
INFO:root:Epoch = [ 477/3000] TrainLoss = 22.83 TrainTime = 35.1719s
INFO:root:Epoch = [ 478/3000] TrainLoss = 22.75 TrainTime = 37.3280s
INFO:root:Epoch = [ 479/3000] Train

INFO:root:Epoch = [ 584/3000] TrainLoss = 22.69 TrainTime = 41.7445s
INFO:root:Epoch = [ 585/3000] TrainLoss = 22.7 TrainTime = 41.5695s
INFO:root:Epoch = [ 586/3000] TrainLoss = 22.66 TrainTime = 42.1554s
INFO:root:Epoch = [ 587/3000] TrainLoss = 22.38 TrainTime = 40.3513s
INFO:root:Epoch = [ 588/3000] TrainLoss = 22.49 TrainTime = 40.6654s
INFO:root:Epoch = [ 589/3000] TrainLoss = 22.96 TrainTime = 40.6082s
INFO:root:Epoch = [ 590/3000] TrainLoss = 22.92 TrainTime = 41.1183s
INFO:root:Epoch = [ 591/3000] TrainLoss = 22.47 TrainTime = 39.9355s
INFO:root:Epoch = [ 592/3000] TrainLoss = 22.86 TrainTime = 42.6685s
INFO:root:Epoch = [ 593/3000] TrainLoss = 22.6 TrainTime = 42.7851s
INFO:root:Epoch = [ 594/3000] TrainLoss = 23.07 TrainTime = 41.7531s
INFO:root:Epoch = [ 595/3000] TrainLoss = 22.72 TrainTime = 40.2487s


In [None]:
import torch.nn as nn
model = torch.load('/home/liu.6221/PnP_fastMRI-master _Liu/190.pt')
def kernel_specnorm(kernel, power_iter=10, print_values=False):
    with torch.no_grad():
        kernel_mat = kernel
        height = kernel_mat.size(0)
        kernel_mat = kernel_mat.reshape(height,-1)

        # Create random U
        u = torch.normal(mean=torch.zeros((height))).to(kernel.device)

        # Perform power iterations
        for iter in range(power_iter):
            Ku = torch.mv(kernel_mat.t(), u)
            v =  Ku / Ku.norm()
            Kv = torch.mv(kernel_mat, v)
            u = Kv / Kv.norm()
            if print_values:
                print(torch.dot(u, torch.mv(kernel_mat, v)))
        return torch.dot(u, torch.mv(kernel_mat, v)).item()


def test_spectral_norm(model, test_type='conv', N=128,power_iter=10, print_values=True):
    # Loop over layers
    a_l = []
    for layer in model.modules():
        if isinstance(layer,nn.Conv2d):
            kernel = layer.weight
            L = kernel_specnorm(kernel, power_iter=power_iter)
            a_l.append(L)
            if print_values:
                print(L)
            
    L_total = np.prod(np.array(a_l))
    if print_values:
        print(L_total)
    return a_l, L_total
                
L = test_spectral_norm(model, power_iter=1000)

In [None]:
0.1**(1/5)