In [None]:
#%% s0
rid = 's0' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%

fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s1
rid = 's1' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-2):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s2
rid = 's2' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s3
rid = 's3' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s4
rid = 's4' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s5
rid = 's5' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-2):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s6
rid = 's6' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s7
rid = 's7' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s8
rid = 's8' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s9
rid = 's9' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-2):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s10
rid = 's10' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s11
rid = 's11' # running id
from math import ceil
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtsbd(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtsbd
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s12
rid = 's12' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv_less(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv_less
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s13
rid = 's13' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv_less(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-2):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv_less
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s14
rid = 's14' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv_less(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv_less
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s15
rid = 's15' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv_less(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv_less
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s4_
rid = 's4_' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s5_
rid = 's5_' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-2):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s6_
rid = 's6_' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s7_
rid = 's7_' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s12_
rid = 's12_' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv_less(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv_less
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s15_
rid = 's15_' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv_less(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 150

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_gtupconv_less
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s16
rid = 's16' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_upconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
        # Estimate H 
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        ch = torch.pi*torch.arange(self.M, device=x.device)
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
                # inp = (inp - tmp.squeeze().permute(2,3,0,1)).detach()

            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_upconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s17
rid = 's17' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_upconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
        # Estimate H 
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        ch = torch.pi*torch.arange(self.M, device=x.device)
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                # inp = inp - tmp.squeeze().permute(2,3,0,1)
                inp = (inp - tmp.squeeze().permute(2,3,0,1)).detach()

            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_upconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s16_1
rid = 's16_1' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_upconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
        # Estimate H 
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Reshape(-1, 12*12),
            LinearBlock(12*12, 64),
            nn.Linear(64, 1),
            )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        ch = torch.pi*torch.arange(self.M, device=x.device)
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
                # inp = (inp - tmp.squeeze().permute(2,3,0,1)).detach()

            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_upconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        if epoch <=100:
            l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3)
        else:
            l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3*epoch)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s16_2
rid = 's16_2' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_upconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
        # Estimate H 
        self.est1 = nn.Linear(100, 1)
        self.est2 = nn.Sequential(
            nn.Linear(100, 1),
            nn.Tanh()
        )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        ch = torch.arange(self.M, device=x.device).to(torch.float) #not *torch.pi here
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
                # inp = (inp - tmp.squeeze().permute(2,3,0,1)).detach()

            r_ang = self.est2(self.est1(inp.real).squeeze()) #vj,Rb,ang
            i_ang = self.est2(self.est1(inp.real).squeeze())
            ang = (r_ang+1j*i_ang).squeeze().mean(-1).angle()
            hj = ((ang[...,None]@ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_upconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        if epoch <=100:
            l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3)
        else:
            l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3*epoch)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s16_3
rid = 's16_3' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_upconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
        # Estimate H 
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        ch = torch.pi*torch.arange(self.M, device=x.device)
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
                # inp = (inp - tmp.squeeze().permute(2,3,0,1)).detach()

            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_upconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        if epoch <=100:
            l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3)
        else:
            l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3*epoch)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s16_cont
rid = 's16_cont' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_upconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
        # Estimate H 
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        I = x.shape[0]
        "Neural nets for H,V"
        shat_all = []
        ch = torch.pi*torch.arange(self.M, device=x.device)
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
                # inp = (inp - tmp.squeeze().permute(2,3,0,1)).detach()

            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ inp.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()
            shat_all.append(shat)
        
        for i in range(self.K):
            "Encoder"
            xx = self.encoder(shat_all[i][:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real, beta*kl

#%%
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 2000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_upconv
model = NN(M,K,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% v1 
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v1'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]

        "Calc gradient and update"
        loss = ((Is-rx_inv_hat@rx).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            loss_val = ((Is2-rx_inv_hat@rx_val_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3]@rx_val_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
"v2 For now, I need to verify that the covariance inverse works"
#%% 
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v2'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]

        "Calc gradient and update"
        loss = ((Is-rx_inv_hat@rx).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            loss_val = ((Is2-rx_inv_hat@rx_val_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3]@rx_val_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%%v3
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v3'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]

        "Calc gradient and update"
        loss = ((Is-rx_inv_hat@rx).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            loss_val = ((Is2-rx_inv_hat@rx_val_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3]@rx_val_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%%v4 
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v4'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
        rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2

        "Calc gradient and update"
        loss = ((Is-rx_inv_hat@rx).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2
            loss_val = ((Is2-rx_inv_hat@rx_val_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3]@rx_val_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v5
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v5'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
        rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2

        "Calc gradient and update"
        loss = ((Is-rx_inv_hat@rx).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2
            loss_val = ((Is2-rx_inv_hat@rx_val_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3]@rx_val_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v6
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v6'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
        rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2

        "Calc gradient and update"
        loss = ((Is-rx_inv_hat@rx).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2
            loss_val = ((Is2-rx_inv_hat@rx_val_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3]@rx_val_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%%v7
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v7'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_inv_cuda = rx_val.inverse().cuda()

#%%
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
        rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2

        "Calc gradient and update"
        loss = ((rx_inv_hat-rx.inverse()).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2
            loss_val = ((rx_inv_hat - rx_inv_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3], '\n',rx_inv_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%%v8
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v8'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128,128),
            # nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_inv_cuda = rx_val.inverse().cuda()

#%%
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
        rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2

        "Calc gradient and update"
        loss = ((rx_inv_hat-rx.inverse()).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2
            loss_val = ((rx_inv_hat - rx_inv_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3], '\n',rx_inv_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%%v9
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v9'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%%
M = 3
dicts = sio.loadmat('../data/nem_ss/v2.mat')
v0 = dicts['v'][..., 0]
v1 = dicts['v'][..., 1]
from skimage.transform import resize
v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
plt.imshow(v0.abs())
plt.colorbar()
v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
v1 = awgn(v1, snr=30, seed=0).abs().to(torch.cfloat)

snr=0; n_data=int(1.1e4)
delta = v0.mean()*10**(-snr/10)
angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
h = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

signal = (h[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
# n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:]).reshape(n_data, M, 100, 100)
n1 = hs_n1[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5 
n2 = hs_n2[...,None] @ torch.randn(1, torch.tensor(v0.shape).prod(), dtype=torch.cfloat)*delta**0.5
mix =  signal + n1.reshape(n_data, M, 100, 100) + n2.reshape(n_data, M, 100, 100)   
mix_all, sig_all = mix.permute(0,2,3,1), signal.permute(0,2,3,1)

# torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
plt.figure()
plt.imshow(mix_all[0,:,:,0].abs())
plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.rx_net = nn.Sequential(
            nn.Linear(18,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,18),
            nn.ReLU(inplace=True),
            nn.Linear(18,18)
        )
    def forward(self, x):
        rx_inv = self.rx_net(x)

        return rx_inv 
model = DOA().cuda()

#%%
h_all, M = h, sig_all.shape[-1]
n_tr = int(1e4)
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], sig_all[:n_tr], h_all[:n_tr])
val0 = mix_all[n_tr:n_tr+200]
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_inv_cuda = rx_val.inverse().cuda()

#%%
inp_val = torch.stack((rx_val.real, rx_val.imag), dim=1).reshape(rx_val.shape[0], -1).cuda()
tr = Data.DataLoader(data, batch_size=32, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)
Is = torch.stack([torch.eye(3,3)]*32, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, sig, _) in enumerate(tr):
        optimizer.zero_grad()

        "prepare for input to the NN"
        mix = mix.cuda()
        sig = sig.cuda()
        Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
        rx = Rx.mean(dim=(1,2))
        inp = torch.stack((rx.real, rx.imag), dim=1).reshape(mix.shape[0], -1)
        rx_raw = model(inp)
        rx_reshape = rx_raw.reshape(mix.shape[0],M,M,2)
        rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
        rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2

        "Calc gradient and update"
        loss = ((rx_inv_hat-rx.inverse()).abs()**2).mean()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 20 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.savefig(fig_loc + f'last_100_at_epoch{epoch}')

        with torch.no_grad():
            rx_raw = model(inp_val)
            rx_reshape = rx_raw.reshape(inp_val.shape[0],M,M,2)
            rx_inv_hat = rx_reshape[...,0] + 1j*rx_reshape[...,1]
            rx_inv_hat = (rx_inv_hat + rx_inv_hat.conj().permute(0,2,1))/2
            loss_val = ((rx_inv_hat - rx_inv_cuda).abs()**2).mean()
            loss_val_all.append(loss_val.cpu().item())
            print('epoch', epoch, rx_inv_hat[:3], '\n',rx_inv_cuda[:3])
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')           
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')

print('done')




In [None]:
#%% v10
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v10'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()
                    
            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v11
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v11'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v12
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v12'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v13
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v13'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v14
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v14'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 0.1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v15
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v15'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 10
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%%  v16
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v16'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,6)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        h6 = self.hnet(m)
        rx12 = self.rxnet(m)

        return h6, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                with torch.no_grad():
                    w = rx_inv_hat@hhat[...,None] / \
                            (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                    p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            h6, rx12 = model(inp)
            hhat = h6[:,:3] + 1j*h6[:,3:]
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                h6, rx12 = model(inp)
                hhat_val = h6[:,:3] + 1j*h6[:,3:]
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v17
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v17'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,1)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        ang = self.hnet(m)
        rx12 = self.rxnet(m)
        ch = torch.pi*torch.arange(3, device=x.device)
        hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
        return hj, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            hhat, rx12 = model(inp)
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                hhat_val, rx12 = model(inp)
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v18
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v18'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,1)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        ang = self.hnet(m)
        rx12 = self.rxnet(m)
        ch = torch.pi*torch.arange(3, device=x.device)
        hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
        return hj, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 0.1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            hhat, rx12 = model(inp)
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                hhat_val, rx12 = model(inp)
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v19
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v19'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,1)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        ang = self.hnet(m)
        rx12 = self.rxnet(m)
        ch = torch.pi*torch.arange(3, device=x.device)
        hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
        return hj, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 10
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                w = rx_inv_hat@hhat[...,None] / \
                        (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            hhat, rx12 = model(inp)
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                hhat_val, rx12 = model(inp)
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')




In [None]:
#%% v20
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
torch.autograd.set_detect_anomaly(True)

rid = 'v20'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'


#%% generate data
if True:
    M, n_data= 3, int(1.2e4)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    v0 = torch.tensor(resize(v0, (100, 100), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = torch.tensor(resize(v1, (100, 100), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = torch.tensor(resize(v2, (100, 100), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    snr=0
    delta = v0.mean()*10**(-snr/10)
    angs = (torch.rand(n_data,1)*20 +10)/180*np.pi  # signal aoa [10, 30]
    H = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n1 = (torch.rand(n_data,1)*20 -70)/180*np.pi  # noise aoa [-70, -50]
    hs_n1 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    angs_n2 = (torch.rand(n_data,1)*20 +120)/180*np.pi  # noise aoa [120, 140]
    hs_n2 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    signal = (H[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    n1 = (hs_n1[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    n2 = (hs_n2[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    mix =  (signal + n1 + n2).reshape(n_data, M, 100, 100)   
    mix_all, sig_all = mix.permute(0,2,3,1), signal.reshape(n_data, M, 100, 100).permute(0,2,3,1)
    mixn = awgn_batch(mix_all)
    H_all = torch.stack((H, hs_n1, hs_n2), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(10000, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)


#%% load data and model
class DOA(nn.Module):
    def __init__(self):
        super().__init__()
        self.mainnet = nn.Sequential(
            nn.Linear(12,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.ReLU(inplace=True),
        )
        self.hnet = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,1)
        )
        self.rxnet = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 12)
        )
    def forward(self, x):
        m = self.mainnet(x)
        ang = self.hnet(m)
        rx12 = self.rxnet(m)
        ch = torch.pi*torch.arange(3, device=x.device)
        hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
        return hj, rx12 

def lower2matrix(rx12):
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[0,0], rx_inv_hat[1,1], rx_inv_hat[2,2] = \
        rx_inv_hat[0,0]/2, rx_inv_hat[1,1]/2, rx_inv_hat[2,2]/2
    return rx_inv_hat

model = DOA().cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

#%%
M, btsize, n_tr = 3, 64, int(1e4)
lamb = 1
const_range = torch.arange(M).to(torch.cfloat)[None,:].cuda()
data = Data.TensorDataset(mix_all[:n_tr], H_all[:n_tr])
tr = Data.DataLoader(data, batch_size=btsize, drop_last=True, shuffle=True)
optimizer = RAdam(model.parameters(),
                lr= 1e-4,
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

"Pre calc constants"
Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(torch.cfloat).cuda()
Is2 = torch.stack([torch.eye(3,3)]*200, dim=0).to(torch.cfloat).cuda()
ind = torch.tril_indices(3,3)
indx = ind[0]*3+ind[1]

"validation"
val0, h0 = mix_all[n_tr:n_tr+200], H_all[n_tr:n_tr+200].cuda()
rx_val = (val0[...,None] @ val0[:,:,:,None,:].conj()).mean(dim=(1,2))
rx_val_cuda = rx_val.cuda()

loss_all, loss_val_all = [], []
for epoch in range(2001):
    for i, (mix, H) in enumerate(tr):
        loss = 0
        optimizer.zero_grad()
        mix, H = mix.cuda(), H.cuda()
        for j in range(3): # recursive for each source
            if j == 0:
                Rx = mix[...,None] @ mix[:,:,:,None,:].conj()
                rx = Rx.mean(dim=(1,2))
            else:
                with torch.no_grad():
                    w = rx_inv_hat@hhat[...,None] / \
                            (hhat[:,None,:].conj()@rx_inv_hat@hhat[...,None])
                    p = Is - hhat[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

            inp = torch.stack((rx.real, rx.imag), dim=1)[:,:,ind[0], ind[1]].reshape(btsize, -1)
            hhat, rx12 = model(inp)
            rx_inv_hat = lower2matrix(rx12)
            loss = loss + ((Is-rx_inv_hat@rx).abs()**2).mean() + \
                lamb*((H[:,:,j]-hhat).abs()**2).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
        if i % 30 == 0:
            loss_all.append(loss.detach().cpu().item())
        
    if epoch % 10 == 0:
        plt.figure()
        plt.plot(loss_all, '-x')
        plt.title(f'epoch {epoch}')
        plt.show()

        if epoch >50:
            plt.figure()
            plt.plot(loss_all[-100:], '-rx')
            plt.title(f'epoch {epoch}')
            plt.show()

        "Validation"
        with torch.no_grad():
            loss_val = 0
            for j in range(3): # recursive for each source
                if j ==0 :
                    rx = rx_val_cuda
                else:
                    w = rx_inv_hat_val@hhat_val[...,None] / \
                            (hhat_val[:,None,:].conj()@rx_inv_hat_val@hhat_val[...,None])
                    p = Is2 - hhat_val[...,None]@w.permute(0,2,1).conj()
                    rx = p@rx@p.permute(0,2,1).conj()

                rl = rx.reshape(200, -1)[:,indx] # take lower triangle
                inp = torch.stack((rl.real, rl.imag), dim=1).reshape(200, -1)
                hhat_val, rx12 = model(inp)
                rx_inv_hat_val = lower2matrix(rx12)

                loss_val = loss_val + ((Is2-rx_inv_hat_val@rx).abs()**2).mean() + \
                    lamb*((h0[:,:,j]-hhat_val).abs()**2).mean()

            loss_val_all.append(loss_val.cpu().item())
            plt.figure()
            plt.plot(loss_val_all, '-x')
            plt.title(f'val loss at epoch {epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_validation_loss')

            plt.figure()
            plt.plot(loss_val_all[-50:], '-rx')
            plt.title(f'last 50 val loss epoch {epoch}')
            plt.savefig(fig_loc + f'last_50val_at_epoch{epoch}')
            plt.close('all')   
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


