In [1]:
%matplotlib inline
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import hdf5storage
import time
import h5py
from torch.utils.data import Dataset, DataLoader

from core import image_utils
from core import model_utils
from core import data_loader
from core import models_3D

In [2]:
data_path = "/media/damingshen/05ba0687-eb77-48c2-a767-e9a44fbe94c1/Daming_data"
model_path = data_path + "/models"
train_path = data_path + "/RT_cine/AFib/Training"

x_folder = "Zero_filled"
y_folder = "CS_recon"

x_fnames = os.listdir(train_path + "/" + x_folder)
x_files = [os.path.join(train_path, x_folder, fname) for fname in x_fnames]

y_fnames = os.listdir(train_path + "/" + y_folder)
y_files = [os.path.join(train_path, y_folder, fname[:-6] + "cs.mat") for fname in x_fnames]
print(len(x_files))

398


In [4]:
gpu_no = "1"
device = torch.device("cuda:"+str(gpu_no) if torch.cuda.is_available() else "cpu")
batch_size = 1
epochs = 50
lr = 1e-4
print_step = 50
num_workers = 6
lamda = 10e3
################################### network ###############################
mrinet = models_3D.ComplexUNet3Dres(features=32,drop_out = 0.0,mode = None).to(device) 
feature_extractor = models_3D.FeatureExtractor(num_layers = 15).to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, mrinet.parameters()), lr = lr, betas = (0.9, 0.99))

In [5]:
# model_name = "RTcine_recon_{}Net_crossValid.pth".format(model_mode)
import logging
from torch.autograd import Variable
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
out = None
exit_flag = False
model_name = "RTcine_PCNN_res01.pth"

logging.info("Start training")

dataset = data_loader.Loader3D_complex(x_files, y_files, imsize = 192 ,t_slices = 80)
loader = DataLoader(dataset,
                     batch_size = batch_size,
                     shuffle = True,
                     drop_last = True,
                     num_workers = num_workers)

loss_epoch = np.zeros([epochs,1])

for i in range(epochs):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, mrinet.parameters()), lr = lr, betas = (0.9, 0.99))
    total_loss = 0
    if exit_flag:
        break
    try:
        for batch_idx, data in enumerate(loader):
            optimizer.zero_grad()
            loss = 0
            x = data["x"].to(device)
            y = data["y"].to(device)
            out = mrinet(x)
            out_feat = feature_extractor(out)
            y_feat = feature_extractor(y)
            loss = loss_fn(out, y) + loss_fn(out_feat, y_feat) 
            loss.backward()
            optimizer.step()
            total_loss = total_loss + loss.data.cpu().numpy()

            #del out,x,y,out_feat,y_feat
            #torch.cuda.empty_cache()
            if (batch_idx+1)%print_step == 0:
                logging.info("Epoch {:},Batch = {:}, Learning Rate = {:.8f}, Loss= {:.8f}".format(i+1,(batch_idx+1)*batch_size,lr,total_loss/((batch_idx+1)*batch_size)))
    except KeyboardInterrupt:
        exit_flag = True
        print("Exiting training")
    loss_epoch[i] = total_loss/((batch_idx+1)*batch_size)            
    lr = lr*0.95
mrinet.save(path = model_path, filename = model_name, optimizer=optimizer)

2020-04-01 20:41:36,937 Start training
2020-04-01 20:45:16,442 Epoch 1,Batch = 50, Learning Rate = 0.00010000, Loss= 0.17558241
2020-04-01 20:48:52,080 Epoch 1,Batch = 100, Learning Rate = 0.00010000, Loss= 0.12782242
2020-04-01 20:52:28,413 Epoch 1,Batch = 150, Learning Rate = 0.00010000, Loss= 0.10724141
2020-04-01 20:56:04,168 Epoch 1,Batch = 200, Learning Rate = 0.00010000, Loss= 0.09516130
2020-04-01 20:59:37,724 Epoch 1,Batch = 250, Learning Rate = 0.00010000, Loss= 0.08722311
2020-04-01 21:03:10,974 Epoch 1,Batch = 300, Learning Rate = 0.00010000, Loss= 0.08119359
2020-04-01 21:06:46,581 Epoch 1,Batch = 350, Learning Rate = 0.00010000, Loss= 0.07603604
2020-04-01 21:13:51,451 Epoch 2,Batch = 50, Learning Rate = 0.00009500, Loss= 0.04348557
2020-04-01 21:17:26,306 Epoch 2,Batch = 100, Learning Rate = 0.00009500, Loss= 0.04114712
2020-04-01 21:21:01,085 Epoch 2,Batch = 150, Learning Rate = 0.00009500, Loss= 0.03983312
2020-04-01 21:24:36,689 Epoch 2,Batch = 200, Learning Rate = 0.

model saved as: /media/damingshen/05ba0687-eb77-48c2-a767-e9a44fbe94c1/Daming_data/models/RTcine_PCNN_res01.pth


In [5]:
## 1shot valid
data_path = "/media/damingshen/05ba0687-eb77-48c2-a767-e9a44fbe94c1/Daming_data"
train_path = data_path + "/RT_cine/AFib/Testing"
x_folder = "Zero_filled"
y_folder = "CS_recon"
x_fnames = os.listdir(train_path + "/" + x_folder)
x_files = [os.path.join(train_path, x_folder, fname) for fname in x_fnames]
y_fnames = os.listdir(train_path + "/" + y_folder)
y_files = [os.path.join(train_path, y_folder, fname[:-6] + "cs.mat") for fname in x_fnames]

In [6]:
# mrinet = models_3D.ComplexUNet3D(features=32,drop_out = 0.0,mode = None).to(device) 
#model_name = "RTcine_PCNN_res.pth"
#mrinet.load(path = model_path,mode='single', filename = model_name)

loaded: /media/damingshen/05ba0687-eb77-48c2-a767-e9a44fbe94c1/Daming_data/models/RTcine_PCNN_res.pth


In [20]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
DL_folder = "Recon_DL/PCNN_res01"
x_suffix="zp.mat"
DL_suffix="PCNN_res01.mat"
mrinet.eval()   
for idx in range(len(x_files)):
    dataset = data_loader.Loader3D_complex(x_files[idx:idx+1], y_files[idx:idx+1], imsize = 192,t_slices = 80)
    loader = DataLoader(dataset,
                 batch_size = 1,
                 shuffle = False,
                 drop_last = False,
                 num_workers = 1)
    for batch_idx, data in enumerate(loader):
        x = data["x"].to(device)
    out = mrinet(x)
    cout = out.data.cpu().numpy()
    cout = cout[:,0] + 1j*cout[:,1]
    PCNN_res01 = np.transpose(np.squeeze(cout), (1, 2, 0))
    #del x,out
    #torch.cuda.empty_cache()
    x_file = x_files[idx]
    DL_file = x_file.replace(x_folder,DL_folder)
    DL_file = DL_file.replace(x_suffix,DL_suffix)
    hdf5storage.savemat(DL_file, dict(PCNN_res01 = PCNN_res01))
    logging.info(x_fnames[idx])

2020-04-03 13:56:12,957 Start
2020-04-03 13:56:13,787 End
2020-04-03 13:56:14,894 2019_0814_CAMRI_205545_017_slc6_zp.mat
2020-04-03 13:56:16,347 Start
2020-04-03 13:56:17,168 End
2020-04-03 13:56:18,164 2020_0130_CAMRI_205545_036_slc5_zp.mat
2020-04-03 13:56:19,689 Start
2020-04-03 13:56:20,515 End
2020-04-03 13:56:21,584 2019_0814_CAMRI_205545_017_slc7_zp.mat
2020-04-03 13:56:23,032 Start
2020-04-03 13:56:23,852 End
2020-04-03 13:56:24,913 2019_0814_CAMRI_205545_017_slc3_zp.mat
2020-04-03 13:56:26,322 Start
2020-04-03 13:56:27,140 End
2020-04-03 13:56:28,138 2020_0227_CAMRI_205545_044_slc5_zp.mat
2020-04-03 13:56:29,478 Start
2020-04-03 13:56:30,295 End
2020-04-03 13:56:31,286 2019_1205_CAMRI_205545_025_slc3_zp.mat
2020-04-03 13:56:32,668 Start
2020-04-03 13:56:33,491 End
2020-04-03 13:56:34,470 2019_1114_CAMRI_205545_024_slc3_zp.mat
2020-04-03 13:56:35,987 Start
2020-04-03 13:56:36,808 End
2020-04-03 13:56:37,789 2019_1205_CAMRI_205545_025_slc4_zp.mat
2020-04-03 13:56:39,172 Start
20