In [None]:
#%% v20000
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
#%%
torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20000' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128)
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


In [None]:
#%% v20010
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
#%%
torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20010' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.1)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v20100
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20100' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-4
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = torch.optim.Adam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v20110
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20110' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = torch.optim.Adam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v20200
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
#%%
torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20200' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-4
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = torch.optim.RMSprop(model.parameters(),
                lr= opts['lr'])

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v20210
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
#%%
torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20210' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = torch.optim.RMSprop(model.parameters(),
                lr= opts['lr'])

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v20220
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
#%%
torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20220' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-4
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.1)
optimizer = torch.optim.RMSprop(model.parameters(),
                lr= opts['lr'])

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v20230
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam
#%%
torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((Rx, Rs, Hhat, Rb), f'rid{rid}_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

rid = 'v20230' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.1)
optimizer = torch.optim.RMSprop(model.parameters(),
                lr= opts['lr'])

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'modle_epoch{epoch}.pt')
print('done')
# %%


In [None]:
#%% v21000
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import NN7 as NN
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v21000' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128)
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v22000
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN8(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.v_net = nn.Sequential(
            DoubleConv(in_channels=M*2, out_channels=1),
            ) 
        self.v_out = OutConv(in_channels=1, out_channels=1)
        self.hb_net = nn.Sequential(
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Reshape(-1, 12*12),
            )
        # Estimate H
        self.h_net = nn.Sequential(
            LinearBlock(12*12, 64),
            LinearBlock(64, 32),
            nn.Linear(32, 1),
            nn.Tanh()
            )   
        # Estimate Rb
        self.b_net = nn.Sequential(
            LinearBlock(12*12, 64),
            LinearBlock(64, 32),
            nn.Linear(32, 1),
            )   
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                temp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - temp.squeeze().permute(2,3,0,1)
            temp = self.v_net(torch.cat((inp.real, inp.imag), dim=1)).exp() 
            vj = self.v_out(temp).exp() #sigma_s**2 >=0
            vj = threshold(vj, floor=1e-3, ceiling=1e3)  # shape of [I, 1, N, F]
            hb = self.hb_net(vj)
            ang = self.h_net(hb)  # shape of [I,1]
            sig_b_squared = self.b_net(hb).exp() # shape of [I,1]
            "Get H"
            ch = torch.pi*torch.arange(self.M, device=ang.device)
            hj = ((ang @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Get Rb, the energy of the rest"
            Rb = threshold(sig_b_squared[:,:,None]**2, 1e-3, 1e2)*torch.ones(batch_size, \
                self.M, device=ch.device).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rs = vj.permute(2,3,0,1)[..., None].to(torch.cfloat)  #shape of [N,F,I,1,1]
            Rx = hj[...,None] @ Rs @ hj[:,None].conj() + Rb # shape of [N,F,I,M,M]
            W = Rs @ hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl


#%%
rid = 'v22000' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 100
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN8
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch}', Rb[0], Rb.sum()/3/128)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


In [None]:
#%% v22100
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN8(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.v_net = nn.Sequential(
            DoubleConv(in_channels=M*2, out_channels=1),
            ) 
        self.v_out = OutConv(in_channels=1, out_channels=1)
        self.hb_net = nn.Sequential(
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Reshape(-1, 12*12),
            )
        # Estimate H
        self.h_net = nn.Sequential(
            LinearBlock(12*12, 64),
            LinearBlock(64, 32),
            nn.Linear(32, 1),
            nn.Tanh()
            )   
        # Estimate Rb
        self.b_net = nn.Sequential(
            LinearBlock(12*12, 64),
            LinearBlock(64, 32),
            nn.Linear(32, 1),
            )   
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                temp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - temp.squeeze().permute(2,3,0,1)
            temp = self.v_net(torch.cat((inp.real, inp.imag), dim=1)).exp() 
            vj = self.v_out(temp).exp() #sigma_s**2 >=0
            vj = threshold(vj, floor=1e-3, ceiling=1e3)  # shape of [I, 1, N, F]
            hb = self.hb_net(vj)
            ang = self.h_net(hb)  # shape of [I,1]
            sig_b_squared = self.b_net(hb).exp() # shape of [I,1]
            "Get H"
            ch = torch.pi*torch.arange(self.M, device=ang.device)
            hj = ((ang @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Get Rb, the energy of the rest"
            Rb = threshold(sig_b_squared[:,:,None]**2, 1e-3, 1e2)*torch.ones(batch_size, \
                self.M, device=ch.device).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rs = vj.permute(2,3,0,1)[..., None].to(torch.cfloat)  #shape of [N,F,I,1,1]
            Rx = hj[...,None] @ Rs @ hj[:,None].conj() + Rb # shape of [N,F,I,M,M]
            W = Rs @ hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl


#%%
rid = 'v22100' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 100
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN8
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(f'epoch{epoch}', Rb[0], Rb.sum()/3/128)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


In [None]:
#%% v23000
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *

class NN9(nn.Module):
    """This is recursive Wiener filter version
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.v_net = nn.Sequential(
            DoubleConv(in_channels=M*2, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            ) 
        self.v_out = OutConv(in_channels=4, out_channels=1)
        self.hb_net = nn.Sequential(
            Down(in_channels=1, out_channels=32),
            Down(in_channels=32, out_channels=16),
            Down(in_channels=16, out_channels=8),
            Reshape(-1, 8*12*12),
            )
        # Estimate H
        self.h_net = nn.Sequential(
            LinearBlock(8*12*12, 64),
            nn.Linear(64, 32),
            nn.Linear(32, 1),
            nn.Tanh()
            )   
        # Estimate Rb
        self.b_net = nn.Sequential(
            LinearBlock(8*12*12, 64),
            nn.Linear(64, 32),
            nn.Linear(32, 1),
            )   
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            temp = self.v_net(torch.cat((inp.real, inp.imag), dim=1)).exp() 
            vj = self.v_out(temp).exp() #sigma_s**2 >=0
            vj = threshold(vj, floor=1e-3, ceiling=1e2)  # shape of [I, 1, N, F]
            hb = self.hb_net(vj)
            ang = self.h_net(hb)  # shape of [I,1]
            sig_b_squared = self.b_net(hb).exp() # shape of [I,1]
            "Get H"
            ch = torch.pi*torch.arange(self.M, device=ang.device)
            hj = ((ang @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Get Rb, the energy of the rest"
            Rb = sig_b_squared[:,:,None]*torch.ones(batch_size, \
                self.M, device=ch.device).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rs = vj.mean(dim=(2,3)).to(torch.cfloat)[..., None]  #shape of [I,1,1]
            Rx = hj[...,None] @ Rs @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = Rs @ hj[:, None,].conj() @ Rx.inverse()  # shape of [I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23000' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN9
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v23100
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *

class NN10(nn.Module):
    """This is recursive Wiener filter version
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.v_net = nn.Sequential(
            DoubleConv(in_channels=M*2, out_channels=1),
            ) 
        self.v_out = OutConv(in_channels=1, out_channels=1)
        self.hb_net = nn.Sequential(
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Down(in_channels=1, out_channels=1),
            Reshape(-1, 12*12),
            )
        # Estimate H
        self.h_net = nn.Sequential(
            LinearBlock(12*12, 64),
            LinearBlock(64, 32),
            nn.Linear(32, 1),
            nn.Tanh()
            )   
        # Estimate Rb
        self.b_net = nn.Sequential(
            LinearBlock(12*12, 64),
            LinearBlock(64, 32),
            nn.Linear(32, 1),
            )   
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            temp = self.v_net(torch.cat((inp.real, inp.imag), dim=1)).exp() 
            vj = self.v_out(temp).exp() #sigma_s**2 >=0
            vj = threshold(vj, floor=1e-3, ceiling=1e2)  # shape of [I, 1, N, F]
            hb = self.hb_net(vj)
            ang = self.h_net(hb)  # shape of [I,1]
            sig_b_squared = self.b_net(hb).exp() # shape of [I,1]
            "Get H"
            ch = torch.pi*torch.arange(self.M, device=ang.device)
            hj = ((ang @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Get Rb, the energy of the rest"
            Rb = sig_b_squared[:,:,None]*torch.ones(batch_size, \
                self.M, device=ch.device).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            "Wienter filter to get coarse shat"
            Rs = vj.mean(dim=(2,3)).to(torch.cfloat)[..., None]  #shape of [I,1,1]
            Rx = hj[...,None] @ Rs @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = Rs @ hj[:, None,].conj() @ Rx.inverse()  # shape of [I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23100' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN10
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v23200
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *

class NN11(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=8),
            Reshape(-1, 8*12*12),
            LinearBlock(8*12*12, 64),
            nn.Linear(64, 3),
            )   
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            res = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang
            vj = threshold(res[:, 0:1].exp(), floor=1e-3, ceiling=1e2)
            sb = threshold(res[:, 1:2].exp(), floor=1e-3, ceiling=1e2)
            Rb = (sb*torch.ones(batch_size, self.M, \
                device=sb.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            ch = torch.pi*torch.arange(self.M, device=res.device)
            hj = ((res[:, 2:].tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rs = vj[..., None].to(torch.cfloat)  #shape of [I,1,1]
            Rx = hj[...,None] @ Rs @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = Rs @ hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23200' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = d = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN11
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)))
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v23300
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN12(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        self.b1 = nn.Linear(100, 1)
        self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang

            sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            Rb = (sb[:None]*torch.ones(batch_size, self.M, \
                device=sb.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            ch = torch.pi*torch.arange(self.M, device=inp.device)
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23300' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN12
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128, hh)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v23310
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN12(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        self.b1 = nn.Linear(100, 1)
        self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang

            sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            Rb = (sb[:None]*torch.ones(batch_size, self.M, \
                device=sb.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            ch = torch.pi*torch.arange(self.M, device=inp.device)
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23310' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN12
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128, hh)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v23400
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN13(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        self.b1 = nn.Linear(100, 1)
        self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang

            sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            Rb = (sb[:None]*torch.ones(batch_size, self.M, \
                device=sb.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            ch = torch.pi*torch.arange(self.M, device=inp.device)
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23400' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN13
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128, hh)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v23410
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN13(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        self.b1 = nn.Linear(100, 1)
        self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang

            sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            Rb = (sb[:None]*torch.ones(batch_size, self.M, \
                device=sb.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            ch = torch.pi*torch.arange(self.M, device=inp.device)
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
        
            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v23410' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN13
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128, hh)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% v24000
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN14(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        self.est = nn.Sequential(
            Down(in_channels=M*2, out_channels=64),
            Down(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=4),
            Reshape(-1, 4*12*12),
            LinearBlock(4*12*12, 64),
            nn.Linear(64, 1),
            )
        self.b1 = nn.Linear(100, 1)
        self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)
            ang = self.est(torch.cat((inp.real, inp.imag), dim=1)) #vj,Rb,ang

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]

            ch = torch.pi*torch.arange(self.M, device=inp.device)
            hj = ((ang.tanh() @ ch[None,:])*1j).exp() # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar
    
def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 'v24000' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 3000 # how many samples
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 1500

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
xval_cuda = xval[:128].to(torch.cfloat).cuda()

loss_iter, loss_tr, loss_eval = [], [], []
NN = NN14
model = NN(M,K,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)
optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(epoch, Rb[0], Rb.sum()/3/128, hh)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt)), '\n')
            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

# %%


In [None]:
#%% t2 -- test sbd with fewer layers, one batch, 
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtdecoder(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 't2' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-3
opts['n_epochs'] = 3000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtdecoder
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%20 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% t2_more -- test sbd with more layers, one batch,
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtdecoder(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M
         
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            DoubleConv(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 't2_more' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-3
opts['n_epochs'] = 3000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtdecoder
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%20 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% t3 -- upconv with fewer layers, one batch,
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            Up_(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar


def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
#@title
rid = 't3' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-3
opts['n_epochs'] = 3000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%20 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% t3_more --upconv with more layers, one batch,
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtupconv(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar


def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 't3_more' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-3
opts['n_epochs'] = 3000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtupconv
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%20 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% t2_05 -- test sbd with fewer layers, one batch, beta=0.5(t2 as 1e-3)
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtdecoder(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).exp()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 't2_05' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-3
opts['n_epochs'] = 2000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtdecoder
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%40 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% t2_05sig -- test sbd with fewer layers, one batch, beta=0.5(t2 as 1e-3), sigmoid*10
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtdecoder(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).sigmoid()*10
            v_all.append(threshold(v, floor=1e-3)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 't2_05sig' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-3
opts['n_epochs'] = 2000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtdecoder
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%40 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% t2_05sig_e-4 -- test sbd with fewer layers, one batch, beta=0.5(t2 as 1e-3), sigmoid*10, lr=1e-4
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_gtdecoder(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.K, self.M = K, M

        # Estimate H and coarse V
        # self.est = nn.Sequential(
        #     Down(in_channels=M*2, out_channels=64),
        #     Down(in_channels=64, out_channels=32),
        #     Down(in_channels=32, out_channels=4),
        #     Reshape(-1, 4*12*12),
        #     LinearBlock(4*12*12, 64),
        #     nn.Linear(64, 1),
        #     )
        # self.b1 = nn.Linear(100, 1)
        # self.b2 = nn.Linear(100, 1)
           
        # Estimate V using auto encoder
        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            Down(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            DoubleConv(in_channels=self.dz+2, out_channels=64),
            DoubleConv(in_channels=64, out_channels=16),
            OutConv(in_channels=16, out_channels=1),
            ) 
        self.im_size = im_size
        x = torch.linspace(-1, 1, im_size)
        y = torch.linspace(-1, 1, im_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # Add as constant, with extra dims for N and C
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        batch_size, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []

        "Neural nets for H,V"
        for i in range(self.K):
            if i == 0:
                inp = x
            else:
                tmp = hj[...,None]@W@inp.permute(2,3,0,1)[...,None]
                inp = inp - tmp.squeeze().permute(2,3,0,1)

            # sb = self.b2(self.b1(inp.abs()).squeeze()).mean(dim=1).exp()
            # Rb = (sb[:None]*torch.ones(batch_size, self.M, \
            #     device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            Rb = (1.4e-3*torch.ones(batch_size, self.M, \
                device=x.device)).diag_embed().to(torch.cfloat) # shape:[I, M, M]
            hj = hgt[:,i].repeat(I).reshape(I,3) # shape:[I, M]
            h_all.append(hj)

            "Wienter filter to get coarse shat"
            Rx = hj[...,None] @ hj[:,None].conj() + Rb # shape of [I,M,M]
            W = hj[:, None,].conj() @ Rx.inverse()  # shape of [N,F,I,1,M]
            shat = (W @ x.permute(2,3,0,1)[...,None]).squeeze().permute(2,0,1) #[I, N, F]
            shat = shat/shat.detach().abs().max()

            "Encoder"
            xx = self.encoder(shat[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(batch_size,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            # View z as 4D tensor to be tiled across new N and F dimensions            
            zr = z.view((batch_size, self.dz)+ (1, 1))  #Shape: IxDxNxF
            # Tile across to match image size
            zr = zr.expand(-1, -1, self.im_size, self.im_size)  #Shape: IxDx64x64
            # Expand grids to batches and concatenate on the channel dimension
            zbd = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1),
                        self.y_grid.expand(batch_size, -1, -1, -1), zr), dim=1) # Shape: Ix(dz*K+2)xNxF
            v = self.decoder(zbd).sigmoid()*10
            v_all.append(threshold(v, floor=1e-3)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.5):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.sum().real + beta*kl

#%%
rid = 't2_05sige-4' # running id
fig_loc = '../data/nem_ss/figures/'
mod_loc = '../data/nem_ss/models/'
if not(os.path.isdir(fig_loc + f'/rid{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'rid{rid}/')
    os.mkdir(mod_loc + f'rid{rid}/')
fig_loc = fig_loc + f'rid{rid}/'
mod_loc = mod_loc + f'rid{rid}/'

I = 30 # how many samples ------------------------------------ TODO
M, N, F, K = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = I # how many samples ------------------------------------ TODO
opts['lr'] = 1e-4
opts['n_epochs'] = 2000

d = torch.load('../data/nem_ss/tr3kM3FT100.pt')
d = awgn_batch(d[:I], snr=30, seed=1) # ------------------------------------ TODO
xtr = (d/d.abs().amax(dim=(1,2,3))[:,None,None,None]) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval_cuda = xval[:30].to(torch.cfloat).cuda()  # ------------------------------------ TODO

loss_iter, loss_tr, loss_eval = [], [], []
loss1, loss2, loss3 = [], [], []
NN = NN_gtdecoder
model = NN(M,K,N).cuda()

# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        loss = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()
 
    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    # loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    # loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    # loss3.append(l3.detach().cpu().item()/opts['batch_size'])
    if epoch%40 == 0:
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        # plt.figure()
        # plt.plot(loss1, '-or')
        # plt.title(f'Loss1 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        # plt.figure()
        # plt.plot(loss2, '-or')
        # plt.title(f'Loss2 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        # plt.figure()
        # plt.plot(loss3, '-or')
        # plt.title(f'Loss3 fuction at epoch{epoch}')
        # plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            loss  = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append(loss.cpu().item()/30)
            plt.figure()
            plt.plot(loss_eval, '-xb')
            plt.title(f'Accumulated validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val')

            hh, rs0= Hhat[0], Rs[0]
            Rx = hh @ rs0 @ hh.conj().t() + Rb[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@x.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            print(Rb[0], Rb.sum()/3/30)
            print(f'epoch{epoch} h_corr is ', h_corr(hh.cpu(), torch.tensor(hgt0)))

            for ii in range(K):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                plt.figure()
                plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                plt.colorbar()
                plt.title(f'Epoch{epoch}_estimated V-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                plt.show()
                plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')