In [14]:
#importing packages

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import platform
import torch.nn.init as init

In [3]:
# initializing values

start_epoch = 0
end_epoch = 50
learning_rate = 1e-4
layer_num = 9
group_num = 1
cs_ratio = 25

In [4]:
# compression ratios and input pixel dimentions

ratio_dict = {1: 10, 4: 43, 10: 109, 25: 272, 30: 327, 40: 436, 50: 545}

In [5]:
n_input = ratio_dict[cs_ratio]
n_output = 1089

# training block number
ntrain_blocks = 88912
batch_size = 64

In [6]:
# loading sampling matrix

phi_matrix = sio.loadmat(r'/home/kudsit/github_gopika/ISTA-Net/phi_0_25_1089.mat')

In [7]:
type(phi_matrix)

dict

In [8]:
phi_matrix

{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Mon Sep 28 13:52:26 2015',
 '__version__': '1.0',
 '__globals__': [],
 'phi': array([[ 0.01110017, -0.00433498,  0.0031883 , ..., -0.01654842,
         -0.02102439,  0.01379417],
        [-0.01005728, -0.02996811,  0.06524745, ...,  0.06070643,
         -0.01883776,  0.02234696],
        [ 0.01976339,  0.01614716, -0.01344627, ..., -0.00730817,
         -0.00801255,  0.01941273],
        ...,
        [ 0.01689902,  0.01530075,  0.04503205, ...,  0.02353648,
         -0.01274751,  0.01843241],
        [ 0.00188634, -0.02561448, -0.07956788, ..., -0.04413104,
         -0.01198319,  0.00925998],
        [-0.00873868, -0.01577827, -0.01665587, ..., -0.01382855,
         -0.01844824,  0.02001504]])}

In [9]:
phi_input_matrix = phi_matrix['phi'] # sampling matrix

In [10]:
training_data = r'/home/kudsit/github_gopika/ISTA-Net/Training_Data.mat' # training data

In [11]:
Q_name = sio.loadmat(r'/home/kudsit/github_gopika/ISTA-Net/Initialization_Matrix_25.mat')
q_init = Q_name['Qinit']   # inintialization matrix

In [12]:
class BasicISTABlock(torch.nn.Module):
    def __init__(self):
        super(BasicISTABlock, self).__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))
        self.soft_thresh = nn.Parameter(torch.Tensor([0.1]))  # creates a tensor object with a single item

        # "Xavier normal" refers to a method of initializing the weights of neural network layers.

        # forward and backward convolutional layers

        self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 1, 3, 3)))
        self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(1, 32, 3, 3)))

        def forward(self, x, PhiTPhi, PhiTb):
            x = x - self.lambda_step * torch.mm(x, PhiTPhi)
            x = x + self.lambda_step * PhiTb
            x_input = x.view(-1, 1, 33, 33)

            x = F.conv2d(x_input, self.conv1_forward, padding=1)
            x = F.relu(x)
            x_forward = F.conv2d(x, self.conv2_forward, padding=1)

            x = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thresh))

            x = F.conv2d(x, self.conv1_backward, padding=1)
            x = F.relu(x)
            x_backward = F.conv2d(x, self.conv2_backward, padding=1)

            x_pred = x_backward.view(-1, 1089)

            x = F.conv2d(x_forward, self.conv1_backward, padding=1)
            x = F.relu(x)
            x_est = F.conv2d(x, self.conv2_backward, padding=1)
            symloss = x_est - x_input

            return [x_pred, symloss]

In [15]:
# Define ISTA-Net
class ISTANet(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANet, self).__init__()
        onelayer = []
        self.LayerNo = LayerNo
        for i in range(LayerNo):
            onelayer.append(BasicISTABlock())
        self.fcs = nn.ModuleList(onelayer)
    def forward(self, Phix, Phi, Qinit):
        PhiTPhi = torch.mm(torch.transpose(Phi, 0, 1), Phi)
        PhiTb = torch.mm(Phix, Phi)
        x = torch.mm(Phix, torch.transpose(Qinit, 0, 1))
        layers_sym = []   # for computing symmetric loss
        for i in range(self.LayerNo):
            [x, layer_sym] = self.fcs[i](x, PhiTPhi, PhiTb)
            layers_sym.append(layer_sym)
        x_final = x
        return [x_final, layers_sym]
model = ISTANet(layer_num)
model = nn.DataParallel(model)

In [None]:
print_flag = 1   # print parameter number

if print_flag:
    num_count = 0
    for para in model.parameters():
        num_count += 1
        print('Layer %d' % num_count)
        print(para.size())



class RandomDataset(Dataset):
    def __init__(self, data, length):
        self.data = data
        self.len = length

    def __getitem__(self, index):
        return torch.Tensor(self.data[index, :]).float()

    def __len__(self):
        return self.len


if (platform.system() =="Windows"):
    rand_loader = DataLoader(dataset=RandomDataset(Training_labels, nrtrain), batch_size=batch_size, num_workers=0,
                             shuffle=True)
else:
    rand_loader = DataLoader(dataset=RandomDataset(Training_labels, nrtrain), batch_size=batch_size, num_workers=4,
                             shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model_dir = "./%s/CS_ISTA_Net_layer_%d_group_%d_ratio_%d_lr_%.4f" % (args.model_dir, layer_num, group_num, cs_ratio, learning_rate)

log_file_name = "./%s/Log_CS_ISTA_Net_layer_%d_group_%d_ratio_%d_lr_%.4f.txt" % (args.log_dir, layer_num, group_num, cs_ratio, learning_rate)

if not os.path.exists(model_dir):
    os.makedirs(model_dir)


if start_epoch > 0:
    pre_model_dir = model_dir
    model.load_state_dict(torch.load('./%s/net_params_%d.pkl' % (pre_model_dir, start_epoch)))


Phi = torch.from_numpy(Phi_input).type(torch.FloatTensor)
Phi = Phi.to(device)

Qinit = torch.from_numpy(Qinit).type(torch.FloatTensor)
Qinit = Qinit.to(device)


# Training loop
for epoch_i in range(start_epoch+1, end_epoch+1):
    for data in rand_loader:

        batch_x = data
        batch_x = batch_x.to(device)

        Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))

        [x_output, loss_layers_sym] = model(Phix, Phi, Qinit)

        # Compute and print loss
        loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))

        loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
        for k in range(layer_num-1):
            loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1], 2))

        gamma = torch.Tensor([0.01]).to(device)

        # loss_all = loss_discrepancy
        loss_all = loss_discrepancy + torch.mul(gamma, loss_constraint)

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss_all.backward()
        optimizer.step()

        output_data = "[%02d/%02d] Total Loss: %.4f, Discrepancy Loss: %.4f,  Constraint Loss: %.4f\n" % (epoch_i, end_epoch, loss_all.item(), loss_discrepancy.item(), loss_constraint)
        print(output_data)

    output_file = open(log_file_name, 'a')
    output_file.write(output_data)
    output_file.close()

    if epoch_i % 5 == 0:
        torch.save(model.state_dict(), "./%s/net_params_%d.pkl" % (model_dir, epoch_i))  # save only the parameters
