In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Set the necessary configurations and hyperparameters
start_epoch = 0
end_epoch = 1
layer_num = 9
learning_rate = 1e-4
cs_ratio = 4
nrtrain = 88912
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_input = 1089
n_output = 109

# Generate a random sampling matrix
def generate_gaussian_sampling_matrix(num_rows, num_cols):
    matrix = np.random.randn(num_rows, num_cols)
    normalized_matrix = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
    return normalized_matrix

Phi_input = generate_gaussian_sampling_matrix(n_output, n_input)

# Load training data
Train_data = sio.loadmat(r"C:\Users\vrinda\Desktop\2ndsem\RL_DL\ISTA\i_did_it\setup.mat")
train_label = Train_data['labels']
X_data = train_label.transpose()
y_data = np.dot(Phi_input, X_data)
y_yT = np.dot(y_data, y_data.transpose())
invy_yT = np.linalg.inv(y_yT)
x_yT = np.dot(X_data, y_data.transpose())
Qinit = np.dot(x_yT, invy_yT)

# Define ISTA-Net Block
class BackBone(nn.Module):
    def __init__(self):
        super(BackBone, self).__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]))
        self.conv1_forward = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_forward = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_backward = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_backward = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_normal_(self.conv1_forward.weight)
        nn.init.xavier_normal_(self.conv2_forward.weight)
        nn.init.xavier_normal_(self.conv1_backward.weight)
        nn.init.xavier_normal_(self.conv2_backward.weight)

    def forward(self, x):
        x_input = x.view(-1, 1, 33, 33)
        x = F.relu(self.conv1_forward(x_input))
        x_forward = self.conv2_forward(x)
        x = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr))
        x = F.relu(self.conv1_backward(x))
        x_backward = self.conv2_backward(x)
        x_pred = x_backward.view(-1, 1089)
        x_est = self.conv2_backward(F.relu(self.conv1_backward(x_forward)))
        symloss = x_est - x_input
        return [x_pred, symloss]

# Define ISTA-Net
class Net(nn.Module):
    def __init__(self, layer_num):
        super(Net, self).__init__()
        self.layer_num = layer_num
        self.blocks = nn.ModuleList([BackBone() for _ in range(layer_num)])

    def forward(self, Phix, Phi, Qinit):
        PhiTPhi = torch.mm(torch.transpose(Phi, 0, 1), Phi)
        PhiTb = torch.mm(Phix, Phi)
        x = torch.mm(Phix, torch.transpose(Qinit, 0, 1))
        sym_layers = []  # for computing symmetric loss
        for i in range(self.layer_num):
            [x, sym_layer] = self.blocks[i](x)
            sym_layers.append(sym_layer)
        x_final = x
        return [x_final, sym_layers]

# Create the ISTA-Net model
model = Net(layer_num).to(device)

# Create the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Convert data to float tensors
Phi = torch.from_numpy(Phi_input).type(torch.FloatTensor).to(device)
Qinit = torch.from_numpy(Qinit).type(torch.FloatTensor).to(device)

# Define a custom dataset for training
class RandomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return torch.Tensor(self.data[index, :]).float()

    def __len__(self):
        return len(self.data)

batch_size = 64
rand_loader = DataLoader(dataset=RandomDataset(train_label), batch_size=batch_size, num_workers=0, shuffle=True)

# Training loop
for epoch_i in range(start_epoch + 1, end_epoch + 1):
    for batch_x in rand_loader:
        batch_x = batch_x.to(device)
        Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))
        [x_output, loss_layers_sym] = model(Phix, Phi, Qinit)

        # Compute and print loss
        loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))
        loss_constraint = sum(torch.mean(torch.pow(loss, 2)) for loss in loss_layers_sym)
        gamma = torch.Tensor([0.01]).to(device)
        Total_loss = loss_discrepancy + gamma * loss_constraint

        # Zero gradients, perform a backward pass, and update the weights
        optimizer.zero_grad()
        Total_loss.backward()
        optimizer.step()

        # Compute SSIM and PSNR
        ssim_value = ssim(batch_x.cpu().numpy(), x_output.detach().cpu().numpy(), data_range=1.0)
        psnr_value = psnr(batch_x.cpu().numpy(), x_output.detach().cpu().numpy(), data_range=1.0)

        # Print loss, SSIM, and PSNR
        output_data = "Epoch: [%d/%d], Total Loss: %.4f, Discrepancy Loss: %.4f, Constraint Loss: %.4f, SSIM: %.4f, PSNR: %.4f" % (epoch_i, end_epoch, Total_loss.item(), loss_discrepancy.item(), loss_constraint.item(), ssim_value, psnr_value)
        print(output_data)



Epoch: [1/1], Total Loss: 0.3960, Discrepancy Loss: 0.3888, Constraint Loss: 0.7199, SSIM: -0.0155, PSNR: 4.1023
Epoch: [1/1], Total Loss: 0.3234, Discrepancy Loss: 0.3170, Constraint Loss: 0.6378, SSIM: -0.0115, PSNR: 4.9896
Epoch: [1/1], Total Loss: 0.3481, Discrepancy Loss: 0.3416, Constraint Loss: 0.6449, SSIM: -0.0076, PSNR: 4.6647
Epoch: [1/1], Total Loss: 0.3149, Discrepancy Loss: 0.3086, Constraint Loss: 0.6241, SSIM: -0.0041, PSNR: 5.1055
Epoch: [1/1], Total Loss: 0.2694, Discrepancy Loss: 0.2634, Constraint Loss: 0.6058, SSIM: 0.0002, PSNR: 5.7941
Epoch: [1/1], Total Loss: 0.1967, Discrepancy Loss: 0.1909, Constraint Loss: 0.5784, SSIM: 0.0045, PSNR: 7.1927
Epoch: [1/1], Total Loss: 0.2236, Discrepancy Loss: 0.2174, Constraint Loss: 0.6263, SSIM: 0.0079, PSNR: 6.6283
Epoch: [1/1], Total Loss: 0.2071, Discrepancy Loss: 0.2007, Constraint Loss: 0.6452, SSIM: 0.0113, PSNR: 6.9746
Epoch: [1/1], Total Loss: 0.1698, Discrepancy Loss: 0.1633, Constraint Loss: 0.6439, SSIM: 0.0151, P

In [3]:
print(train_label.shape)
print((y_yT).shape)
print((y_data).shape)
print((invy_yT).shape)
print((x_yT).shape)
print((X_data).shape)

(88912, 1089)
(109, 109)
(109, 88912)
(109, 109)
(1089, 109)
(1089, 88912)
