In [None]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import scipy.io as sio
import torchvision.transforms as transforms
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import platform
from argparse import ArgumentParser
import types
import matplotlib.pyplot as plt
from utils.ISTANet import *

In [None]:
matrix_dir = 'sampling_matrix'
data_dir = 'data'
n_epoch = 400
learning_rate = 0.0001
layer_num = 25
group_num = 1
cs_ratio = 50 #10,20,30,40,50
noise_sigma = 0.01 #0.01 or 0.05

try:
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
    torch.backends.cuda.matmul.allow_tf32 = False
    # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
    torch.backends.cudnn.allow_tf32 = False
except:
    pass

In [None]:
#Select device (CPU/GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:"+str(device))

In [None]:
nrtrain = 25
batch_size = 4


# Load CS Sampling Matrix: phi
Phi_data_Name = './%s/mask_%d.mat' % (matrix_dir, cs_ratio)
Phi_data = sio.loadmat(Phi_data_Name)
mask_matrix = Phi_data['mask_matrix']


mask_matrix = torch.from_numpy(mask_matrix).type(torch.FloatTensor)
mask = torch.unsqueeze(mask_matrix, 2)
mask = torch.cat([mask, mask], 2)
mask = mask.to(device)


Training_data_Name = 'Train_BrainImages.mat'
Training_data = sio.loadmat('./%s/%s' % (data_dir, Training_data_Name))
Training_labels = Training_data['labels']

In [None]:
class RandomDataset(Dataset):
    def __init__(self, data, length,transform=None):
        self.data = data
        self.len = length
        self.transform = transform

    def __getitem__(self, index):
        item = torch.unsqueeze(torch.Tensor(self.data[index, :]).float(),0)
        if self.transform:
            item = self.transform(item)
        return item

    def __len__(self):
        return self.len

In [None]:
##Load data
data_transforms = transforms.Compose([
    transforms.CenterCrop(256),
    # transforms.ToTensor(),
    #transforms.Normalize((0.5, ), (0.5, ))
    lambda x: x*10
])

if (platform.system() =="Windows"):
    rand_loader = DataLoader(dataset=RandomDataset(Training_labels, nrtrain,data_transforms), batch_size=batch_size, num_workers=0,
                             shuffle=True)
else:
    rand_loader = DataLoader(dataset=RandomDataset(Training_labels, nrtrain,data_transforms), batch_size=batch_size, num_workers=4,
                             shuffle=True)

In [None]:
#Define model and training
model = ISTANetplus(layer_num)
model = nn.DataParallel(model)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Training loopsampling_matrix
loss_all_epoch_array=[]
loss_all_epoch = []

loss_noise_epoch_array=[]
loss_noise_epoch = []
for epoch_i in range(1, n_epoch+1):
    for data in rand_loader:
        batch_x_orig = data
        batch_x_orig = batch_x_orig.to(device)
        
        #add noise
        batch_x = (data + noise_sigma*np.random.normal(0,1,np.shape(data))).float()
        batch_x = batch_x.to(device)

        PhiTb = FFT_Mask_ForBack()(batch_x, mask)

        [x_output, loss_layers_sym] = model(PhiTb, mask)

        # Compute and print loss
        #loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))
        loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))
        loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
        for k in range(layer_num-1):
            loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1], 2))

            
        loss_noise = torch.mean(torch.pow(x_output - batch_x_orig, 2))
        gamma1 = torch.Tensor([0.005]).to(device)
        gamma2 = torch.Tensor([0.005]).to(device)

        # loss_all = loss_discrepancy
        loss_all = loss_discrepancy + torch.mul(gamma1, loss_constraint)+ torch.mul(gamma2, loss_noise)
        loss_all_epoch.append(loss_all)
        loss_noise_epoch.append(loss_noise)
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss_all.backward()
        optimizer.step()

        output_data = "[%02d/%02d] Total Loss: %.5f, Discrepancy Loss: %.5f,  Constraint Loss: %.5f, Noise Loss: %.5f" % (epoch_i, n_epoch, loss_all.item(), loss_discrepancy.item(), loss_constraint, loss_noise.item())
        print(output_data,end='\r')
    
    loss_all_epoch =  torch.tensor(loss_all_epoch, device = 'cpu')
    loss_all_epoch_array.append(torch.mean(loss_all_epoch))
    loss_all_epoch = []
    
    loss_noise_epoch =  torch.tensor(loss_noise_epoch, device = 'cpu')
    loss_noise_epoch_array.append(torch.mean(loss_noise_epoch))
    loss_noise_epoch = []
    
    print("\n")
    if epoch_i % 5 == 0:
        torch.save(model.state_dict(), "./models/net_params_cs_%d_%d.pkl" % (cs_ratio,epoch_i))  # save only the parameters
        print("Model saved")
        plt.plot(np.linspace(1, epoch_i, epoch_i).astype(int), loss_all_epoch_array)
        plt.plot(np.linspace(1, epoch_i, epoch_i).astype(int), loss_noise_epoch_array)
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend(['Total', 'Noise component'])
        plt.show()