In [6]:
# %load train_kronnc.py
# random del 15 个nodes  email.txt
import torch
import pickle
import random
import numpy as np
import networkx as nx 
from networkx.convert import from_dict_of_dicts
from networkx.classes.graph import Graph
from kronEM import *
seed =1900
np.random.seed(seed)
random.seed(seed)

In [12]:
def kronecker(A,B):
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))

def swapnodes(adj_before,i,j):
    sz = adj_before.shape[0]
    l = list(range(sz))
    l[i],l[j] = l[j],l[i]
    perm = torch.eye(sz)[l]
    if use_cuda:
        perm = perm.cuda()
    adj_after = torch.mm(torch.mm(perm,adj_before),perm.T)
    return adj_after    
def metropolis_update_ratio_Pk(Pk_before,Pk_later,A):
    '''
    if memory is sufficinet, this one is much cleaner to execute
    '''
    Nll_before=(1-A)*torch.log(1-Pk_before) +A*torch.log(Pk_before)
    Nll_later=(1-A)*torch.log(1-Pk_later) +A*torch.log(Pk_later)
    ratio=torch.exp(torch.sum(Nll_later-Nll_before))
    return ratio

def swap_pk(pk,pos1,pos2,A,u):
    pk_later = swapnodes(pk,pos1,pos2)
    ratio = metropolis_update_ratio_Pk(pk,pk_later,A)
    if u<ratio:
        pk=pk_later
    return pk


class Gumbel_union_kron(nn.Module):
    def __init__(self,p0,korder,label_non_obs,init_H,warmup,temp=10,temp_drop_frac = 0.9999):
        super(Gumbel_union_kron,self).__init__()
        self.p = Parameter(p0,requires_grad = True)
        self.korder = korder
        self.label_non_obs = label_non_obs
        self.init_H  = init_H
        self.warmup = warmup
    
    def generator_adjacency(self):
        k = self.korder
        p0 = self.p
        adj = self.p
        for i in range(k-1):
            adj = kronecker(adj,p0)
        return adj
    
    def shuffle_adj(self):
        pk = self.generator_adjacency()
        if use_cuda:
            pk = pk.cuda()
        warmup =self.warmup
        u1= torch.rand(warmup)
        Node_list=np.arange(len(pk))
        element_to_swap=np.random.choice(a=Node_list,size=(2,3*(warmup)))
        mask=element_to_swap[1,:]!=element_to_swap[0,:]# it is pointless to swap the same element
        n1_swap=element_to_swap[0,:][mask]
        n2_swap=element_to_swap[1,:][mask]
        for i in range(warmup):
            pk = swap_pk(pk,n1_swap[i],n2_swap[i],self.init_H,u1[i])
        return pk
    
    def sample_all(self,hard = True):
        pk = self.shuffle_adj()
        un_index = torch.nonzero(label_non_obs)
        un_pk = pk[(un_index[:,0],un_index[:,1])].unsqueeze(1)
        logp = torch.cat((un_pk,1-un_pk),1)
        if use_cuda:
            logp = logp.cuda()
        out = gumbel_softmax(self.logp,self.temperature,hard)
        if hard:
            hh = torch.zeros_like(logp)
            for i in range(out.size()[0]):
                hh[i,out[i]] = 1
                out = hh
        out = out[:,0]
        matrix = torch.zeros_like(pk)
        matrix[(un_index[:,0],un_index[:,1])] = out
        
        return matrix

In [13]:
import torch
import time
import random
import torch.nn.functional as F
import torch.nn.utils as U
import torch.optim as optim
import argparse
import sys 
import pickle
import matplotlib.pyplot as plt
from model_gy import *
from tools_changedelmethod import *
import os


parser = argparse.ArgumentParser(description = "ER network")
parser.add_argument('--node_num', type=int, default=128,
                    help='number of epochs to train')
parser.add_argument("--seed",type =int,default = 135,help = "random seed (default: 2050)")
# parser.add_argument("--sysdyn",type = str,default = 'voter',help = "the type of dynamics")
parser.add_argument("--dim",type = int,default = 2,help = "information diminsion of each node cml")
parser.add_argument("--hidden_size",type = int,default = 64,help = "hidden size of GGN model (default:128)")
parser.add_argument("--epoch_num",type = int,default = 700,help = "train epoch of model (default:1000)")                    
parser.add_argument("--batch_size",type = int,default = 1024,help = "input batch size for training (default: 128)")
parser.add_argument("--cuda",type = int, default = 3,help = "choose the GPU (default: 0)")
parser.add_argument("--lr_net",type = float,default = 0.004,help = "gumbel generator learning rate (default:0.004) ")
parser.add_argument("--lr_dyn",type = float,default = 0.001,help = "dynamic learning rate (default:0.001)")
parser.add_argument("--lr_state",type = float,default = 0.1,help = "state learning rate (default:0.1)")
parser.add_argument("--miss_percent",type = float,default = 0.1,help = "missing percent node (default:0.1)")
parser.add_argument("--data_path",type =str,default = '/data/chenmy/voter/seed2050email128128500',help = "the path of simulation data (default:ER_p0.04100300) ")
args = parser.parse_args([])
# from model_old_generator import *
# configuration
HYP = {
    'node_num': args.node_num,  # node num
    'seed': args.seed,  # the seed
    'dim': args.dim,  # information diminsion of each node cml:1 spring:4
    'hid': args.hidden_size,  # hidden size
    'epoch_num': args.epoch_num,  # epoch
    'batch_size': args.batch_size,  # batch size
    'lr_net':args.lr_net,  # lr for net generator
    'lr_dyn': args.lr_dyn,  # lr for dyn learner
    'lr_state':args.lr_state,
    'miss_percent':args.miss_percent,
    'data_path':args.data_path,
    'temp': 1,  # 温度
    'drop_frac': 1,  # temperature drop frac
    "cuda":args.cuda,
}

print("all parameter ",HYP)
# partial known adj 
torch.cuda.set_device(HYP["cuda"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(HYP['seed'])
random.seed(HYP['seed'])
del_num = int(args.node_num*args.miss_percent)
kn_nodes = HYP['node_num'] - del_num
# filename = os.path.abspath(__file__)
# print("filename",filename)


num = 3
def onehot_state(x_un_pre):
    if x_un_pre.shape[1]==1:
        state = torch.argmax(x_un_pre,2)
        pre_state = torch.cat((state,1-state),1).unsqueeze(1)
    else:
        state = torch.argmax(x_un_pre,2).unsqueeze(2)
        pre_state = torch.cat((state,1-state),2)  
    return pre_state

# load data
adj_address = HYP['data_path']+'-adjmat.pickle'
series_address = HYP['data_path']+'-series.pickle'
# train_loader, val_loader, test_loader, object_matrix = load_bn_ggn_ori(series_address,adj_address,batch_size=HYP['batch_size'])
train_loader, val_loader, test_loader, object_matrix = load_bn_ggn(series_address,adj_address,batch_size=HYP['batch_size'],seed = HYP['seed'])
print(object_matrix[-del_num:,-del_num:,].sum(0))

# check  未知与未知部分的link 
unedges  = torch.sum(object_matrix[-del_num:,-del_num:])
while unedges.item() == 0:
    print("sample 0 edges")
    sys.exit()

start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print('start_time:', start_time)
    
# dyn learner isomorphism
dyn_isom = IO_B_discrete(HYP['dim'], HYP['hid']).to(device)
op_dyn = optim.Adam(dyn_isom.parameters(), lr=HYP['lr_dyn'])



# states learner
sample_num = len(train_loader.dataset)
states_learner = Generator_states_discrete(sample_num,del_num).double()
if use_cuda:
    states_learner = states_learner.cuda()
opt_states = optim.Adam(states_learner.parameters(),lr = HYP['lr_state'])
train_index = DataLoader([i for i in range(sample_num)],HYP['batch_size'])

# val  states learner 
v_sample_num = len(val_loader.dataset)
states_learner_v = Generator_states_discrete(v_sample_num,del_num).double()
if use_cuda:
    states_learner_v = states_learner_v.cuda()
opt_states_v = optim.Adam(states_learner_v.parameters(),lr = HYP['lr_state'])
val_index = DataLoader([i for i in range(v_sample_num)],HYP['batch_size'])

observed_adj = object_matrix[:-del_num,:-del_num]
kn_mask,left_mask,un_un_mask,kn_un_mask = partial_mask(HYP['node_num'],del_num)
print("data_num",len(train_loader.dataset),"object_matrix",object_matrix)
loss_fn = torch.nn.NLLLoss()

all parameter  {'node_num': 128, 'seed': 135, 'dim': 2, 'hid': 64, 'epoch_num': 700, 'batch_size': 1024, 'lr_net': 0.004, 'lr_dyn': 0.001, 'lr_state': 0.1, 'miss_percent': 0.1, 'data_path': '/data/chenmy/voter/seed2050email128128500', 'temp': 1, 'drop_frac': 1, 'cuda': 3}
tensor([1., 3., 0., 3., 0., 0., 1., 3., 0., 0., 0., 3.], device='cuda:3')
start_time: 2020-12-22 22:44:08
data_num 357 object_matrix tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]], device='cuda:3')


In [14]:
def gener_init_H(obs,obj,label_non_obs):
    missing_edges = int(torch.sum(obj*label_non_obs).item())
    z_ele_num = int(torch.sum(label_non_obs).item())
    z_element_choice = torch.randperm(z_ele_num)[:missing_edges]
    if use_cuda:
        z_element_choice = z_element_choice.cuda()
    z_element = torch.nonzero(label_non_obs)
#     print(z_element,z_element_choice)
    init_z_edges = torch.index_select(z_element,0,z_element_choice)

    obs[init_z_edges[:,0],init_z_edges[:,1]] = 1
    return obs

In [15]:
H = gener_init_H(object_matrix*kn_mask.cuda(),object_matrix,left_mask.cuda())

In [17]:
p0 = torch.FloatTensor([[0.8, 0.5],[0.4, 0.1]]) 
generator = Gumbel_union_kron(p0,korder=7,label_non_obs = left_mask,init_H = H,warmup = 30)
op_net = optim.Adam(generator.parameters(), lr=HYP['lr_net'])
# generator.sample_all()

In [None]:
# 训练全网的正向动力学
'''voter states,dyn,NET 同时训练 '''
# indices to draw the curve
num = 0
def train_dyn_net_state():
    loss_batch = []
    ymae_batch = []
    for idx,(data,states_id) in enumerate(zip(train_loader,train_index)):
        x = data[0].float().to(device)
        y = data[1].to(device).long()

        x_kn = x[:,:-del_num,:]
        y_kn = y[:,:-del_num]
    
        x_un = x[:,-del_num:,:]
        generator.drop_temp()
        outputs = torch.zeros(y.size(0), y.size(1),2)
        loss_node = []
        ymae_node = []
        
        for j in range(HYP['node_num']-del_num):
            op_net.zero_grad()
            op_dyn.zero_grad()
            opt_states.zero_grad()
            
            x_un_pre = states_learner(states_id.cuda())
            
            x_hypo = torch.cat((x_kn,x_un_pre.float()),1)
    
            adj_col = generator.sample_all()[:,j]
            adj_col[:-del_num] = observed_adj[j].cuda()

            
            y_hat = dyn_isom(x_hypo, adj_col, j, num, HYP['node_num']-del_num)
            loss = loss_fn(y_hat,y[:,j])
            loss.backward()

            mae = torch.mean(abs(y[:,j] - torch.argmax(y_hat,1)).float())  

            # cut gradient in case nan shows up
            U.clip_grad_norm_(generator.gen_matrix, 0.000075)

            op_dyn.step()
            op_net.step()
            opt_states.step()

            # use outputs to caculate mse
            outputs[:, j, :] = y_hat
            
            # record
            ymae_node.append(mae.item())
            loss_node.append(loss.item())

        ymae_batch.append(np.mean(ymae_node))
        loss_batch.append(np.mean(loss_node))
        
    return np.mean(loss_batch),np.mean(ymae_batch)

In [None]:




# train
val_epoch =  100
choose_num  = int(HYP['node_num']/3)
losses = []
val_loss_epoch = []
metric_epoch_yt=[]

for epoch in range(HYP['epoch_num']):
    start_time = time.time()
    loss_y,mse_y = train_dyn_net_state()
    (index_order,auc_net,precision,kn_un_precision,un_un_precision) = part_constructor_evaluator_sgm(generator,1,object_matrix,HYP["node_num"],del_num)
    metric_epoch_yt.append([auc_net,precision,kn_un_precision,un_un_precision])
    losses.append([loss_y,mse_y])
    # if auc_net[0]>0.95:
    #     break
    print(epoch,'gumbel all Net error: %f,kn_un :%f,un_un:%f'%(round(float(precision[1].item()/left_mask.sum()),2),round(float(kn_un_precision[1].item()/kn_un_mask.sum()),2),round(float(un_un_precision[1].item()/un_un_mask.sum()),2)))   
    print("index_order",index_order)
    print('loss_y:%f,mse_y:%f'%(loss_y,mse_y))
    end_time = time.time()
    print("cost_time",str(round(end_time -start_time, 2)))

    if (epoch+1)%100 ==0:
        print("\nstatred val",epoch)
        val_losses = [];val_maes = []
        for i in range(val_epoch):
            val_loss,val_mae = val()
            val_losses.append(val_loss)
            val_maes.append(val_mae)
        vloss = torch.mean(torch.FloatTensor(val_losses))
        vmae = torch.mean(torch.FloatTensor(val_maes))
        val_loss_epoch.append([vloss,vmae])
        print('     val loss:' + str(vloss)+' val mse:' + str(vmae),'\n')#
end_time = time.time()
print("cost_time",str(round(end_time -start_time, 2)))

In [None]:
def val():
    loss_batch = []
    ymae_batch = []
    choose_node = torch.randint(0,HYP['node_num']-del_num,[choose_num])
    hypo_adj = generator.sample_all().detach()
    hypo_adj[:-del_num,:-del_num] = observed_adj
    for (idx,data),state_id in zip(enumerate(val_loader),val_index):
        
        x = data[0].float().to(device)
        y = data[1].to(device).long()
        
        x_kn = x[:,:-del_num,:]
        y_kn = y[:,:-del_num]
    
        x_un = x[:,-del_num:,:]
        generator.drop_temp()
        outputs = torch.zeros(y.size(0), y.size(1),2)
        loss_node = []
        states_node = []

        for index,j in enumerate(choose_node):

            x_un_pre = states_learner_v(state_id.cuda())
            x = torch.cat((x_kn,x_un_pre.float()),1)
            
            opt_states_v.zero_grad()
            adj_col = hypo_adj[:,j].cuda()# hard = true
            y_hat = dyn_isom(x, adj_col, j, num, HYP['node_num'])
            
            loss = loss_fn(y_hat,y[:,j])
            loss.backward()
            opt_states_v.step()
            ymae = torch.mean(abs(y[:,j] - torch.argmax(y_hat,1)).float())  
            # states_node
            outputs[:, index, :] = y_hat
            
            # record
            loss_node.append(loss.item())            
            states_node.append(ymae.item())

        ymae_batch.append(np.mean(states_node))
        loss_batch.append(np.mean(loss_node))
        
    return torch.mean(torch.FloatTensor(loss_batch)),torch.mean(torch.FloatTensor(ymae_batch))

In [2]:
class kronecker_Generator(nn.Module):
    def __init__(self,p0,korder = 3,node_num = 2):
        super(kronecker_Generator,self).__init__()
        self.p = Parameter(p0,requires_grad = True)
        # self.p = Parameter(torch.rand(node_num,node_num,requires_grad=True))
        self.korder = korder
        # print(self.p)
    def generator_adjacency(self):
        k = self.korder
        p0 = self.p
        adj = self.p
        for i in range(k-1):
            adj = kronecker(adj,p0)
        return adj
def loss_func(sigma,Pk):
    loss = -torch.sum((1-sigma)*torch.log(1-Pk)+sigma*torch.log(Pk))
    return loss
def metropolis_update_ratio(sigma_before,sigma_later,Pk):
    '''
    if memory is sufficinet, this one is much cleaner to execute
    '''
    Nll_before=(1-sigma_before)*np.log(1-Pk) +sigma_before*np.log(Pk)
    Nll_later=(1-sigma_later)*np.log(1-Pk) +sigma_later*np.log(Pk)
    ratio=np.exp(np.sum(Nll_later-Nll_before))
    return ratio
def SwapElement(sigma_before,i,j):
    i_topology=sigma_before[i,:]
    j_topology=sigma_before[j,:]
    sigma_later=np.copy(sigma_before)
    sigma_later[i,:]=j_topology
    sigma_later[j,:]=i_topology
    sigma_later[:,i]=sigma_before[:,j]
    sigma_later[:,j]=sigma_before[:,i]
    sigma_later[i,j]=sigma_before[i,j]
    sigma_later[j,i]=sigma_before[j,i]
    sigma_later[i,i]=sigma_before[j,j]
    sigma_later[j,j]=sigma_before[i,i]
    return sigma_later

def SamplePermutation(Pk,sigma,u,n1_swap,n2_swap,label_non_obs,obj_adj):
    sigma_later=SwapElement(sigma,n1_swap, n2_swap)
    ratio=metropolis_update_ratio(sigma,sigma_later,Pk)
    if u<ratio:
        sigma=sigma_later
        label_non_obs=SwapElement(label_non_obs,n1_swap, n2_swap)
        obj_adj = SwapElement(obj_adj,n1_swap,n2_swap)
#         print(ifswap)
    return sigma,label_non_obs,obj_adj
def SampleZ(H,Pk,label_non_obs,u):
    mat_size=len(Pk)
    edge_in_non_obs=(H>0)*label_non_obs  
    edge_position=np.where(edge_in_non_obs)
    edge_removed=np.random.randint(len(edge_position[0]))
    px=Pk[edge_position[0][edge_removed],edge_position[1][edge_removed]]
    non_edge_in_non_obs=(H<1)*label_non_obs
    ##return 
    py_array=non_edge_in_non_obs*Pk
    py_array=np.ravel(py_array)
    py=np.random.choice(range(len(py_array)), size=1, p=py_array/np.sum(py_array))
    ratio=(1-py_array[py])/(1-px)
    if ratio<u:
        H[py//mat_size,py%mat_size]=1
        H[edge_position[0][edge_removed],edge_position[1][edge_removed]]=0
        # print('accept')
    return H

def missing_label(sz,missing_percent):
     # random sample
    missing_num = int(sz*missing_percent)
    idx = torch.randperm(sz)[:missing_num]
    mask_un_obs = torch.zeros(sz,sz)
    for i in idx:
        mask_un_obs[i] = 1
        mask_un_obs[:,i:i+1] = 1
    mask_obs = 1- mask_un_obs
    return mask_un_obs,mask_obs,idx
# only sigma and kronecker pk are itered
def E_step(sigma,N,warmup,Pk,label_non_obs,obj_adj):
    u1=np.random.rand(N+2*warmup)
    u2=np.random.rand(N+2*warmup)
    sigma_hist=[]
    Z_label=[]
    Node_list=np.arange(len(Pk))
    element_to_swap=np.random.choice(a=Node_list,size=(2,3*(N+warmup)))
    mask=element_to_swap[1,:]!=element_to_swap[0,:]# it is pointless to swap the same element
    n1_swap=element_to_swap[0,:][mask]
    n2_swap=element_to_swap[1,:][mask]
    
    for i in range(warmup):
        sigma,label_non_obs,obj_adj=SamplePermutation(Pk,sigma,u2[i],n1_swap[i],n2_swap[i],label_non_obs,obj_adj)  
    for j in range(warmup):
        sigma=SampleZ(sigma,Pk,label_non_obs,u1[i+j:i+j+1]) 
    for k in range(N):
        sigma=SampleZ(sigma,Pk,label_non_obs,u1[i+j+1+k:i+j+1+k+1]) 
        sigma_hist.append(sigma)
    Z_label.append(label_non_obs)
    return sigma_hist,Z_label,obj_adj #（H+G）

def M_step(epoch,sigma_train,p0,k,N):
    losses = []
    generator = kronecker_Generator(p0,k,2)
    learning_rate = 1e-5#0.0000001
    opt_net = optim.SGD(generator.parameters(),lr = learning_rate)
    decayRate = 0.95

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt_net, gamma=decayRate) 
    for i in range(epoch):   
        opt_net.zero_grad()
        Pk = generator.generator_adjacency()
        loss = loss_func(sigma_train,Pk,N)
        loss2 = loss2_func(sigma_train.detach().numpy(),Pk.detach().numpy(),N)
        loss.backward()
        losses.append(loss.item())
        # print(str(i),loss.item(),loss2.item(),(loss2-loss).item())
        opt_net.step()
        scheduler.step()
        for group in opt_net.param_groups:
            for param in group["params"]: 
              # print("before param",param)
              param.data.clamp_(0.0001,0.9999)
              # print("after param",param)
    
        for p in generator.parameters():
            p0 = p.data  
        generator = kronecker_Generator(p0,k,2)
        Pk = generator.generator_adjacency()
        Pk = Pk.detach().numpy() 
        # evaluation
        #   nll infer  nll true
    return np.mean(losses),Pk,p0

def kronEM(iterstep,N,warmup,H,Pk,label_non_obs,epoch,p0,obj_adj,k):
    emlosses = []
    for i in range(iterstep):
        print("start iterstep",i)
        sigma_hist,Z_label,obj_adj = E_step(H,N,warmup,Pk,label_non_obs,obj_adj)
        print("E-step")
        H = sigma_hist[-1]
        label_non_obs = Z_label[-1]
        sigma_hist_train = np.array(sigma_hist)
        sigma_train = torch.DoubleTensor(sigma_hist_train)
        emloss,Pk,p0 = M_step(epoch,sigma_train,p0,k,N)
        emlosses.append(emloss)
        print("*************\n EM loss is %f"%emloss,"p0 is ",p0)
        
        # evaluation
        inferp_nll = NLL(H,Pk)
        print("inferp_nll is %f "%(inferp_nll))
        obs_mask = (1-label_non_obs).astype(bool)
        obs_auc = calauc(obj_adj,Pk,obs_mask)
        non_obs_auc = calauc(obj_adj,Pk,label_non_obs.astype(bool))
        #         print("perm is ", perm)
        
        label_obs = (1-label_non_obs)
        obs_diff_edge = abs((H - obj_adj)*label_obs).sum()
        unobs_diff_edge =  abs((H - obj_adj)*label_non_obs).sum()
  
        all_diff_edge = abs((H - obj_adj)).sum()
        # D_abs = torch.abs(orip-p0).mean()
        c=calauc(H,Pk,(label_non_obs).astype(bool))
        print(c)
        print("infer p is ",p0,  "obseved auc  is %f and non_obs auc is %f and obs_diff_edge is %f and un_obs_diff_edge %f and all_diff_edge is %f"%(obs_auc,non_obs_auc,obs_diff_edge,unobs_diff_edge,all_diff_edge))
    
    return emlosses,H,Pk,label_non_obs,obj_adj

def NLL(sigma_true,pk):
    Nll_before=(1-sigma_true)*np.log(1-pk) +sigma_true*np.log(pk)
    return -np.sum(Nll_before)

def loss2_func(sigma_train,Pk,N):
    loss2 = 0
    for i in range(N):
        loss2+= NLL(sigma_train[i],Pk)
    # print(loss2)
    return loss2/N

def loss_func(sigma,Pk,N):
    loss = -torch.sum((1-sigma)*torch.log(1-Pk)+sigma*torch.log(Pk))
    return loss/N

def generator_adj(korder,p):
    k = korder
    p0 = p
    adj = p
    for i in range(k-1):
        adj = kronecker(adj,p0)
    return adj
def calauc(H,Pk,mask):
    fpr, tpr, thresholds = roc_curve(H[mask],Pk[mask])
    Auc = auc(fpr, tpr)
    return Auc

In [3]:
k =7
sz = 2**k
remove_proportion = 0.25
del_num = int(sz*remove_proportion)

orip = torch.FloatTensor([[0.9,0.6],[0.4, 0.2]]) # torch.FloatTensor([[0.1981, 0.6427],[0.9684, 0.4522]])  
print("objective p", orip)
ground_adj = generator_adj(k,orip)

Ground_truth_adj_before = (ground_adj>torch.rand((ground_adj.shape))).float()

objective p tensor([[0.9000, 0.6000],
        [0.4000, 0.2000]])


In [46]:
# sigma=Ground_truth_adj
baseline = loss_func(Ground_truth_adj_before,ground_adj,1)
print("base",baseline)

np_baseline = NLL(Ground_truth_adj_before.data.numpy(),ground_adj.data.numpy())
print("base2",np_baseline)

# shuffle 
node_num=sz
permu_m = torch.eye(node_num)
permutetion_arrange = torch.randperm(node_num)
permu_m_shuffle= torch.index_select(permu_m,0,permutetion_arrange) 

Ground_truth_adj = torch.mm(permu_m_shuffle,Ground_truth_adj_before).mm(permu_m_shuffle.t())
Ground_truth_adj =Ground_truth_adj_before
mask_un_obs,mask_obs,miss_idx = missing_label(sz,remove_proportion)
G = Ground_truth_adj*mask_obs
init_z = Ground_truth_adj*mask_un_obs
missing_edges = int(init_z.sum())

# initial partial z with fixed missing edges num
z_ele_num = int(mask_un_obs.sum())
z_element_choice = torch.randperm(z_ele_num)[:missing_edges]
z_element = torch.nonzero(mask_un_obs)
init_z_edges = torch.index_select(z_element,0,z_element_choice)
H = G.clone()
H[init_z_edges[:,0],init_z_edges[:,1]] = 1
# print(G,H,z_element_choice)
# initial kronecker

p0 = torch.FloatTensor([[0.8, 0.5],[0.4, 0.1]])   # torch.rand([2,2])
generator = kronecker_Generator(p0,k,2)
Pk = generator.generator_adjacency()
init_H = H.data.numpy()
label_non_obs = mask_un_obs.data.numpy()
init_Pk = Pk.detach().data.numpy()
init_label_non_obs= label_non_obs.copy()
objective_adj = np.array(Ground_truth_adj)

base tensor(791.8375)
base2 791.83716


In [None]:
def SwapElement(sigma_before,i,j):
    i_topology=sigma_before[i,:]
    j_topology=sigma_before[j,:]
    sigma_later=np.copy(sigma_before)
    sigma_later[i,:]=j_topology
    sigma_later[j,:]=i_topology
    sigma_later[:,i]=sigma_before[:,j]
    sigma_later[:,j]=sigma_before[:,i]
    sigma_later[i,j]=sigma_before[i,j]
    sigma_later[j,i]=sigma_before[j,i]
    sigma_later[i,i]=sigma_before[j,j]
    sigma_later[j,j]=sigma_before[i,i]
    
    return sigma_later

In [109]:
def getswappk(warmup):
    u1=np.random.rand(warmup)
    Node_list=np.arange(len(Pk))
    element_to_swap=np.random.choice(a=Node_list,size=(2,3*(warmup)))
    mask=element_to_swap[1,:]!=element_to_swap[0,:]# it is pointless to swap the same element
    n1_swap=element_to_swap[0,:][mask]
    n2_swap=element_to_swap[1,:][mask]
    for i in range(warmup):
        pk = SamplePermutationPk(init_Pk,n1_swap[i],n2_swap[i],index,init_H,u1[i])
    return pk

In [110]:
torch.rand(100)

tensor([0.9255, 0.6935, 0.1863, 0.6986, 0.5002, 0.5029, 0.6360, 0.4869, 0.2306,
        0.7959, 0.0481, 0.2977, 0.9555, 0.3731, 0.7308, 0.7063, 0.9098, 0.8276,
        0.6475, 0.0466, 0.5031, 0.4216, 0.4395, 0.4553, 0.8027, 0.4982, 0.1635,
        0.7212, 0.7280, 0.0915, 0.3064, 0.8402, 0.0142, 0.3212, 0.1841, 0.7992,
        0.7759, 0.9877, 0.0595, 0.4103, 0.5849, 0.0304, 0.3970, 0.1216, 0.7669,
        0.8689, 0.4970, 0.0456, 0.3238, 0.5318, 0.6152, 0.3125, 0.8939, 0.8318,
        0.4548, 0.1058, 0.7522, 0.5967, 0.1546, 0.7176, 0.4060, 0.2343, 0.7893,
        0.7611, 0.1940, 0.8506, 0.7182, 0.0822, 0.2599, 0.0214, 0.8159, 0.0642,
        0.6513, 0.0272, 0.2667, 0.5180, 0.7042, 0.4255, 0.0788, 0.2195, 0.1731,
        0.8244, 0.6016, 0.7863, 0.3811, 0.7120, 0.6345, 0.5151, 0.6208, 0.6456,
        0.5823, 0.6553, 0.3517, 0.1856, 0.6264, 0.9657, 0.0824, 0.3345, 0.2266,
        0.4657])

In [None]:
 def sample_all(self, hard=False):
        self.logp = self.gen_matrix
        
        if use_cuda:
            self.logp = self.gen_matrix.cuda()
        
        out = gumbel_softmax(self.logp, self.temperature, hard)
        if hard:
            hh = torch.zeros((self.del_num*(2*self.sz - self.del_num-1),2))
            for i in range(out.size()[0]):
                hh[i, out[i]] = 1
            out = hh                    
        out = out[:, 0]
        if use_cuda:
            out = out.cuda()
            
        matrix = torch.zeros(self.sz,self.sz).cuda()
        left_mask = torch.ones(self.sz,self.sz)
        left_mask[:-self.del_num,:-self.del_num] = 0
        left_mask = left_mask - torch.diag(torch.diag(left_mask))
        un_index = left_mask.nonzero()
        matrix[(un_index[:,0],un_index[:,1])] = out
        out_matrix = matrix
        # out_matrix = out[:, 0].view(self.gen_matrix.size()[0], self.gen_matrix.size()[0])
        return out_matrix
    def init(self, mean, var):
        init.normal_(self.gen_matrix, mean=mean, std=var)

tensor([[2.0972e-01, 1.3107e-01, 1.3107e-01,  ..., 1.2500e-02, 1.2500e-02,
         7.8125e-03],
        [1.0486e-01, 2.6214e-02, 6.5536e-02,  ..., 2.5000e-03, 6.2500e-03,
         1.5625e-03],
        [1.0486e-01, 6.5536e-02, 2.6214e-02,  ..., 6.2500e-03, 2.5000e-03,
         1.5625e-03],
        ...,
        [3.2768e-03, 8.1920e-04, 2.0480e-03,  ..., 8.0000e-07, 2.0000e-06,
         5.0000e-07],
        [3.2768e-03, 2.0480e-03, 8.1920e-04,  ..., 2.0000e-06, 8.0000e-07,
         5.0000e-07],
        [1.6384e-03, 4.0960e-04, 4.0960e-04,  ..., 4.0000e-07, 4.0000e-07,
         1.0000e-07]], grad_fn=<ViewBackward>)
tensor([[2.0972e-01, 1.3107e-01, 1.3107e-01,  ..., 1.2500e-02, 1.2500e-02,
         7.8125e-03],
        [1.0486e-01, 2.6214e-02, 6.5536e-02,  ..., 2.5000e-03, 6.2500e-03,
         1.5625e-03],
        [1.0486e-01, 6.5536e-02, 2.6214e-02,  ..., 6.2500e-03, 2.5000e-03,
         1.5625e-03],
        ...,
        [3.2768e-03, 8.1920e-04, 2.0480e-03,  ..., 8.0000e-07, 2.0000e-06,
 

tensor([[2.0972e-01, 1.3107e-01, 1.3107e-01,  ..., 1.2500e-02, 1.2500e-02,
         7.8125e-03],
        [1.0486e-01, 2.6214e-02, 6.5536e-02,  ..., 2.5000e-03, 6.2500e-03,
         1.5625e-03],
        [1.0486e-01, 6.5536e-02, 2.6214e-02,  ..., 6.2500e-03, 2.5000e-03,
         1.5625e-03],
        ...,
        [3.2768e-03, 8.1920e-04, 2.0480e-03,  ..., 8.0000e-07, 2.0000e-06,
         5.0000e-07],
        [3.2768e-03, 2.0480e-03, 8.1920e-04,  ..., 2.0000e-06, 8.0000e-07,
         5.0000e-07],
        [1.6384e-03, 4.0960e-04, 4.0960e-04,  ..., 4.0000e-07, 4.0000e-07,
         1.0000e-07]], grad_fn=<ViewBackward>)

In [118]:
pk = generator.shuffle_adj(1)

In [129]:
label_non_obs = torch.FloatTensor(label_non_obs)

In [125]:
pk

tensor([[2.0972e-01, 1.3107e-01, 1.3107e-01,  ..., 1.2500e-02, 1.2500e-02,
         7.8125e-03],
        [1.0486e-01, 2.6214e-02, 6.5536e-02,  ..., 2.5000e-03, 6.2500e-03,
         1.5625e-03],
        [1.0486e-01, 6.5536e-02, 2.6214e-02,  ..., 6.2500e-03, 2.5000e-03,
         1.5625e-03],
        ...,
        [3.2768e-03, 8.1920e-04, 2.0480e-03,  ..., 8.0000e-07, 2.0000e-06,
         5.0000e-07],
        [3.2768e-03, 2.0480e-03, 8.1920e-04,  ..., 2.0000e-06, 8.0000e-07,
         5.0000e-07],
        [1.6384e-03, 4.0960e-04, 4.0960e-04,  ..., 4.0000e-07, 4.0000e-07,
         1.0000e-07]], grad_fn=<ViewBackward>)

tensor([[1.3107e-01, 8.6893e-01],
        [8.1920e-02, 9.1808e-01],
        [5.1200e-02, 9.4880e-01],
        ...,
        [1.6000e-06, 1.0000e+00],
        [1.6000e-06, 1.0000e+00],
        [1.6000e-06, 1.0000e+00]], grad_fn=<CatBackward>)

In [None]:
    
    
def sample_all(self,hard=False):
'''sample from kron para'''
self.logp = self.gen_matrix


# out_matrix = out[:, 0].view(self.gen_matrix.size()[0], self.gen_matrix.size()[0])
return out_matrix



In [None]:
torch.cat()

In [None]:
if use_cuda:
    self.logp = self.gen_matrix.cuda()
out = gumbel_softmax(self.logp, self.temperature, hard)

if hard:
    hh = torch.zeros((self.del_num*(2*self.sz - self.del_num-1),2))
    for i in range(out.size()[0]):
        hh[i, out[i]] = 1
    out = hh                    
out = out[:, 0]
if use_cuda:
    out = out.cuda()
matrix = torch.zeros(self.sz,self.sz).cuda()
left_mask = torch.ones(self.sz,self.sz)
left_mask[:-self.del_num,:-self.del_num] = 0
left_mask = left_mask - torch.diag(torch.diag(left_mask))
un_index = left_mask.nonzero()
matrix[(un_index[:,0],un_index[:,1])] = out
out_matrix = matrix

In [None]:
        if hard:
            hh = torch.zeros((self.del_num*(2*self.sz - self.del_num-1),2))
            for i in range(out.size()[0]):
                hh[i, out[i]] = 1
            out = hh                    
        out = out[:, 0]
        if use_cuda:
            out = out.cuda()
        matrix = torch.zeros(self.sz,self.sz).cuda()
        left_mask = torch.ones(self.sz,self.sz)
        left_mask[:-self.del_num,:-self.del_num] = 0
        left_mask = left_mask - torch.diag(torch.diag(left_mask))
        un_index = left_mask.nonzero()
        matrix[(un_index[:,0],un_index[:,1])] = out
        out_matrix = matrix
        # out_matrix = out[:, 0].view(self.gen_matrix.size()[0], self.gen_matrix.size()[0])

tensor([[1.3107e-01],
        [8.1920e-02],
        [5.1200e-02],
        ...,
        [1.6000e-06],
        [1.6000e-06],
        [1.6000e-06]], grad_fn=<UnsqueezeBackward0>)

In [None]:
sample_all 


for i in range

In [108]:


def SamplePermutationPk(Pk,pos_1,pos_2,index_pk,A,u):
    Pk=Pk.ravel()[index_pk.ravel()].reshape(Pk.shape)
    index_pk_later=SwapElement(index_pk,pos_1,pos_2)
    print(pos_1,pos_2)
#     print(index_pk_later[pos_1],index_pk_later[pos_2])
    Pk_later=Pk.ravel()[index_pk_later.ravel()].reshape(Pk.shape)
    print((Pk[pos_1]-Pk_later[pos_2]).sum())
    ratio=metropolis_update_ratio_Pk(Pk,Pk_later,A)
#     print(ratio)
    if u<ratio:
        Pk=Pk_later
        index_pk=index_pk_later
        print((Pk[pos_1]-Pk_later[pos_2]).sum())
#         print(index_pk_later)
    return index_pk,Pk

In [146]:
torch.rand(del_num*(2*sz-del_num-1), 2)

tensor([[0.4231, 0.5536],
        [0.8359, 0.3100],
        [0.4056, 0.9016],
        ...,
        [0.5848, 0.5582],
        [0.6054, 0.2039],
        [0.4154, 0.2149]])

In [109]:
def init_index(Pk):
    index=np.arange(len(Pk.ravel())).reshape(Pk.shape) 
    return index

In [None]:
Ground_truth_adj

In [None]:
# sigma=Ground_truth_adj
baseline = loss_func(Ground_truth_adj_before,ground_adj,1)
print("base",baseline)

np_baseline = NLL(Ground_truth_adj_before.data.numpy(),ground_adj.data.numpy())
print("base2",np_baseline)

# shuffle 
node_num=sz
permu_m = torch.eye(node_num)
permutetion_arrange = torch.randperm(node_num)
permu_m_shuffle= torch.index_select(permu_m,0,permutetion_arrange) 

# Ground_truth_adj = torch.mm(permu_m_shuffle,Ground_truth_adj_before).mm(permu_m_shuffle.t())
Ground_truth_adj =Ground_truth_adj_before
mask_un_obs,mask_obs,miss_idx = missing_label(sz,remove_proportion)
G = Ground_truth_adj*mask_obs
init_z = Ground_truth_adj*mask_un_obs
missing_edges = int(init_z.sum())

# initial partial z with fixed missing edges num
z_ele_num = int(mask_un_obs.sum())
z_element_choice = torch.randperm(z_ele_num)[:missing_edges]
z_element = torch.nonzero(mask_un_obs)
init_z_edges = torch.index_select(z_element,0,z_element_choice)
H = G.clone()
H[init_z_edges[:,0],init_z_edges[:,1]] = 1
# print(G,H,z_element_choice)
# initial kronecker

p0 = torch.FloatTensor([[0.8, 0.5],[0.4, 0.1]])   # torch.rand([2,2])
generator = kronecker_Generator(p0,k,2)
Pk = generator.generator_adjacency()
init_H = H.data.numpy()
label_non_obs = mask_un_obs.data.numpy()
init_Pk = Pk.detach().data.numpy()
init_label_non_obs= label_non_obs.copy()
objective_adj = np.array(Ground_truth_adj)

calauc(Ground_truth_adj_before.numpy(),ground_adj.numpy(),(1-label_non_obs).astype(bool))
c=(ground_adj*Ground_truth_adj_before)[label_non_obs.astype(bool)][(ground_adj*Ground_truth_adj_before)[label_non_obs.astype(bool)]>0]
print(np.sort(c)[::-1])

In [None]:
adj_address ="/data/chenmy/voter/seed1051email128-adjmat.pickle"
with open(adj_address,'rb') as f:
    objective_adj = pickle.load(f,encoding='latin1')
objective_adj = np.array(objective_adj)

In [8]:
Ground_truth_adj = torch.FloatTensor(objective_adj)
sz = Ground_truth_adj.shape[0]
k = int(np.log2(sz))
remove_proportion = 0.1
del_num = int(sz*remove_proportion)
mask_un_obs,mask_obs,miss_idx = missing_label(sz,remove_proportion)
G = Ground_truth_adj*mask_obs
nG = G.data.numpy()
H = G.clone()
# print(2,(abs(H.data.numpy()-objective_adj)*mask_obs.data.numpy()).sum())
init_z = Ground_truth_adj*mask_un_obs
missing_edges = int(init_z.sum())
print("miss nodes ",miss_idx)
# initial partial z with fixed missing edges num
z_ele_num = int(mask_un_obs.sum())
print("miss_edge",z_ele_num)
z_element_choice = torch.randperm(z_ele_num)[:missing_edges]
z_element = torch.nonzero(mask_un_obs)
init_z_edges = torch.index_select(z_element,0,z_element_choice)
H[init_z_edges[:,0],init_z_edges[:,1]] = 1
# print(abs((H - Ground_truth_adj)*mask_obs).sum())
# print(G,H,z_element_choice)
# initial kronecker
p0 = torch.FloatTensor([[0.9,0.7],[0.7, 0.3]])#torch.FloatTensor([[0.4408, 0.1770],[0.4951, 0.2585]])  # 
generator = kronecker_Generator(p0,k,2)
Pk = generator.generator_adjacency()
init_H = H.data.numpy()
label_non_obs = mask_un_obs.data.numpy()
init_Pk = Pk.detach().data.numpy()
perm = np.eye(sz)
print(abs((init_H- objective_adj)*mask_obs.data.numpy()).sum())
# init_label_non_obs= label_non_obs.copy()
init_H_ori = init_H.copy()

miss nodes  tensor([ 35, 111,  22,  27,  51, 126,  83,  69,  74,  91,  99,  25])
miss_edge 2928
0.0


In [9]:
missing_edges

222

In [10]:
def gumbel_sample(shape, eps=1e-20):
    u = torch.rand(shape)
    gumbel = - np.log(- np.log(u + eps) + eps)
    if use_cuda:
        gumbel = gumbel.cuda()
    return gumbel

def gumbel_softmax_sample(logits, temperature):
    """ Draw a sample from the Gumbel-Softmax distribution"""
    # gumbel_sample 返回一个sample采样
    y = logits + gumbel_sample(logits.size())
    return torch.nn.functional.softmax(y/temperature, dim=1)

def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
    logits: [batch_size, n_class] unnormalized log-probs
    temperature: non-negative scalar
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
    If hard=True, then the returned sample will be one-hot, otherwise it will
    be a probabilitiy distribution that sums to 1 across classes
    """
    
    y = gumbel_softmax_sample(logits, temperature)
    
    if hard:
        k = logits.size()[-1]
        y_hard = torch.max(y.data, 1)[1]
        y = y_hard
    return y


In [15]:
# %load kronEM.py
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.optim as optim
from sklearn.metrics import roc_curve,auc
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve,auc
import numpy as np

def kronecker(A,B):
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
class kronecker_Generator(nn.Module):
    def __init__(self,p0,korder = 3,node_num = 2):
        super(kronecker_Generator,self).__init__()
        self.p = Parameter(p0,requires_grad = True)
        # self.p = Parameter(torch.rand(node_num,node_num,requires_grad=True))
        self.korder = korder
        # print(self.p)
    def generator_adjacency(self):
        k = self.korder
        p0 = self.p
        adj = self.p
        for i in range(k-1):
            adj = kronecker(adj,p0)
        return adj
def loss_func(sigma,Pk):
    loss = -torch.sum((1-sigma)*torch.log(1-Pk)+sigma*torch.log(Pk))
    return loss
def metropolis_update_ratio(sigma_before,sigma_later,Pk):
    '''
    if memory is sufficinet, this one is much cleaner to execute
    '''
    Nll_before=(1-sigma_before)*np.log(1-Pk) +sigma_before*np.log(Pk)
    Nll_later=(1-sigma_later)*np.log(1-Pk) +sigma_later*np.log(Pk)
    ratio=np.exp(np.sum(Nll_later-Nll_before))
    return ratio
def SwapElement(sigma_before,i,j):
    i_topology=sigma_before[i,:]
    j_topology=sigma_before[j,:]
    sigma_later=np.copy(sigma_before)
    sigma_later[i,:]=j_topology
    sigma_later[j,:]=i_topology
    sigma_later[:,i]=sigma_before[:,j]
    sigma_later[:,j]=sigma_before[:,i]
    sigma_later[i,j]=sigma_before[i,j]
    sigma_later[j,i]=sigma_before[j,i]
    sigma_later[i,i]=sigma_before[j,j]
    sigma_later[j,j]=sigma_before[i,i]
    return sigma_later

def SamplePermutation(Pk,sigma,u,n1_swap,n2_swap,label_non_obs,perm):
    sigma_later=SwapElement(sigma,n1_swap, n2_swap)
    ratio=metropolis_update_ratio(sigma,sigma_later,Pk)
#     Pk_ori = Pk.copy()
    if u<ratio:
        sigma=sigma_later
        label_non_obs=SwapElement(label_non_obs,n1_swap, n2_swap)
        perm = SwapElement(perm,n1_swap,n2_swap)
#         perm[n1_swap], perm[n2_swap] =  perm[n2_swap],perm[n1_swap]
#         Pk_ori = SwapElement(Pk_ori,n1_swap,n2_swap)
#     print(abs(Pk_ori-Pk).sum())
    return sigma,label_non_obs,perm
def SampleZ(H,Pk,label_non_obs,u):
    mat_size=len(Pk)
    edge_in_non_obs=(H>0)*label_non_obs  
    edge_position=np.where(edge_in_non_obs)
    edge_removed=np.random.randint(len(edge_position[0]))
    px=Pk[edge_position[0][edge_removed],edge_position[1][edge_removed]]
    non_edge_in_non_obs=(H<1)*label_non_obs
    ##return 
    py_array=non_edge_in_non_obs*Pk
    py_array=np.ravel(py_array)
    py=np.random.choice(range(len(py_array)), size=1, p=py_array/np.sum(py_array))
    ratio=(1-py_array[py])/(1-px)
    if ratio<u:
        H[py//mat_size,py%mat_size]=1
        H[edge_position[0][edge_removed],edge_position[1][edge_removed]]=0
        # print('accept')
    return H

def missing_label(sz,missing_percent):
     # random sample
    missing_num = int(sz*missing_percent)
    idx = torch.randperm(sz)[:missing_num]
    mask_un_obs = torch.zeros(sz,sz)
    for i in idx:
        mask_un_obs[i] = 1
        mask_un_obs[:,i:i+1] = 1
    mask_obs = 1- mask_un_obs
    return mask_un_obs,mask_obs,idx
# only sigma and kronecker pk are itered
def E_step(sigma,N,warmup,Pk,label_non_obs,perm):
    u1=np.random.rand(N+warmup)
    u2=np.random.rand(N+warmup)
    sigma_hist=[]
    Z_label=[]
    Node_list=np.arange(len(Pk))
    element_to_swap=np.random.choice(a=Node_list,size=(2,3*(N+warmup)))
    mask=element_to_swap[1,:]!=element_to_swap[0,:]# it is pointless to swap the same element
    n1_swap=element_to_swap[0,:][mask]
    n2_swap=element_to_swap[1,:][mask]
    
    for i in range(warmup):
        sigma,label_non_obs,Pk_ori=SamplePermutation(Pk,sigma,u2[i],n1_swap[i],n2_swap[i],label_non_obs,perm)  
    for j in range(warmup):
        sigma=SampleZ(sigma,Pk,label_non_obs,u1[i]) 
    for k in range(N):
        sigma=SampleZ(sigma,Pk,label_non_obs,u1[i]) 
        sigma_hist.append(sigma)
    
    Z_label.append(label_non_obs)
    return sigma_hist,Z_label,Pk_ori #（H+G）
def M_step(epoch,sigma_train,p0,k,N):
    losses = []
    generator = kronecker_Generator(p0,k,2)
    learning_rate = 1e-5#0.0000001
    opt_net = optim.SGD(generator.parameters(),lr = learning_rate)
    decayRate = 0.95

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt_net, gamma=decayRate) 
    for i in range(epoch):   
        opt_net.zero_grad()
        Pk = generator.generator_adjacency()
        loss = loss_func(sigma_train,Pk,N)
        loss2 = loss2_func(sigma_train.detach().numpy(),Pk.detach().numpy(),N)
        loss.backward()
        losses.append(loss.item())
        # print(str(i),loss.item(),loss2.item(),(loss2-loss).item())
        opt_net.step()
        scheduler.step()
        for group in opt_net.param_groups:
            for param in group["params"]: 
              # print("before param",param)
              param.data.clamp_(0.0001,0.9999)
              # print("after param",param)
    
        for p in generator.parameters():
            p0 = p.data  
        generator = kronecker_Generator(p0,k,2)
        Pk = generator.generator_adjacency()
        Pk = Pk.detach().numpy() 
        # evaluation
        #   nll infer  nll true
    return np.mean(losses),Pk,p0

def kronEM(iterstep,N,warmup,H,Pk,label_non_obs,epoch,p0,perm,obj_adj,k):
    emlosses = []
    for i in range(iterstep):
        print("start iterstep",i)
        sigma_hist,Z_label,Pk_ori = E_step(H,N,warmup,Pk,label_non_obs,perm)
        print("E-step")
        H = sigma_hist[-1]
        label_non_obs = Z_label[-1]
        sigma_hist_train = np.array(sigma_hist)
        sigma_train = torch.DoubleTensor(sigma_hist_train)
        emloss,Pk,p0 = M_step(epoch,sigma_train,p0,k,N)
        emlosses.append(emloss)
        print("*************\n EM loss is %f"%emloss,"p0 is ",p0)
        
        # evaluation
        inferp_nll = NLL(H,Pk)
#         inferp_nll2 = NLL(obj_adj,Pk_ori)
        
        print("inferp_nll is %f "%(inferp_nll))
        obs_mask = (1-label_non_obs).astype(bool)
#         print("2",perm)
#         perm_mat = np.eye(H.shape[0])[:,perm]
        obj_adj = np.dot(perm,np.dot(obj_adj,perm.T))
    
#         obs_auc = calauc(obj_adj,Pk,obs_mask)
#         non_obs_auc = calauc(obj_adj,Pk,label_non_obs.astype(bool))
#         #         print("perm is ", perm)
        
        label_obs = (1-label_non_obs)
        obs_diff_edge = abs((H - obj_adj)*label_obs).sum()
        unobs_diff_edge =  abs((H - obj_adj)*label_non_obs).sum()
        all_diff_edge = abs((H - obj_adj)).sum()
        # D_abs = torch.abs(orip-p0).mean()
#         "obseved auc  is %f and non_obs auc is %f obs_auc,non_obs_auc,
        print("infer p is ",p0, "and obs_diff_edge is %f and un_obs_diff_edge %f and all_diff_edge is %f"%(obs_diff_edge,unobs_diff_edge,all_diff_edge))
    
    return emlosses,perm,H,Pk

def NLL(sigma_true,pk):
    Nll_before=(1-sigma_true)*np.log(1-pk) +sigma_true*np.log(pk)
    return -np.sum(Nll_before)

def loss2_func(sigma_train,Pk,N):
    loss2 = 0
    for i in range(N):
        loss2+= NLL(sigma_train[i],Pk)
    # print(loss2)
    return loss2/N

def loss_func(sigma,Pk,N):
    loss = -torch.sum((1-sigma)*torch.log(1-Pk)+sigma*torch.log(Pk))
    return loss/N

def generator_adj(korder,p):
    k = korder
    p0 = p
    adj = p
    for i in range(k-1):
        adj = kronecker(adj,p0)
    return adj
def calauc(H,Pk,mask):
    fpr, tpr, thresholds = roc_curve(H[mask],Pk[mask])
    Auc = auc(fpr, tpr)
    return Auc

In [16]:
warmup=100
N = 1
iterstep = 1
epoch = 1
emlosses2,perm2,H,pk= kronEM(iterstep,N,warmup,init_H,init_Pk,label_non_obs,epoch,p0,perm,init_H_ori,k)

start iterstep 0
E-step
*************
 EM loss is 4129.971266 p0 is  tensor([[0.8963, 0.7064],
        [0.7060, 0.3557]])
inferp_nll is 4075.004395 
infer p is  tensor([[0.8963, 0.7064],
        [0.7060, 0.3557]]) and obs_diff_edge is 1230.000000 and un_obs_diff_edge 262.000000 and all_diff_edge is 1492.000000


In [25]:
objective_adj.T

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [9]:
remove_para =  label_non_obs.sum()

In [None]:
def kronecker(A,B):
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
class kronecker_Generator(nn.Module):
    def __init__(self,p0,korder = 3,node_num = 2):
        super(kronecker_Generator,self).__init__()
        self.p = Parameter(p0,requires_grad = True)
        # self.p = Parameter(torch.rand(node_num,node_num,requires_grad=True))
        self.korder = korder
        # print(self.p)
    def generator_adjacency(self):
        k = self.korder
        p0 = self.p
        adj = self.p
        for i in range(k-1):
            adj = kronecker(adj,p0)
        return adj

In [None]:
    u1=np.random.rand(N+warmup)
    u2=np.random.rand(N+warmup)
    sigma_hist=[]
    Z_label=[]
    Node_list=np.arange(len(Pk))
    element_to_swap=np.random.choice(a=Node_list,size=(2,3*(N+warmup)))
    mask=element_to_swap[1,:]!=element_to_swap[0,:]# it is pointless to swap the same element
    n1_swap=element_to_swap[0,:][mask]
    n2_swap=element_to_swap[1,:][mask]
    
    for i in range(warmup):
        sigma,label_non_obs,obj_adj=SamplePermutation(Pk,sigma,u2[i],n1_swap[i],n2_swap[i],label_non_obs,obj_adj)

In [None]:
'''在网络补全任务中生成未知部分结构'''
class Gumbel_Generator_nc_asy(nn.Module):
    def __init__(self, sz=10,del_num = 1,temp=10, temp_drop_frac=0.9999):
        super(Gumbel_Generator_nc_asy, self).__init__()
        self.sz = sz
        self.del_num = del_num
        self.gen_matrix = Parameter(torch.rand(del_num*(2*sz-del_num-1), 2)) #cmy get only unknown part parameter
        self.temperature = temp
        self.temp_drop_frac = temp_drop_frac

    def drop_temp(self):
        # 降温过程
        self.temperature = self.temperature * self.temp_drop_frac

    def sample_all(self, hard=False):
        self.logp = self.gen_matrix
        if use_cuda:
            self.logp = self.gen_matrix.cuda()
        
        out = gumbel_softmax(self.logp, self.temperature, hard)
        if hard:
            hh = torch.zeros((self.del_num*(2*self.sz - self.del_num-1),2))
            for i in range(out.size()[0]):
                hh[i, out[i]] = 1
            out = hh                    
        out = out[:, 0]
        if use_cuda:
            out = out.cuda()
            
        matrix = torch.zeros(self.sz,self.sz).cuda()
        left_mask = torch.ones(self.sz,self.sz)
        left_mask[:-self.del_num,:-self.del_num] = 0
        left_mask = left_mask - torch.diag(torch.diag(left_mask))
        un_index = left_mask.nonzero()
        matrix[(un_index[:,0],un_index[:,1])] = out
        out_matrix = matrix
        # out_matrix = out[:, 0].view(self.gen_matrix.size()[0], self.gen_matrix.size()[0])
        return out_matrix
    def init(self, mean, var):
        init.normal_(self.gen_matrix, mean=mean, std=var)

In [20]:
## union code 
kronfit 已知部分 + gumbel   + 两个loss （ll+states loss）

0.0

In [13]:
perm

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  85,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  65,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        37,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  22,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127])

In [16]:
permutation_mat  = np.eye(H.shape[0])[perm]
shuffle_H = np.dot(np.dot(permutation_mat,objective_adj),permutation_mat.T)

In [10]:
sigma_later1=SwapElement(init_H,37,65)

In [8]:
sigma_later=SwapElement(sigma_later,22,85)

In [12]:
(abs(sigma_later1-H)*mask_obs.data.numpy()).sum()

68.0

In [18]:
(abs(shuffle_H-H)*mask_obs.data.numpy()).sum()

0.0