In [1]:
import torch
from torch.linalg import matrix_norm as mn
import numpy as np
import create_data_upd as OD
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn

ambient = 400  #784, MNIST
measurements=round(0.25 * ambient)
N_SENSORS=10
#sparse_dim=10
mu = 100 #deafult
batch_size=128

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ShrinkageActivation(nn.Module):
    def __init__(self):
        super(ShrinkageActivation, self).__init__()

    def forward(self, x, epsilon):
        return torch.sign(x) * torch.max(torch.zeros_like(x), torch.abs(x) - epsilon)


class TruncationActivation(nn.Module):
    def __init__(self):
        super(TruncationActivation, self).__init__()

    def forward(self, x, epsilon):
        return torch.sign(x) * torch.min(torch.abs(x), epsilon * torch.ones_like(x))

class ENCODER:
    def __init__(self,A):
        self.A = A # size: (m ,amb), m=28*28*0.25, amb=28*28

    def measure_x(self, x):
        # Create measurements y
        y = torch.einsum("ma,ba->bm", self.A, x.to(DEVICE)).to(DEVICE)
        return y

    def forward(self, x):
        x = x.view(x.size(0), -1)
        min_x = torch.min(x)
        max_x = torch.max(x)

        y = self.measure_x(x)
        return y,min_x,max_x



""" sending in the network: A, min_x,max_x,y """

class Decoder:
    def __init__(self, A,y,mu, phi,min_x, max_x):
        self.A = A
        self.y = y
        self.min_x = min_x
        self.max_x = max_x

        self.y_noisy = self.noisy_measure(y)
        self.epsilon = torch.norm(self.y - self.y_noisy)
        self.mu = mu
        self.x0 =  torch.einsum("am,bm->ba", self.A.t(), self.y_noisy)

        self.sparse_dim = 400*10
        self.ambient = 400
        self.first_activation = TruncationActivation()
        self.second_activation = ShrinkageActivation()
        self.alpha = 0.7
        self.beta = 0.5
        self.acf_iterations=10

        """learned parameter"""
        self.phi = phi

    def decode(self, y, epsilon, mu, x0, z1, z2, min_x, max_x):
        u1 = z1
        u2 = z2
      #  t1 = ARGS.t10
       # t2 = ARGS.t20
        t1=1
        t2=1 #deafult

        theta1 = 1
        theta2 = 1
        Lexact = torch.tensor([1000.0]).to(DEVICE)

        for _ in range(self.acf_iterations):
            x_hat = (
                    x0
                    + (
                            (1 - theta1) * torch.einsum("as,bs->ba", self.phi.t(), u1)
                            + theta1 * torch.einsum("as,bs->ba", self.phi.t(), z1)
                            - (1 - theta2) * torch.einsum("am,bm->ba", self.A.t(), u2)
                            - theta2 * torch.einsum("am,bm->ba", self.A.t(), u2)
                    )
                    / mu
            )

            w1 = self.affine_transform1(theta1, u1, z1, t1, x_hat)
            w2 = self.affine_transform2(theta2, u2, z2, t2, y, x_hat)

            z1 = self.first_activation(w1, t1 / theta1)
            z2 = self.second_activation(w2, t2 * epsilon / theta2)
            u1 = (1 - theta1) * u1 + theta1 * z1
            u2 = (1 - theta2) * u2 + theta2 * z2

            t1 = self.alpha * t1
            t2 = self.beta * t2
            muL = torch.sqrt(mu / Lexact).to(DEVICE)
            theta_scale = (1 - muL) / (1 + muL).to(DEVICE)
            theta1 = torch.min(torch.tensor([1.0]).to(DEVICE), theta1 * theta_scale)
            theta2 = torch.min(torch.tensor([1.0]).to(DEVICE), theta2 * theta_scale)

        return torch.clamp(x_hat, min=min_x, max=max_x).to(DEVICE)

    def affine_transform1(self, theta1, u1, z1, t1, x):
        affine1 = (
                (1 - theta1) * u1
                + theta1 * z1
                - (t1 / theta1) * torch.einsum("sa,ba->bs", self.phi, x)
        )

        return affine1.detach()

    def affine_transform2(self, theta2, u2, z2, t2, y, x):
        affine2 = (
                (1 - theta2) * u2
                + theta2 * z2
                - (t2 / theta2) * (y - torch.einsum("sa,ba->bs", self.A, x))
        )
        return affine2.detach()
    def noisy_measure(self, y):
        # add Gaussian noise to y
        y_noisy = y + 0.0001 * torch.randn_like(y)
        return y_noisy

    def forward(self, y):
        y_noisy = self.noisy_measure(y)

        x0 = torch.einsum("am,bm->ba", self.A.t(), y_noisy)
        phix0 = torch.einsum("sa,ba->bs", self.phi, x0)
        mu = self.mu

        z1 = torch.zeros_like(phix0)
        z2 = torch.zeros_like(y)

        epsilon = torch.norm(y - y_noisy)

        x_hat = self.decode(y_noisy, epsilon, mu, x0, z1, z2, self.min_x, self.max_x)

        return x_hat


In [6]:
#New loss: 100*MSE normalized in the signal energy!

sig1vt_supp=10
zt_noise_sigma=0.01


sig1vt_supp_train=10
zt_noise_sigma_train=0.01
cs_ratio_train=0.25

instance='Z'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}_zt_noise_sigma={zt_noise_sigma_train}_cs_ratio={cs_ratio_train}_1000*normMseLoss.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)

len_test = 1000
rows = ambient
cols = N_SENSORS
data = OD.Data(rows=rows, cols=cols, sig1vt_supp=sig1vt_supp, k_sparse=5,zt_noise_sigma=zt_noise_sigma)
X_dataset_test, C_dataset_test, Z_dataset_test,Zt_dataset_test = data.create_Dataset_save_tag(len_test)
# Z_dataset_test_vecs = Z_dataset_test.transpose(2, 0, 1).reshape(cols, rows * len_test).transpose()
print('Checking Z:')
test=torch.from_numpy(Z_dataset_test).clone().float()

# print('Checking X:')
# test=torch.from_numpy(X_dataset_test).clone().float()

# print('Checking X-AVG:')
# test=torch.from_numpy(X_dataset_test).clone().float()
# test=test-test.mean(axis=2,keepdims=True)

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    print(z_hat)
    print(z_hat.shape)
    print(batch)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)


Checking Z:
tensor([[ 0.0076,  0.0217,  0.1611,  ..., -0.0277,  0.1159,  0.1351],
        [ 0.0233, -0.0635, -0.0287,  ...,  0.0863,  0.0057,  0.0303],
        [-0.1758,  0.0156, -0.0271,  ...,  0.1970,  0.0824,  0.0866],
        ...,
        [ 0.0447, -0.2061, -0.1032,  ...,  0.0032, -0.0671,  0.0743],
        [ 0.0518, -0.0986, -0.1742,  ...,  0.0522, -0.0398, -0.0576],
        [-0.0118, -0.0846,  0.0222,  ..., -0.0529,  0.1068, -0.0330]],
       device='cuda:0')
torch.Size([1000, 4000])
tensor([[[ 3.7536e-03, -1.1911e-02,  1.2482e-02,  ..., -7.4314e-03,
           2.7189e-02, -2.5323e-03],
         [ 1.6987e-01, -7.8710e-02,  6.7613e-02,  ..., -2.0458e-01,
           8.4931e-01, -5.9010e-02],
         [-1.7298e-03,  1.0715e-02, -7.2610e-03,  ..., -1.7275e-03,
          -9.7896e-03, -6.4406e-03],
         ...,
         [ 6.4307e-01,  1.1466e-01, -2.0487e-01,  ...,  3.1924e-01,
           2.2618e-01,  2.2072e-02],
         [-4.7671e-02, -4.1692e-02, -6.6076e-02,  ..., -4.1383e-02,
   

In [7]:
#New loss: 100*MSE normalized in the signal energy!

instance='X-AVG'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}_zt_noise_sigma={zt_noise_sigma_train}_cs_ratio={cs_ratio_train}_1000*normMseLoss.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)


print('Checking X-AVG:')
test=torch.from_numpy(X_dataset_test).clone().float()
test=test-test.mean(axis=2,keepdims=True)

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    print(z_hat)
    print(z_hat.shape)
    print(batch)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)


Checking X-AVG:
tensor([[ -7.5571,   0.5366,  -3.9894,  ...,   5.6508,  17.2290,   3.7099],
        [ -7.4146,   5.8784,   2.3494,  ...,   7.4280,   8.6229,   1.9846],
        [ -1.3724, -10.0131,  -5.5515,  ...,  -5.5298,   6.9166,   1.6528],
        ...,
        [ 21.7546,  12.5807,  -8.9720,  ...,   2.1904,   5.5885,  -6.7054],
        [ 10.2096,   9.1524,   9.7194,  ...,  -0.7187,  -0.5083,  -3.2920],
        [ 18.5401,  -6.7927,  -3.9267,  ..., -17.7382,  -1.9079,   3.5557]],
       device='cuda:0')
torch.Size([1000, 4000])
tensor([[[-3.8368e-01,  1.6191e-01, -1.5491e-01,  ...,  4.3826e-01,
           3.6080e-01,  1.4589e-01],
         [-2.7705e-01, -2.1619e-02, -1.8171e-01,  ...,  9.6656e-02,
           1.0499e+00, -2.4726e-02],
         [-4.9863e-01,  2.3613e-01, -2.2098e-01,  ...,  5.7358e-01,
           4.2127e-01,  1.8629e-01],
         ...,
         [-9.0415e+00,  4.1646e+00, -4.5048e+00,  ...,  1.1022e+01,
           8.1864e+00,  3.4506e+00],
         [-9.5804e+00,  4.2365e

In [8]:
#New loss: 100*MSE normalized in the signal energy!

instance='X'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}_zt_noise_sigma={zt_noise_sigma_train}_cs_ratio={cs_ratio_train}_1000*normMseLoss.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)


print('Checking X:')
test=torch.from_numpy(X_dataset_test).clone().float()

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    print(z_hat)
    print(z_hat.shape)
    print(batch)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)


Checking X:
tensor([[ -5.8300,  -0.3336,  -0.4865,  ...,  10.1690,   9.5961,  10.5005],
        [ -3.0223,  12.3749,   3.0643,  ...,   2.0424,   8.4198,  17.3993],
        [ -7.6591,   5.9494,   2.3339,  ...,   3.5761,  14.2462,   2.5563],
        ...,
        [ 20.8907,  12.6055,  -7.6679,  ...,  -1.4430,  -2.0984,  -3.2009],
        [  3.9419,  17.1858,  11.0154,  ...,  -0.2011,  -3.5432,  -0.9027],
        [ 16.2284,  -2.6353,  -4.8350,  ..., -16.4353,   8.0190,  -4.1140]],
       device='cuda:0')
torch.Size([1000, 4000])
tensor([[[-3.7653e-01,  1.6906e-01, -1.4776e-01,  ...,  4.4541e-01,
           3.6795e-01,  1.5304e-01],
         [-1.7163e-01,  8.3800e-02, -7.6290e-02,  ...,  2.0208e-01,
           1.1553e+00,  8.0694e-02],
         [-4.9115e-01,  2.4362e-01, -2.1349e-01,  ...,  5.8107e-01,
           4.2876e-01,  1.9378e-01],
         ...,
         [-8.6630e+00,  4.5431e+00, -4.1263e+00,  ...,  1.1401e+01,
           8.5649e+00,  3.8291e+00],
         [-9.4054e+00,  4.4114e+00,

In [9]:
#Old loss: MSE

instance='Z'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}_zt_noise_sigma={zt_noise_sigma_train}_cs_ratio={cs_ratio_train}.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)

print('Checking Z:')
test=torch.from_numpy(Z_dataset_test).clone().float()

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    print(z_hat)
    print(z_hat.shape)
    print(batch)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)


Checking Z:
tensor([[-0.1249,  0.0952,  0.1193,  ...,  0.0779, -0.0263,  0.3416],
        [-0.0318,  0.0803,  0.1804,  ..., -0.0012, -0.1126, -0.0065],
        [-0.0665, -0.0536, -0.1410,  ...,  0.4123, -0.1278,  0.1838],
        ...,
        [ 0.1056,  0.0884,  0.1970,  ...,  0.0818, -0.0631, -0.0585],
        [ 0.0199, -0.0091, -0.0650,  ..., -0.0739,  0.1110, -0.1121],
        [ 0.0502,  0.0187, -0.0360,  ..., -0.0102,  0.3954,  0.0788]],
       device='cuda:0')
torch.Size([1000, 4000])
tensor([[[ 3.7536e-03, -1.1911e-02,  1.2482e-02,  ..., -7.4314e-03,
           2.7189e-02, -2.5323e-03],
         [ 1.6987e-01, -7.8710e-02,  6.7613e-02,  ..., -2.0458e-01,
           8.4931e-01, -5.9010e-02],
         [-1.7298e-03,  1.0715e-02, -7.2610e-03,  ..., -1.7275e-03,
          -9.7896e-03, -6.4406e-03],
         ...,
         [ 6.4307e-01,  1.1466e-01, -2.0487e-01,  ...,  3.1924e-01,
           2.2618e-01,  2.2072e-02],
         [-4.7671e-02, -4.1692e-02, -6.6076e-02,  ..., -4.1383e-02,
   

In [11]:
#Old loss: MSE

instance='X-AVG'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)

print('Checking X-AVG:')
test=torch.from_numpy(X_dataset_test).clone().float()
test=test-test.mean(axis=2,keepdims=True)

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    print(z_hat)
    print(z_hat.shape)
    print(batch)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)


Checking X-AVG:
tensor([[ -4.0496,   4.8127,  -3.0249,  ...,   8.2124,  12.0647,   3.5035],
        [  3.7079,   5.6865,   5.8030,  ...,  -7.3220,  17.4591,   3.0042],
        [ -2.3856,  -4.0601,  -7.7825,  ...,  10.4627,  -0.6161, -10.9773],
        ...,
        [ 15.7657,  16.9717,  -5.7677,  ...,   0.0559,   2.2051,  -0.2387],
        [ -8.1848,  10.5842,  -1.2596,  ...,   2.6536,  -4.1896,  -2.7264],
        [  7.9342, -11.4847,  -5.2350,  ..., -12.8515,  -0.7379,  -4.0584]],
       device='cuda:0')
torch.Size([1000, 4000])
tensor([[[-3.8368e-01,  1.6191e-01, -1.5491e-01,  ...,  4.3826e-01,
           3.6080e-01,  1.4589e-01],
         [-2.7705e-01, -2.1619e-02, -1.8171e-01,  ...,  9.6656e-02,
           1.0499e+00, -2.4726e-02],
         [-4.9863e-01,  2.3613e-01, -2.2098e-01,  ...,  5.7358e-01,
           4.2127e-01,  1.8629e-01],
         ...,
         [-9.0415e+00,  4.1646e+00, -4.5048e+00,  ...,  1.1022e+01,
           8.1864e+00,  3.4506e+00],
         [-9.5804e+00,  4.2365e

In [12]:
#Old loss: MSE

instance='X'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)

print('Checking X:')
test=torch.from_numpy(X_dataset_test).clone().float()

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    print(z_hat)
    print(z_hat.shape)
    print(batch)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)


Checking X:
tensor([[  2.0955,   1.4015,  -2.3707,  ...,  12.9838,  13.3361,  11.3372],
        [  5.8594,  13.4893,  10.6718,  ...,   5.0203,  16.1126,   5.0551],
        [ -1.8416,  -5.8759, -10.0158,  ...,   1.7039,  18.1148,   3.5137],
        ...,
        [ 13.3895,  12.0785, -15.1087,  ...,   0.9298,  -7.3957,   2.8321],
        [  2.4183,  20.9959,  11.6366,  ...,   0.6142,  -1.1268,   2.7084],
        [ 10.5699,  -6.9245,  -8.1564,  ..., -14.9625,  -6.0427,  -0.6422]],
       device='cuda:0')
torch.Size([1000, 4000])
tensor([[[-3.7653e-01,  1.6906e-01, -1.4776e-01,  ...,  4.4541e-01,
           3.6795e-01,  1.5304e-01],
         [-1.7163e-01,  8.3800e-02, -7.6290e-02,  ...,  2.0208e-01,
           1.1553e+00,  8.0694e-02],
         [-4.9115e-01,  2.4362e-01, -2.1349e-01,  ...,  5.8107e-01,
           4.2876e-01,  1.9378e-01],
         ...,
         [-8.6630e+00,  4.5431e+00, -4.1263e+00,  ...,  1.1401e+01,
           8.5649e+00,  3.8291e+00],
         [-9.4054e+00,  4.4114e+00,

In [5]:
#Normalize object of training beafore training

sig1vt_supp=10
zt_noise_sigma=0.01
rows=400
cols=10


sig1vt_supp_train=10
zt_noise_sigma_train=0.01
cs_ratio_train=0.25
norm_data='Yes'
k_sparse=5

instance='Z'
file_name = f"DECONET(synthetic our data)-10L-4000-red4000-lr0.0001-mu100-initkaiming-datasets_{instance}_400x10_ksparse=5%_sig1vt_supp={sig1vt_supp_train}_zt_noise_sigma={zt_noise_sigma_train}_cs_ratio={cs_ratio_train}_normalizeDataset={norm_data}.pt"
state_dict = torch.load('/home-sipl/prj7482/'+file_name)
A=state_dict['A']
enc = ENCODER(A)

checkpoint_name = f"save_datasets_{rows}x{cols}_ksparse={k_sparse}%_sig1vt_supp={sig1vt_supp}_zt_noise_sigma={zt_noise_sigma}.pt"
loaded_datasets = torch.load('/home-sipl/prj7482/'+checkpoint_name)
mean=loaded_datasets['mean_Z']
std=loaded_datasets['std_Z']
print('mean ',mean)
print('std ',std)

len_test = 1000
rows = ambient
cols = N_SENSORS
data = OD.Data(rows=rows, cols=cols, sig1vt_supp=sig1vt_supp, k_sparse=5,zt_noise_sigma=zt_noise_sigma)
X_dataset_test, C_dataset_test, Z_dataset_test,Zt_dataset_test = data.create_Dataset_save_tag(len_test)

print('Checking Z:')
test=torch.from_numpy(Z_dataset_test).clone().float()
test=(test-mean)/std

test_loader = DataLoader(
    test,
    num_workers=2,
    batch_size=len_test,
    shuffle=False,
)

criterion = nn.MSELoss()
for i,batch in enumerate(test_loader):
    batch=batch.to(DEVICE)
    y, min_x, max_x = enc.forward(batch)
    phi = state_dict['phi']
    dec = Decoder(A=A, y=y, mu=mu, phi=phi, min_x=min_x, max_x=max_x)
    z_hat = dec.forward(y)
    mse = criterion(z_hat, batch.view(batch.size(0), -1))
    normal_mse_x=4000*std**2*mse/(torch.norm(torch.from_numpy(X_dataset_test), p='fro')**2)
    normal_mse_z=4000*mse/(torch.norm(batch, p='fro')**2)
 #   normal_mse_z=4000*mse/(torch.norm(batch*std+mean, p='fro')**2)
    print('mse: ',mse)
    print('normalized by X mse:', normal_mse_x)
    print('normalized by Z mse:', normal_mse_z)



  state_dict = torch.load('/home-sipl/prj7482/'+file_name)
  loaded_datasets = torch.load('/home-sipl/prj7482/'+checkpoint_name)


mean  0.008597723
std  0.20495129
Checking Z:
mse:  tensor(2.0297, device='cuda:0')
normalized by X mse: tensor(2.0356e-06, device='cuda:0', dtype=torch.float64)
normalized by Z mse: tensor(0.0481, device='cuda:0')
