In [None]:
#%% s30
rid = 's30' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(True)
from vae_model import *
class NN_s5(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        self.hnet = nn.Sequential(
            nn.Linear(18,128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128,64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Linear(64,M*2)
        )

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        Is = torch.stack([torch.eye(3,3)]*btsize, dim=0).to(x.dtype).cuda()
        ch = np.pi*torch.arange(self.M, device=x.device)
        eps0 = 1e-5
        x0 = x.permute(0,2,3,1)[...,None]
        Rx = x0 @ x.permute(0,2,3,1).conj()[...,None,:]
        rx = Rx.mean(dim=(1,2)) # shape of [I,M,M]
        rx_inv0 = rx.pinverse()
        for i in range(self.J):
            if i == 0:
                rx_inv = rx_inv0
            else:
                w = rx_inv@hhat[...,None] / \
                        (eps0+ hhat[:,None,:].conj()@rx_inv@hhat[...,None])
                p = Is - (hhat[...,None]@w.permute(0,2,1).conj()).detach()
                rx = p@rx@p.permute(0,2,1).conj()
                rx_inv = rx.pinverse()
            temp = self.hnet(torch.stack((rx.real, rx.imag), dim=1).reshape(btsize,-1))
            hhat = temp[:,:M] +1j*temp[:,M:]
            hhat = hhat/hhat.detach()[:,0:1]
            h_all.append(hhat)
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]

        b = x0
        for i in range(self.J):
            hhat = Hhat[:,:,i]
            w = rx_inv0@hhat[...,None] / \
                    (hhat[:,None,:].conj()@rx_inv0@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@x0
            b = b - shat*hhat[:, None,None,:,None]
        
            "Encoder"
            xx = self.encoder(shat.squeeze()[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).square()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]

        return vhat.diag_embed(), Hhat, Rb, mu, logvar

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, beta=0.1):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]

    try:
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 
    except:
        torch.save((x, Rx, Rs, Hhat, Rb), f'rid{rid}x_Rx_Rs_Hhat_Rb.pt')
        print('error happpened, data saved and stop')
        ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze()
    return -ll.mean().real, beta*kl 

#%% load data
I = 3000 # how many samples
M, N, F, J = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 2001

d = torch.load('../data/nem_ss/tr9kM3FT100_ang6915-30.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh_ang6915-30.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat).cuda()
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_s5
model = NN(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

    loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
    loss1.append(l1.detach().cpu().item()/opts['batch_size'])
    loss2.append(l2.detach().cpu().item()/opts['batch_size'])
    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            hh = Hhat[0].detach()
            rs0 = Rs[0].detach() 
            Rx = hh @ rs0 @ hh.conj().t() + Rb.detach()[0]
            shat = (rs0 @ hh.conj().t() @ Rx.inverse()@xval_cuda.permute(0,2,3,1)[0,:,:,:, None]).cpu() 
            for ii in range(J):
                plt.figure()
                plt.imshow(shat[:,:,ii,0].abs())
                plt.title(f'Epoch{epoch}_estimated sources-{ii}')
                plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources-{ii}')
                plt.show()

                # plt.figure()
                # plt.imshow(rs0[:,:,ii, ii].abs().cpu())
                # plt.title(f'Epoch{epoch}_estimated V-{ii}')
                # plt.savefig(fig_loc + f'Epoch{epoch}_estimated V-{ii}')
                # plt.show()
                plt.close('all')
            print('h_corr', h_corr(hh, hgt[0]))
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s31
rid = 's31' # running id
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
torch.manual_seed(1)

fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'

if torch.__version__[:5] != '1.8.1':
    def mydet(x):
        return x.det()
    RAdam = torch.optim.RAdam
else:
    RAdam = optim.RAdam

torch.autograd.set_detect_anomaly(False)
from vae_model import *
class NN_s6(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M=3, K=3, im_size=100):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        self.hnet = nn.Sequential(
            nn.Linear(18,128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128,128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128,64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Linear(64,1)
        )

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(25*25, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, 25*25),
            Reshape(-1, 1, 25, 25),
            Up_(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=4),
            OutConv(in_channels=4, out_channels=1),
            ) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, _, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        x0 = x.permute(0,2,3,1)[...,None]
        Rx = x0 @ x.permute(0,2,3,1).conj()[...,None,:]
        rx = Rx.mean(dim=(1,2)) # shape of [I,M,M]
        rx_inv0 = rx.pinverse()
        b = x0
        for i in range(self.J):
            if i > 0 :
                rx = (b@b.transpose(-1,-2).conj()).mean(dim=(1,2))
            ang = self.hnet(torch.stack((rx.real, rx.imag), dim=1).reshape(btsize,-1))
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()
            h_all.append(hhat)
        
            w = rx_inv0@hhat[...,None] / \
                    (hhat[:,None,:].conj()@rx_inv0@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@x0
            b = b - shat*hhat[:, None,None,:,None]
        
            "Encoder"
            xx = self.encoder(shat.squeeze()[:,None].abs())
            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            z_all.append(z)
            
            "Decoder to get V"
            v = self.decoder(z).square()
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2
        Hhat = torch.stack(h_all, 2) # shape:[I, M, K]
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        vhat = torch.stack(v_all, 4).squeeze().to(torch.cfloat) # shape:[I, N, F, K]
        zall = torch.stack(z_all, dim=1)
        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall


def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1e-3):
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*mydet(Rx)).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    # 'fisher'
    # _, J, D = zall.shape
    # m = zall.mean(dim=(0,1))
    # sw, sb = 0, 0
    # for j in range(J):
    #     z = zall[:,j]
    #     mj = z.mean(0)
    #     sw = sw + ((z - mj)[...,None]@(z - mj)[:,None]).sum(0)
    #     sb = sb + D*(mj-m)[...,None]@(mj-m)[None,:]
    # term = 1e-3*(sw-sb).trace()

    return -ll.mean().real, beta*kl

#%% load data
I = 6000 # how many samples
M, N, F, J = 3, 100, 100, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['lr'] = 1e-3
opts['n_epochs'] = 501

d = torch.load('../data/nem_ss/tr9kM3FT100_ang6915-30.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, _ , hgt0 = torch.load('../data/nem_ss/val500M3FT100_xsh_ang6915-30.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
NN = NN_s6
model = NN(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr = []
            for ind in range(128):
                hh = Hhat[ind].detach()
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)

            plt.figure()
            for ind in range(3):
                hh = Hhat[ind].detach()
                rs0 = Rs[ind].detach() 
                Rx = hh @ rs0 @ hh.conj().t() + Rb.detach()[ind]
                shat = (rs0 @ hh.conj().t() @ Rx.inverse()@xval_cuda.permute(0,2,3,1)[ind,:,:,:, None]).cpu() 
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[:,:,ii,0].abs())
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')

In [None]:
#%% s35
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's35' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()/I

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl  #+ 10*loss_slotCEL

#%%
I = 6000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['n_epochs'] = 301
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr9kM3FT64_data0.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, sval, hgt0 = torch.load('../data/nem_ss/val500M3FT64_xsh_data0.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat)
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr, av_scorr = [], []
            Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
            shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                    @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
            shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
            for ind in range(128):
                hh = Hhat_val[ind]
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
                av_scorr.append(s_corr(sval[ind].abs(), shat[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)
            print('first 3 s_corr',av_scorr[:3],' averaged128:', sum(av_scorr)/128)

            plt.figure()
            for ind in range(3):
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[ind,:,:,ii])
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


In [None]:
#%% s52
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's52' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-5

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()/I

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl  #+ 10*loss_slotCEL

#%%
I = 6000 # how many samples
M, N, F, J = 3, 48, 48, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['n_epochs'] = 301
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr9kM3FT48_data0.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, sval, hgt0 = torch.load('../data/nem_ss/val500M3FT48_xsh_data0.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat)
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr, av_scorr = [], []
            Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
            shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                    @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
            shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
            for ind in range(128):
                hh = Hhat_val[ind]
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
                av_scorr.append(s_corr(sval[ind].abs(), shat[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)
            print('first 3 s_corr',av_scorr[:3],' averaged128:', sum(av_scorr)/128)

            plt.figure()
            for ind in range(3):
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[ind,:,:,ii])
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


In [None]:
#%% s54
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's54' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    "Slot contrastive loss"
    inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl + loss_slotCEL

#%%
I = 6000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['n_epochs'] = 301
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr9kM3FT64_data0.pt')
d = awgn_batch(d, snr=30, seed=1)
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval, sval, hgt0 = torch.load('../data/nem_ss/val500M3FT64_xsh_data0.pt')
hgt = torch.tensor(hgt0).to(torch.cfloat)
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr, av_scorr = [], []
            Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
            shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                    @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
            shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
            for ind in range(128):
                hh = Hhat_val[ind]
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
                av_scorr.append(s_corr(sval[ind].abs(), shat[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)
            print('first 3 s_corr',av_scorr[:3],' averaged128:', sum(av_scorr)/128)

            plt.figure()
            for ind in range(3):
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[ind,:,:,ii])
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s59
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's59' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 201
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data1.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data1.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s61_1
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's61_1' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 501
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data1.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data1.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s62
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's62' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 501
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data1.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data1.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s65
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's65' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data2.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data2.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s67
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's67' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 64, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data0.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data0.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s69
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's69' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 66, 3
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s71
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's71' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 66, 3
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s71_
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's71_' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s11(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            nn.Conv2d(16, 16, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(16//4,1), num_channels=16),
            nn.LeakyReLU(inplace=True),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 66, 3
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s11(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s75
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's75' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s76ca
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's76ca' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
# model = NNet_s10(M,J,N).cuda()
model = torch.load('../data/data_ss/models/s75/model_epoch400.pt')

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), shat[ind:ind+1].cuda()).cpu())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s76s
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's76s' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    "Slot contrastive loss"
    inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl + loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# model = torch.load('../data/data_ss/models/s75/model_epoch400.pt')

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s76s5
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's76s5' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    "Slot contrastive loss"
    inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl + loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-4

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# model = torch.load('../data/data_ss/models/s75/model_epoch400.pt')

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s76sc
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's76sc' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    "Slot contrastive loss"
    inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl + loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
model = torch.load('../data/data_ss/models/s75/model_epoch400.pt')

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s77
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's77' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
        
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        b = xj
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        # eye = torch.eye(M, device='cuda')
        # Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 66, 3
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s79
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's79' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s11(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            nn.Conv2d(16, 16, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(16//4,1), num_channels=16),
            nn.LeakyReLU(inplace=True),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        b = xj
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        # eye = torch.eye(M, device='cuda')
        # Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 3, 64, 66, 3
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM3FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM3FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s11(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s81
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's81' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        b = xj
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        # eye = torch.eye(M, device='cuda')
        # Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -Rx.det().real.log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze().real 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 16
opts['n_epochs'] = 901
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                 
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb

                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()

                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s81_
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's81_' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10_(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        n_feat, n_channel = 192, 96
        self.mainnet = nn.Sequential(
            FC_layer_g(42, n_feat),
            FC_layer_g(n_feat, n_feat),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(n_feat, n_feat),
            FC_layer_g(n_feat, n_feat),
            FC_layer_g(n_feat, n_feat//2),
            nn.Linear(n_feat//2, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(n_feat, n_feat),
            FC_layer_g(n_feat, n_feat),
            FC_layer_g(n_feat, n_feat//2),
            FC_layer_g(n_feat//2, n_feat//2),
            nn.Linear(n_feat//2, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=n_channel),
            DoubleConv_g(in_channels=n_channel, out_channels=n_channel//2),
            Down_g(in_channels=n_channel//2, out_channels=n_channel//4),
            DoubleConv_g(in_channels=n_channel//4, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=n_channel),
            DoubleConv_g(in_channels=n_channel, out_channels=n_channel//2),
            Up_g(in_channels=n_channel//2, out_channels=n_channel//4),
            DoubleConv_g(in_channels=n_channel//4, out_channels=n_channel//8),
            nn.Conv2d(n_channel//8, n_channel//8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(n_channel//32,1), num_channels=n_channel//8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=n_channel//8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        b = xj
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        # eye = torch.eye(M, device='cuda')
        # Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -Rx.det().real.log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze().real 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 901
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10_(M,J,N).cuda()
for w in model.parameters():
    nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr(hh.cpu(), h[ind]))
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s82
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's82' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        b = xj
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        # eye = torch.eye(M, device='cuda')
        # Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -Rx.det().real.log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze().real  

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 901
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
model = torch.load('../data/data_ss/models/s75/model_epoch400.pt')

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s82c
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's82c' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        b = xj
        Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        # eye = torch.eye(M, device='cuda')
        # Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -Rx.det().real.log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze().real  

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 901
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
model = torch.load('../data/data_ss/models/s75/model_epoch400.pt')
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s87
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's87' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s91
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's91' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s12(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(42, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-3

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s12(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

In [None]:
#%% s95
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 's95' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx42):
    ind = torch.tril_indices(6,6)
    indx = np.diag_indices(6)
    rx_inv_hat = torch.zeros(rx42.shape[0], 6, 6, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx42[:, :21] + 1j*rx42[:,21:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet_s10_1(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        feat = 192
        self.mainnet = nn.Sequential(
            FC_layer_g(42, feat),
            FC_layer_g(feat, feat),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(feat, feat),
            FC_layer_g(feat, feat),
            FC_layer_g(feat, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(feat, feat),
            FC_layer_g(feat, feat),
            FC_layer_g(feat, 64),
            FC_layer_g(64, 64),
            nn.Linear(64, 42)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            nn.Conv2d(8, 8, kernel_size=3, padding=(1,2)),
            nn.GroupNorm(num_groups=max(8//4,1), num_channels=8),
            nn.LeakyReLU(inplace=True),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(M,M)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl #+ loss_slotCEL

#%%
I = 18000 # how many samples
M, N, F, J = 6, 64, 66, 6
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 701
opts['lr'] = 1e-4

d = torch.load('../data/nem_ss/tr18kM6FT64_data3.pt')
xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

xval, sval, hgt = torch.load('../data/nem_ss/val1kM6FT64_xsh_data3.pt')
sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval, hgt)
dval = Data.DataLoader(data, batch_size=200, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet_s10_1(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s, h) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
                shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                        @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
                shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
                for ind in range(x.shape[0]):
                    hh = Hhat_val[ind]
                    av_hcorr.append(h_corr_cuda(hh, h[ind].cuda()).cpu().item())
                    av_scorr.append(s_corr_cuda(s[ind:ind+1].abs().cuda(), \
                        shat[ind:ind+1].cuda()).cpu().item())
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,J,ii+1+ind*J)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 h_corr',av_hcorr[:3],' averaged:', sum(av_hcorr)/len(av_hcorr))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_hcorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())

## V series

In [None]:
#%% v79
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())
"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

rid = 'v79'
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% generate data
if True:
    M, n_data= 3, int(9e3)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    NandF = 48
    v0 = resize(v0, (50, 50), preserve_range=True)
    v0 = torch.tensor(resize(v0, (NandF, NandF), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = resize(v1, (50, 50), preserve_range=True)
    v1 = torch.tensor(resize(v1, (NandF, NandF), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = resize(v2, (50, 50), preserve_range=True)
    v2 = torch.tensor(resize(v2, (NandF, NandF), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    angs = (torch.rand(n_data,1)*180 -0)/180*np.pi  # signal aoa [0, 180]
    h_1 = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]
    angs_n1 = (torch.rand(n_data,1)*180 -0)/180*np.pi  # noise aoa [0, 180]
    h_2 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]
    angs_n2 = (torch.rand(n_data,1)*180 -0)/180*np.pi  # noise aoa [0, 180]
    h_3 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    sig1 = (h_1[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    sig2 = (h_2[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    sig3 = (h_3[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    sig1 = sig1.reshape(n_data, M, NandF, NandF)
    sig2 = sig2.reshape(n_data, M, NandF, NandF)
    sig3 = sig3.reshape(n_data, M, NandF, NandF)
    for i in range(n_data):
        a = torch.randint(0,50,(2,))
        sig1[i] = sig1[i].roll((a[0], a[1]),dims=(1,2))
        a = torch.randint(0,50,(2,))
        sig2[i] = sig2[i].roll((a[0], a[1]),dims=(1,2))
        a = torch.randint(0,50,(2,))
        sig3[i] = sig3[i].roll((a[0], a[1]),dims=(1,2))
    s = torch.stack((sig1[:,0], sig2[:,0], sig3[:,0]), dim=3)
    mix = sig1 + sig2 + sig3
    mix_all = mix.permute(0,2,3,1)
    mixn = awgn_batch(mix_all,seed=0).permute(0,3,1,2) # shape of [I, M, N, F]
    H_all = torch.stack((h_1, h_2, h_3), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(NandF**2, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer(12, 128),
            FC_layer(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer(128, 128),
            FC_layer(128, 128),
            FC_layer(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer(128, 128),
            FC_layer(128, 128),
            FC_layer(128, 64),
            FC_layer(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Down(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up(in_channels=1, out_channels=64),
            DoubleConv(in_channels=64, out_channels=32),
            Up(in_channels=32, out_channels=16),
            DoubleConv(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-4

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()/I

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl  #+ 10*loss_slotCEL

I = 6000 # how many samples
M, N, F, J = 3, NandF, NandF, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['n_epochs'] = 201
opts['lr'] = 1e-3

xtr = mixn[:I]/mix[:I].abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xtr)
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval = mixn[I:I+128]
hgt = H_all[I:I+128]
sval = s[I:I+128]
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr, av_scorr = [], []
            Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
            shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                    @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
            shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
            for ind in range(128):
                hh = Hhat_val[ind]
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
                av_scorr.append(s_corr(sval[ind].abs(), shat[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)
            print('first 3 s_corr',av_scorr[:3],' averaged128:', sum(av_scorr)/128)

            plt.figure()
            for ind in range(3):
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[ind,:,:,ii])
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')



# #%%
# model = torch.load('../data/data_ss/models/v76/model_epoch200.pt')



# #%%
# rr = xval.permute(0,2,3,1)[...,None]@xval.permute(0,2,3,1)[:,:,:,None,:].conj()
# rxval = rr.mean(dim=(1,2))
# for i in range(5):
#     r0 = rxval[i].pinverse()
#     hgt0 = hgt[i]
#     for j in range(3):
#         w = r0@hgt0[:,j:j+1] / \
#                 (eps + hgt0[:,j:j+1].t().conj()@r0@hgt0[:,j:j+1])
#         shat = w.t().conj()@xval.permute(0,2,3,1)[i,..., None]
#         plt.figure()
#         plt.imshow(shat.squeeze().abs())
#         plt.colorbar()

In [None]:
#%% v80
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 'v80' # standard
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% generate data
if True:
    M, n_data= 3, int(9e3)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    NandF = 48
    v0 = resize(v0, (50, 50), preserve_range=True)
    v0 = torch.tensor(resize(v0, (NandF, NandF), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = resize(v1, (50, 50), preserve_range=True)
    v1 = torch.tensor(resize(v1, (NandF, NandF), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = resize(v2, (50, 50), preserve_range=True)
    v2 = torch.tensor(resize(v2, (NandF, NandF), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    angs = (torch.rand(n_data,1)*180 -0)/180*np.pi  # signal aoa [0, 180]
    h_1 = (1j*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]
    angs_n1 = (torch.rand(n_data,1)*180 -0)/180*np.pi  # noise aoa [0, 180]
    h_2 = (1j*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]
    angs_n2 = (torch.rand(n_data,1)*180 -0)/180*np.pi  # noise aoa [0, 180]
    h_3 = (1j*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    sig1 = (h_1[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    sig2 = (h_2[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    sig3 = (h_3[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    sig1 = sig1.reshape(n_data, M, NandF, NandF)
    sig2 = sig2.reshape(n_data, M, NandF, NandF)
    sig3 = sig3.reshape(n_data, M, NandF, NandF)
    for i in range(n_data):
        a = torch.randint(0,50,(2,))
        sig1[i] = sig1[i].roll((a[0], a[1]),dims=(1,2))
        a = torch.randint(0,50,(2,))
        sig2[i] = sig2[i].roll((a[0], a[1]),dims=(1,2))
        a = torch.randint(0,50,(2,))
        sig3[i] = sig3[i].roll((a[0], a[1]),dims=(1,2))
    s = torch.stack((sig1[:,0], sig2[:,0], sig3[:,0]), dim=3)
    mix = sig1 + sig2 + sig3
    mix_all = mix.permute(0,2,3,1)
    mixn = awgn_batch(mix_all,seed=0).permute(0,3,1,2) # shape of [I, M, N, F]
    H_all = torch.stack((h_1, h_2, h_3), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(NandF**2, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_(12, 128),
            FC_layer_(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_(128, 128),
            FC_layer_(128, 128),
            FC_layer_(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_(128, 128),
            FC_layer_(128, 128),
            FC_layer_(128, 64),
            FC_layer_(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_(in_channels=1, out_channels=64),
            DoubleConv_(in_channels=64, out_channels=32),
            Down_(in_channels=32, out_channels=16),
            DoubleConv_(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_(in_channels=1, out_channels=64),
            DoubleConv_(in_channels=64, out_channels=32),
            Up_(in_channels=32, out_channels=16),
            DoubleConv_(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()/I

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl  #+ 10*loss_slotCEL

#%%
I = 6000 # how many samples
M, N, F, J = 3, NandF, NandF, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['n_epochs'] = 201
opts['lr'] = 1e-3

xtr = mixn[:I]/mix[:I].abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xtr)
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval = mixn[I:I+128]
hgt = H_all[I:I+128]
sval = s[I:I+128]
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr, av_scorr = [], []
            Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
            shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                    @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
            shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
            for ind in range(128):
                hh = Hhat_val[ind]
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
                av_scorr.append(s_corr(sval[ind].abs(), shat[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)
            print('first 3 s_corr',av_scorr[:3],' averaged128:', sum(av_scorr)/128)

            plt.figure()
            for ind in range(3):
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[ind,:,:,ii])
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


# #%%
# model = torch.load('../data/data_ss/models/v76/model_epoch200.pt')

# #%%
# rr = xval.permute(0,2,3,1)[...,None]@xval.permute(0,2,3,1)[:,:,:,None,:].conj()
# rxval = rr.mean(dim=(1,2))
# for i in range(5):
    # r0 = rxval[i].pinverse()
    # hgt0 = hgt[i]
    # for j in range(3):
    #     w = r0@hgt0[:,j:j+1] / \
    #             (eps + hgt0[:,j:j+1].t().conj()@r0@hgt0[:,j:j+1])
    #     shat = w.t().conj()@xval.permute(0,2,3,1)[i,..., None]
    #     plt.figure()
    #     plt.imshow(shat.squeeze().abs())
    #     plt.colorbar()

In [None]:
#%% v84
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 'v84' # standard
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% generate data
if True:
    M, n_data= 3, int(9e3)
    dicts = sio.loadmat('../data/nem_ss/v2.mat')
    v0 = dicts['v'][..., 0]
    v1 = dicts['v'][..., 1]
    v2 = dicts['v'][..., 2]
    from skimage.transform import resize
    NandF = 48
    v0 = resize(v0, (50, 50), preserve_range=True)
    v0 = torch.tensor(resize(v0, (NandF, NandF), preserve_range=True))
    v0 = awgn(v0, snr=30, seed=0).abs().to(torch.cfloat)
    plt.imshow(v0.abs())
    plt.colorbar()
    v1 = resize(v1, (50, 50), preserve_range=True)
    v1 = torch.tensor(resize(v1, (NandF, NandF), preserve_range=True))
    v1 = awgn(v1, snr=30, seed=1).abs().to(torch.cfloat)
    v2 = resize(v2, (50, 50), preserve_range=True)
    v2 = torch.tensor(resize(v2, (NandF, NandF), preserve_range=True))
    v2 = awgn(v2, snr=30, seed=2).abs().to(torch.cfloat)

    angs = (torch.rand(n_data,1)*180 -0)/180*np.pi  # signal aoa [0, 180]
    h_1 = (1j*np.pi*angs.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]
    angs_n1 = (torch.rand(n_data,1)*180 -0)/180*np.pi  # noise aoa [0, 180]
    h_2 = (1j*np.pi*angs_n1.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]
    angs_n2 = (torch.rand(n_data,1)*180 -0)/180*np.pi  # noise aoa [0, 180]
    h_3 = (1j*np.pi*angs_n2.sin()@torch.arange(M).to(torch.cfloat)[None,:]).exp() #shape of [n,M]

    sig1 = (h_1[..., None]@(torch.randn(v0.shape, dtype=torch.cfloat)*(v0**0.5)).flatten()[None,:])
    sig2 = (h_2[..., None]@(torch.randn(v1.shape, dtype=torch.cfloat)*(v1**0.5)).flatten()[None,:])
    sig3 = (h_3[..., None]@(torch.randn(v2.shape, dtype=torch.cfloat)*(v2**0.5)).flatten()[None,:])
    sig1 = sig1.reshape(n_data, M, NandF, NandF)
    sig2 = sig2.reshape(n_data, M, NandF, NandF)
    sig3 = sig3.reshape(n_data, M, NandF, NandF)
    for i in range(n_data):
        a = torch.randint(0,50,(2,))
        sig1[i] = sig1[i].roll((a[0], a[1]),dims=(1,2))
        a = torch.randint(0,50,(2,))
        sig2[i] = sig2[i].roll((a[0], a[1]),dims=(1,2))
        a = torch.randint(0,50,(2,))
        sig3[i] = sig3[i].roll((a[0], a[1]),dims=(1,2))
    s = torch.stack((sig1[:,0], sig2[:,0], sig3[:,0]), dim=3)
    mix = sig1 + sig2 + sig3
    mix_all = mix.permute(0,2,3,1)
    mixn = awgn_batch(mix_all,seed=0).permute(0,3,1,2) # shape of [I, M, N, F]
    H_all = torch.stack((h_1, h_2, h_3), dim=-1)

    # torch.save((mix, sig, h), 'toy_matrix_inv.pt') # generate data is faster than loading it...
    plt.figure()
    plt.imshow(mix_all[0,:,:,0].abs())
    plt.colorbar()

if False: # check data low rank or not
    for i in range(n_data):
        x = mix[i,:,:].reshape(NandF**2, 3)
        xbar = x - x.mean(0)
        cov = x.conj().t() @ x
        r = torch.linalg.matrix_rank(cov)
        if r != 3:
            print('low rank', i, 'rank is ', r)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.mainnet = nn.Sequential(
            FC_layer_g(12, 128),
            FC_layer_g(128, 128),
        )
        self.hnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            nn.Linear(64, 1),
            nn.Tanh()
        )
        self.rxnet = nn.Sequential(
            FC_layer_g(128, 128),
            FC_layer_g(128, 128),
            FC_layer_g(128, 64),
            FC_layer_g(64, 32),
            nn.Linear(32, 12)
        )

        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]
        for i in range(self.J):
            "Get H estimation"
            ind = torch.tril_indices(3,3)
            rx = (xj@xj.transpose(-1,-2).conj()).mean(dim=(1,2))
            rx_lower = rx[:, ind[0], ind[1]]
            mid =self.mainnet(torch.stack((rx_lower.real,rx_lower.imag),\
                 dim=1).reshape(btsize,-1))
            ang = self.hnet(mid)
            temp = ang@ch[None,:]
            hhat = (1j*temp).exp()  # shape of [I, M]
            h_all.append(hhat)

            "Get Rx inverse"
            rx_index = self.rxnet(mid)
            rx_inv = lower2matrix(rx_index) # shape of [I, M, M]
        
            "Encoder part"
            w = rx_inv@hhat[...,None] / \
                (hhat[:,None,:].conj()@rx_inv@hhat[...,None])
            shat = w.permute(0,2,1).conj()[:,None,None]@xj
            xx = self.encoder(shat.squeeze()[:,None].abs())

            "Get latent variable"
            zz = self.fc1(xx.reshape(btsize,-1))
            mu = zz[:,::2]
            logvar = zz[:,1::2]
            z = self.reparameterize(mu, logvar)
            wz = self.bilinear(z)
            z_all.append(z)
            z_all.append(wz)
            
            "Decoder to get V"
            v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
            v_all.append(threshold(v, floor=1e-3, ceiling=1e2)) # 1e-3 to 1e2

            "Remove the current component"
            rxinvh = rx_inv@hhat[...,None]  # shape of [I, M, 1]
            v_rxinv_h_herm = (v[...,None, None]*rxinvh[:,None, None]).transpose(-1,-2).conj() 
            cj = hhat[:,None,None,:,None] * (v_rxinv_h_herm @ xj) # shape of [I,N,F,M,1]
            xj = xj - cj
       
        Hhat = torch.stack(h_all, 2) # shape:[I, M, J]
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda')
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.diag_embed(), Hhat, Rb, mu, logvar, zall

def loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.permute(0,2,3,1)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rxperm = Hhat @ Rs.permute(1,2,0,3,4) @ Hhat.transpose(-1,-2).conj() + Rb
    Rx = Rxperm.permute(2,0,1,3,4) # shape of [I, N, F, M, M]
    ll = -(np.pi*Rx.det()).log() - (x[...,None,:].conj()@Rx.inverse()@x[...,None]).squeeze() 

    # "Slot contrastive loss"
    # inp = (zall[:,0::2]@zall[:,1::2].permute(0,2,1)).reshape(I*J, J) # shape of [N,J,J]
    # target = torch.cat([torch.arange(J) for i in range(I)]).cuda()
    # loss_slotCEL = nn.CrossEntropyLoss(reduction='none')(inp, target).sum()/I

    # "My own loss for H"
    # HHt = Hhat@Hhat.permute(0,2,1).conj() 
    # temp = x[...,None]@ x[:,:,:,None].conj()
    # rx = temp.mean(dim=(1,2))
    # term = (((rx- HHt/100).abs())**2).mean()

    return -ll.sum(), beta*kl  #+ 10*loss_slotCEL

#%%
I = 6000 # how many samples
M, N, F, J = 3, NandF, NandF, 3
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 64
opts['n_epochs'] = 201
opts['lr'] = 1e-3

xtr = mixn[:I]/mix[:I].abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xtr)
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)
xval = mixn[I:I+128]
hgt = H_all[I:I+128]
sval = s[I:I+128]
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
xval_cuda = xval[:128].to(torch.cfloat).cuda()

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Hhat, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Hhat, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%10 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            Rs, Hhat_val, Rb, mu, logvar, zall= model(xval_cuda)
            l1, l2 = loss_fun(xval_cuda, Rs, Hhat_val, Rb, mu, logvar, zall)
            loss_eval.append((l1+l2).cpu().item()/128)
            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all')           

            av_hcorr, av_scorr = [], []
            Rxperm = Hhat_val@Rs.permute(1,2,0,3,4)@Hhat_val.transpose(-1,-2).conj() + Rb
            shatperm = Rs.permute(1,2,0,3,4)@Hhat_val.conj().transpose(-1,-2)\
                    @Rxperm.inverse()@xval_cuda.permute(2,3,0,1)[...,None]
            shat = shatperm.permute(2,0,1,3,4).squeeze().cpu().abs()
            for ind in range(128):
                hh = Hhat_val[ind]
                av_hcorr.append(h_corr(hh.cpu(), hgt[ind]))
                av_scorr.append(s_corr(sval[ind].abs(), shat[ind]))
            print('first 3 h_corr',av_hcorr[:3],' averaged128:', sum(av_hcorr)/128)
            print('first 3 s_corr',av_scorr[:3],' averaged128:', sum(av_scorr)/128)

            plt.figure()
            for ind in range(3):
                for ii in range(J):
                    plt.subplot(3,3,ii+1+ind*3)
                    plt.imshow(shat[ind,:,:,ii])
                    # plt.tight_layout(pad=1.1)
                    # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
            plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
            plt.show()
            plt.close('all')
            
        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')


# #%%
# model = torch.load('../data/data_ss/models/v76/model_epoch200.pt')

# #%%
# rr = xval.permute(0,2,3,1)[...,None]@xval.permute(0,2,3,1)[:,:,:,None,:].conj()
# rxval = rr.mean(dim=(1,2))
# for i in range(5):
    # r0 = rxval[i].pinverse()
    # hgt0 = hgt[i]
    # for j in range(3):
    #     w = r0@hgt0[:,j:j+1] / \
    #             (eps + hgt0[:,j:j+1].t().conj()@r0@hgt0[:,j:j+1])
    #     shat = w.t().conj()@xval.permute(0,2,3,1)[i,..., None]
    #     plt.figure()
    #     plt.imshow(shat.squeeze().abs())
    #     plt.colorbar()

In [None]:
#%% v86  # 1 channel, 3 classes, see if the VAE capacity is enough
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"]="0"
plt.rcParams['figure.dpi'] = 100
torch.set_printoptions(linewidth=160)
from datetime import datetime
print('starting date time ', datetime.now())

#%%
"make the result reproducible"
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)       # current GPU seed
torch.cuda.manual_seed_all(seed)   # all GPUs seed
torch.backends.cudnn.deterministic = True  #True uses deterministic alg. for cuda
torch.backends.cudnn.benchmark = False  #False cuda use the fixed alg. for conv, may slower

rid = 'v86' 
fig_loc = '../data/data_ss/figures/'
mod_loc = '../data/data_ss/models/'
if not(os.path.isdir(fig_loc + f'/{rid}/')): 
    print('made a new folder')
    os.mkdir(fig_loc + f'{rid}/')
    os.mkdir(mod_loc + f'{rid}/')
fig_loc = fig_loc + f'{rid}/'
mod_loc = mod_loc + f'{rid}/'
# torch.autograd.set_detect_anomaly(True)

#%% define models and functions
from vae_modules import *
def lower2matrix(rx12):
    ind = torch.tril_indices(3,3)
    indx = np.diag_indices(3)
    rx_inv_hat = torch.zeros(rx12.shape[0], 3, 3, dtype=torch.cfloat).cuda()
    rx_inv_hat[:, ind[0], ind[1]] = rx12[:, :6] + 1j*rx12[:,6:]
    rx_inv_hat = rx_inv_hat + rx_inv_hat.permute(0,2,1).conj()
    rx_inv_hat[:, indx[0], indx[1]] = rx_inv_hat[:, indx[0], indx[1]]/2
    return rx_inv_hat

class NNet(nn.Module):
    """This is recursive Wiener filter version, with Rb threshold of [1e-3, 1e2]
    Input shape [I,M,N,F], e.g.[32,3,100,100]
    J <=K
    """
    def __init__(self, M, K, im_size):
        super().__init__()
        self.dz = 32
        self.J, self.M = K, M
        down_size = int(im_size/4)
        self.encoder = nn.Sequential(
            Down_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Down_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=1),
            )
        self.fc1 = nn.Linear(down_size*down_size, 2*self.dz)
        self.decoder = nn.Sequential(
            nn.Linear(self.dz, down_size*down_size),
            Reshape(-1, 1, down_size, down_size),
            Up_g(in_channels=1, out_channels=64),
            DoubleConv_g(in_channels=64, out_channels=32),
            Up_g(in_channels=32, out_channels=16),
            DoubleConv_g(in_channels=16, out_channels=8),
            OutConv(in_channels=8, out_channels=1),
            ) 
        self.bilinear = nn.Linear(self.dz, self.dz, bias=False)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):     
        btsize, M, N, F = x.shape
        z_all, v_all, h_all = [], [], []
        ch = np.pi*torch.arange(self.M, device=x.device)
        xj = x.permute(0,2,3,1)[...,None]  # shape of [I,N,F,M,1]

        
        "Encoder part"
        shat = xj
        xx = self.encoder(shat.squeeze()[:,None].abs())

        "Get latent variable"
        zz = self.fc1(xx.reshape(btsize,-1))
        mu = zz[:,::2]
        logvar = zz[:,1::2]
        z = self.reparameterize(mu, logvar)
        wz = self.bilinear(z)
        z_all.append(z)
        z_all.append(wz)
        
        "Decoder to get V"
        v = self.decoder(z).square().squeeze()  # shape of [I,N,F]
        v_all.append(threshold(v, floor=1e-6, ceiling=1e2)) # 1e-6 to 1e2 
       
        vhat = torch.stack(v_all, 3).to(torch.cfloat) # shape:[I, N, F, J]
        zall = torch.stack(z_all, dim=1)

        # Rb = (b@b.conj().permute(0,1,2,4,3)).mean(dim=(1,2)).squeeze()
        eye = torch.eye(M, device='cuda') 
        Rb = torch.stack(tuple(eye for ii in range(btsize)), 0)*1e-3

        return vhat.squeeze(), Rb, mu, logvar, zall

def loss_fun(x, Rs, Rb, mu, logvar, zall, beta=1):
    I, M, J = x.shape[0], x.shape[1], Rs.shape[-1]
    x = x.squeeze()
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    Rx = Rs.squeeze() + Rb # shape of [I, N, F]
    ll = -(np.pi*Rx.abs()).log() - (x.conj()*(1/Rx)*x)
    return -ll.sum(), beta*kl 

#%%
"raw data processing"
FT = 64  #48, 64, 80, 100, 128, 200, 256
var_name = ['ble', 'bt', 'fhss1', 'fhss2', 'wifi1', 'wifi2']
data = {}

def get_ftdata(data_pool):
    *_, Z = stft(data_pool, fs=4e7, nperseg=FT, boundary=None)
    x = torch.tensor(np.roll(Z, FT//2, axis=1))  # roll nperseg//2
    return x.to(torch.cfloat)

for i in range(6):
    temp = sio.loadmat('/home/chenhao1/Matlab/LMdata/compressed/'+var_name[i]+f'_{FT}_2k.mat')
    x = torch.tensor(temp['x'])
    x =  x/((x.abs()**2).sum(dim=(1),keepdim=True)**0.5)# normalize
    data[i] = x
s1 = get_ftdata(data[0]) # ble [2000,F,T]
s2 = get_ftdata(data[2]) # fhss1
s3 = get_ftdata(data[5]) # wifi2
s = [s1, s2, s3]

torch.manual_seed(1)
J, M = 3, 1
"training data"
x = []
for i in range(4):
    temp = 0
    for j in range(J):
        idx = torch.randperm(2000)
        temp =s[j][idx]
        x.append(temp)
x = torch.stack(x, dim=1)
xtr = x[:,:9].reshape(-1,1,FT,FT)
d = awgn_batch(xtr, snr=40, seed=1) # added white noise

xvt = x[:,9:].reshape(-1,1,FT,FT)
sval = xvt[:1000]
xval = awgn_batch(sval, snr=40, seed=10)

#%%
I = 18000 # how many samples
M, N, F, J = 1, 64, 64, 1
NF = N*F
eps = 5e-4
opts = {}
opts['batch_size'] = 128
opts['n_epochs'] = 301
opts['lr'] = 1e-3

xtr = (d/d.abs().amax(dim=(1,2,3), keepdim=True)) # [sample,M,N,F]
xtr = xtr.to(torch.cfloat)
data = Data.TensorDataset(xtr[:I])
tr = Data.DataLoader(data, batch_size=opts['batch_size'], shuffle=True, drop_last=True)

sval= sval.permute(0,2,3,1)
xval = xval/xval.abs().amax(dim=(1,2,3), keepdim=True)
data = Data.TensorDataset(xval, sval)
dval = Data.DataLoader(data, batch_size=1000, drop_last=True)

#%%
loss_iter, loss_tr, loss1, loss2, loss_eval = [], [], [], [], []
model = NNet(M,J,N).cuda()
# for w in model.parameters():
#     nn.init.normal_(w, mean=0., std=0.01)

optimizer = torch.optim.RAdam(model.parameters(),
                lr= opts['lr'],
                betas=(0.9, 0.999), 
                eps=1e-8,
                weight_decay=0)

for epoch in range(opts['n_epochs']):
    model.train()
    for i, (x,) in enumerate(tr): 
        x = x.cuda()
        optimizer.zero_grad()         
        Rs, Rb, mu, logvar, zall= model(x)
        l1, l2 = loss_fun(x, Rs, Rb, mu, logvar, zall)
        loss = l1 + l2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
        optimizer.step()
        torch.cuda.empty_cache()

        if i%30 == 0:
            loss_tr.append(loss.detach().cpu().item()/opts['batch_size'])
            loss1.append(l1.detach().cpu().item()/opts['batch_size'])
            loss2.append(l2.detach().cpu().item()/opts['batch_size'])

    if epoch%5 == 0:
        print(epoch)
        plt.figure()
        plt.plot(loss_tr, '-or')
        plt.title(f'Loss fuction at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_LossFunAll')

        plt.figure()
        plt.plot(loss1, '-og')
        plt.title(f'Reconstruction loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss1')

        plt.figure()
        plt.plot(loss2, '-og')
        plt.title(f'KL loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_Loss2')

        plt.figure()
        plt.plot(loss_tr[-50:], '-or')
        plt.title(f'Last 50 of loss at epoch{epoch}')
        plt.savefig(fig_loc + f'Epoch{epoch}_last50')

        model.eval()
        with torch.no_grad():
            av_hcorr, av_scorr, temp = [], [], []
            for i, (x, s) in enumerate(dval):
                xval_cuda = x.cuda()
                Rs, Rb, mu, logvar, zall= model(xval_cuda)
                l1, l2 = loss_fun(xval_cuda, Rs, Rb, mu, logvar, zall)
                temp.append((l1+l2).cpu().item()/x.shape[0])
                     
                Rx = Rs + Rb
                shatperm = Rs/Rx*xval_cuda.squeeze()
                shat = shatperm[...,None].cpu().abs()
                for ind in range(x.shape[0]):
                    av_scorr.append(s_corr(s[ind].abs(), shat[ind]))
                
                if i == 0:
                    plt.figure()
                    for ind in range(3):
                        for ii in range(J):
                            plt.subplot(3,3,ii+1+ind*3)
                            plt.imshow(shat[ind,:,:,ii])
                            # plt.tight_layout(pad=1.1)
                            # if ii == 0 : plt.title(f'Epoch{epoch}_sample{ind}')
                    plt.savefig(fig_loc + f'Epoch{epoch}_estimated sources')
                    plt.show()
                    plt.close('all')

            loss_eval.append(sum(temp)/len(temp))
            print('first 3 s_corr',av_scorr[:3],' averaged:', sum(av_scorr)/len(av_scorr))

            plt.figure()
            plt.plot(loss_eval[-50:], '-xb')
            plt.title(f'last 50 validation loss at epoch{epoch}')
            plt.savefig(fig_loc + f'Epoch{epoch}_val') 
            plt.close('all') 

        torch.save(model, mod_loc+f'model_epoch{epoch}.pt')
print('done')
print('End date time ', datetime.now())