In [None]:
import numpy as np 
import torch 
import ot 
import os
import matplotlib.pyplot as plt
os.chdir('.')
from lib.gromov_test import partial_gromov_ver1,cost_matrix_d,tensor_dot_param,tensor_dot_func,gwgrad_partial,partial_gromov_wasserstein,gwgrad_partial1
from lib.opt import *
from lib.pu_learning import *

import numpy as np 
import numba as nb
import warnings
import time
from ot.backend import get_backend, NumpyBackend
from ot.lp import emd

from sklearn.datasets import load_svmlight_file
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

from sklearn.metrics import accuracy_score, recall_score, precision_score


In [None]:
@nb.njit(cache=True)
def tensor_dot_param(C1,C2,Lambda=0,loss='square_loss'):
    if loss=='square_loss':
        def f1(r1):
            return r1**2-2*Lambda
        def f2(r2):
            return r2**2
        def h1(r1):
            return r1
        def h2(r2):
            return 2*r2
    # else:
    #     warnings.warn("loss function error")

    fC1=f1(C1)
    fC2=f2(C2)
    hC1=h1(C1)
    hC2=h2(C2)
    
    return fC1,fC2,hC1,hC2

@nb.njit(cache=True)
def tensor_dot_func(fC1,fC2,hC1,hC2,Gamma):
    #Gamma=np.ascontiguousarray(Gamma)
    n,m=Gamma.shape
    Gamma_1=Gamma.sum(1).reshape((-1,1))
    Gamma_2=Gamma.sum(0).reshape((-1,1))
    C1=fC1.dot(Gamma_1).dot(np.ones((1,m)))
    C2=np.ones((n,1)).dot(Gamma_2.T).dot(fC2.T)
    tensor_dot=C1+C2-hC1.dot(Gamma).dot(hC2.T) 
    return tensor_dot

@nb.njit(cache=True)
def gwgrad_partial1(C1, C2, T,loss='square'):
    """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>`
    as the marginals may not sum to 1.

    Parameters
    ----------
    C1: array of shape (n_p,n_p)
        intra-source (P) cost matrix

    C2: array of shape (n_u,n_u)
        intra-target (U) cost matrix

    T : array of shape(n_p+nb_dummies, n_u) (default: None)
        Transport matrix

    Returns
    -------
    numpy.array of shape (n_p+nb_dummies, n_u)
        gradient


    .. _references-gwgrad-partial:
    References
    ----------
    .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
        "Gromov-Wasserstein averaging of kernel and distance matrices."
        International Conference on Machine Learning (ICML). 2016.
    """
    #T=np.ascontiguousarray(T)
    if loss=='square':
        cC1 = np.dot(C1 ** 2 , np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1)))
        cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2 ** 2 )
        constC = cC1 + cC2
        A = -2*np.dot(C1, T).dot(C2.T)
        tens = constC + A
    elif loss=='dot':
        constC=0
        A = -2*np.dot(C1, T).dot(C2.T)
        tens = constC + A
    return tens 

def partial_gromov_ver1(C1, C2, p, q, Lambda, G0=None,nb_dummies=1,
                               thres=1, numItermax_gw=1000,numItermax=None, tol=1e-7,
                               log=False, verbose=False, line_search=True,seed=0,truncate=True, **kwargs):
   
    r"""
    Solves the partial optimal transport problem
    and returns the OT plan

    The function considers the following problem:

    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F

    .. math::
        s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}

             \gamma^T \mathbf{1} &\leq \mathbf{b}

             \gamma &\geq 0

             \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}

    where :

    - :math:`\mathbf{M}` is the metric cost matrix
    - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
    - `m` is the amount of mass to be transported

    The formulation of the problem has been proposed in
    :ref:`[29] <references-partial-gromov-wasserstein>`


    Parameters
    ----------
    C1 : ndarray, shape (ns, ns)
        Metric cost matrix in the source space
    C2 : ndarray, shape (nt, nt)
        Metric costfr matrix in the target space
    p : ndarray, shape (ns,)
        Distribution in the source space
    q : ndarray, shape (nt,)
        Distribution in the target space
    m : float, optional
        Amount of mass to be transported
        (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
    nb_dummies : int, optional
        Number of dummy points to add (avoid instabilities in the EMD solver)
    G0 : ndarray, shape (ns, nt), optional
        Initialization of the transportation matrix
    thres : float, optional
        quantile of the gradient matrix to populate the cost matrix when 0
        (default: 1)
    numItermax : int, optional
        Max number of iterations
    tol : float, optional
        tolerance for stopping iterations
    log : bool, optional
        return log if True
    verbose : bool, optional
        Print information along iterations
    **kwargs : dict
        parameters can be directly passed to the emd solver


    Returns
    -------
    gamma : (dim_a, dim_b) ndarray
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary returned only if `log` is `True`


    Examples
    --------
    >>> import ot
    >>> import scipy as sp
    >>> a = np.array([0.25] * 4)
    >>> b = np.array([0.25] * 4)
    >>> x = np.array([1,2,100,200]).reshape((-1,1))
    >>> y = np.array([3,2,98,199]).reshape((-1,1))
    >>> C1 = sp.spatial.distance.cdist(x, x)
    >>> C2 = sp.spatial.distance.cdist(y, y)
    >>> np.round(partial_gromov_wasserstein(C1, C2, a, b),2)
    array([[0.  , 0.25, 0.  , 0.  ],
           [0.25, 0.  , 0.  , 0.  ],
           [0.  , 0.  , 0.25, 0.  ],
           [0.  , 0.  , 0.  , 0.25]])
    >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2)
    array([[0.  , 0.  , 0.  , 0.  ],
           [0.  , 0.  , 0.  , 0.  ],
           [0.  , 0.  , 0.25, 0.  ],
           [0.  , 0.  , 0.  , 0.  ]])


    .. _references-partial-gromov-wasserstein:
    References
    ----------
    ..  [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
        Transport with Applications on Positive-Unlabeled Learning".
        NeurIPS.

    """

    # if m is None:
    #     m = np.min((np.sum(p), np.sum(q)))
    # elif m < 0:
    #     raise ValueError("Problem infeasible. Parameter m should be greater"
    #                      " than 0.")
    # elif m > np.min((np.sum(p), np.sum(q))):
    #     raise ValueError("Problem infeasible. Parameter m should lower or"
    #                      " equal than min(|a|_1, |b|_1).")
    
        
    if G0 is None:
        G0 = np.outer(p, q)

    cpt = 0
    err = 1
    
    if log:
        log_dict = {'err': [],'G0_mass':[],'Gprev_mass':[]}
        
    fC1,fC2,hC1,hC2=tensor_dot_param(C1,C2,Lambda=Lambda,loss='square_loss')
    fC1,fC2,hC1,hC2=np.ascontiguousarray(fC1),np.ascontiguousarray(fC2),np.ascontiguousarray(hC1),np.ascontiguousarray(hC2)
    C1,C2=np.ascontiguousarray(C1),np.ascontiguousarray(C2)
    iter_num=0
    n,m=C1.shape[0],C2.shape[0]
    if numItermax is None:
        numItermax=n*100
    p_sum,q_sum=p.sum(),q.sum()
    G0_orig=np.zeros((n,m))
    
    mu_extended,nu_extended,M_extended=np.zeros(n+1),np.zeros(m+1),np.zeros((n+1,m+1))
    mu_extended[0:n],mu_extended[-1]=p,q_sum
    nu_extended[0:m],nu_extended[-1]=q,p_sum
        
    while (err > tol and cpt < numItermax_gw):
        #iter_num+=1
        Gprev = G0.copy()

        Mt_circ_G=tensor_dot_func(fC1,fC2,hC1,hC2,Gprev)
        reg=2*Lambda*np.sum(Gprev)
        
        M_circ_G=gwgrad_partial1(C1, C2, Gprev)-reg
        print('difference is', np.linalg.norm(M_circ_G-Mt_circ_G))
        #M_tilde_circ_gamma=M_circ_gamma-reg 
        
        # opt solver: 
        # Flamary's trick to fasten the computation: select only the subset of columns/lines
        
#        G0,innerlog_=opt_lp(p,q,Grad,Lambda=0,log=log,numItermax=numItermax,**kwargs)
        
#        eps=reg
#        M_extended[:,-1],M_extended[-1,:]=reg,reg
    
        M_extended[0:n,0:m]=Mt_circ_G #-reg
        
        #M_extended[:idx_x.shape[0], :idx_y.shape[0]]= M_star[np.ix_(idx_x, idx_y)]
        gamma_extended,log_dict=emd_lp(mu_extended,nu_extended,M_extended,numItermax=numItermax,log=log,**kwargs)
        
        G0=G0_orig.copy()
        G0[0:n,0:m]=gamma_extended[:-1,:-1]
        #G0[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
        if cpt % 10 == 0:  # to speed up the computations
            err = np.linalg.norm(G0 - Gprev)
            if log:
                log['err'].append(err)
            if verbose:
                if cpt % 200 == 0:
                    print('{:5s}|{:12s}|{:12s}'.format(
                        'It.', 'Err', 'Loss') + '\n' + '-' * 31)
                print('{:5d}|{:8e}|{:8e}'.format(cpt, err,
                                                 gwloss_partial(C1, C2, G0)))

        
        
        # line search 
        deltaG = G0 - Gprev
        
        # line search 
        if line_search:
            
            
            Mt_circ_deltaG=tensor_dot_func(fC1,fC2,hC1,hC2,deltaG)
            a=np.sum(Mt_circ_deltaG*deltaG)
            b=2 * (np.sum(Mt_circ_G * deltaG))
            
            M_circ_deltaG=gwgrad_partial1(C1, C2, deltaG)
            deltaG_sum=np.sum(deltaG)
            a1=np.sum(M_circ_deltaG*deltaG)-2*Lambda*deltaG_sum**2
            b1= 2 * (np.sum(M_circ_G * deltaG)-reg*deltaG_sum)
            
            print('a1-a',a1-a)
            print('b1-b',b1-b)
            if a>0:  # due to numerical precision
                if b>=0:
                    alpha = 0
                    cpt = numItermax_gw
                else:
                    alpha = min(1, np.divide(-b, 2.0 * a))
            else:
                if (a + b) < 0:
                    alpha = 1
                else:
                    alpha = 0
                    cpt = numItermax_gw
        else:
            alpha=1
        
        G0 = Gprev + alpha * deltaG
        cpt += 1
        print('cpt is',cpt)
    if log:
        log_dict.update(innerlog_)
        return G0, log_dict #,iter_num
    else:
        return G0 #,iter_num

In [None]:
def data_process(name='amazon_surf'):
    # open the data file 
    if name in ['MNIST','EMNIST']:
        data_file=torch.load('pu_learning/data/'+name+'.pt')
        (X,l)=data_file
        classes= None
    elif 'surf' in name or 'decaf' in name:        
        with open('pu_learning/data/'+name+'_fts.pkl', 'rb') as f:
            data_file = pickle.load(f)
     
        if 'surf' in name:
            X0=data_file['features']
            l=data_file['labels']
            classes=data_file['classes']
            pca = PCA(n_components=10, random_state=0)
            pca.fit(X0.T)
            X = pca.components_.T
        elif 'decaf' in name:
            X0=data_file['fc8']
            l=data_file['labels']
            classes=data_file['classes']
            pca = PCA(n_components=40, random_state=0)
            pca.fit(X0.T)
            X = pca.components_.T
    return (X,l),classes


def MNIST_figure(figure_list,label_list):
    plt.figure(figsize=(10, 4))
    for i in range(10):
        plt.subplot(2, 5, i + 1)
        plt.imshow(figure_list[i][0], cmap='gray')
        plt.title(f"Label: {label_list[i]}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

def normalize_X(X):
    div = np.max(X, axis=0) - np.min(X, axis=0)
    div[div == 0] = 1 # Avoid division by zero
    X = (X - np.min(X, axis=0)) / div
    return X
    
# def convert_data(dataset,name='MNIST',visual=False):
#     if name in ['MNIST','EMNIST']:
#         X_list,label_list=dataset
#             label_list_all.append(label_list)
#         embedding_list_all=np.vstack(embedding_list_all)
#         label_list_all=np.vstack(label_list_all).reshape(-1).astype(np.int64)
#     return embedding_list_all,label_list_all


# it is modified version 
def draw_pu_dataset_scar(dataset_p, dataset_u=None, size_p=10, size_u=20, prior=0.5, p_label=0,seed_nb=None,same_dataset=True):
    """Draw a Positive and Unlabeled dataset "at random""

    Parameters
    ----------
    dataset_p: name of the dataset among which the positives are drawn

    dataset_u: name of the dataset among which the unlabeled are drawn

    size_p: number of points in the positive dataset

    size_u: number of points in the unlabeled dataset

    prior: percentage of positives on the dataset (s)

    seed_nb: seed

    Returns
    -------
    pandas.DataFrame of shape (n_p, d_p)
        Positive dataset

    pandas.DataFrame of shape (n_u, d_u)
        Unlabeled dataset

    pandas.Series of len (n_u)
        labels of the unlabeled dataset
    """
    x, l = dataset_p[0].copy(),dataset_p[1].copy()
    A=l==p_label
    B=l!=p_label
    l[A],l[B]=1,0
    x=normalize_X(x)

    size_u_p = int(prior * size_u)
    size_u_n = size_u - size_u_p
    
    xp_t = x[l == 1]
    tp_t = l[l == 1]

    xp, xp_other, _, tp_o = train_test_split(xp_t, tp_t, train_size=size_p,
                                             random_state=seed_nb)
    #print('xp_other shape',xp_other.shape)
    if same_dataset or dataset_u is None:
        xup, _, lup, _ = train_test_split(xp_other, tp_o, train_size=size_u_p,
                                        random_state=seed_nb)
    else:
        x, l = dataset_u[0].copy(),dataset_u[1].copy()
        x=normalize_X(x)
        A=l==p_label
        B=l!=p_label
        l[A],l[B]=1,0
        # x, t = make_data(dataset=dataset_u)
        
        # div = np.max(x, axis=0) - np.min(x, axis=0)
        # div[div == 0] = 1
        # x = (x - np.min(x, axis=0)) / div
        xp_other = x[l == 1]
        tp_o = l[l == 1]
        xup, _, lup, _ = train_test_split(xp_other, tp_o,
                                        train_size=size_u_p,
                                        random_state=seed_nb)

    xn_t = x[l == 0]
    tn_t = l[l == 0]
    xun, _, lun, _ = train_test_split(xn_t, tn_t, train_size=size_u_n,
                                    random_state=seed_nb)
    
    xu = np.concatenate([xup, xun], axis=0)
    yu = np.concatenate((np.ones(len(xup)), np.zeros(len(xun)))).astype(np.int64)
    yu_2=np.concatenate((lup,lun))
    #print(np.linalg.norm(yu-yu_2))
    return xp, xu, yu_2

def init_pgw_param(C1,C2,r):
    n,m=C1.shape[0],C2.shape[0]
    q=np.ones(m)/m  
    p=np.ones(n)/n*r # make the mass of p to be r
    mass=np.min((p.sum(),r))
    return p,q,mass




            

def gamma_to_l(G,r):
    n,m=G.shape
    G_2=G.sum(0)
    quantile=np.quantile(G_2,1-r)
    l_G=np.zeros(m)
    l_G[G_2>=quantile]=1
    return l_G

def init_param_ugw(C1,C2):
    n,m=C1.shape[0],C2.shape[0]
    n_pos,n_unl=n,m
    nb_try=1
    mu = (torch.ones([n_pos]) / n_pos).expand(nb_try, -1)
    nu = (torch.ones([n_unl]) / n_unl).expand(nb_try, -1)
    
    grid_eps = [2. ** k for k in range(-9, -8, 1)]
    grid_rho = [2. ** k for k in range(-10, -4, 1)]
    eps=grid_eps[0]
    rho=grid_rho[0]
    rho2=grid_rho[0]
    Cx=torch.from_numpy(C1).to(torch.float32).reshape((nb_try,n,n))
    Cy=torch.from_numpy(C2).to(torch.float32).reshape((nb_try,m,m))
    return mu,nu,eps,rho,rho2,Cx,Cy

def init_flb_uot(C1,C2):
    mu,nu,eps,rho,rho2,Cx,Cy=init_param_ugw(C1,C2)
    print('eps in flb_uot is',eps)
    _, _, init_plan = compute_batch_flb_plan(
            mu, Cx, nu, Cy, eps=eps, rho=rho, rho2=rho2,
            nits_sinkhorn=50000, tol_sinkhorn=1e-5)
    
    return init_plan[0].numpy().astype(np.float64)

def init_flb_pot(C1,C2,p,q,r,Lambda=30.0,n=100):
    p,q,mass=init_pgw_param(C1,C2,r)
    S1,S2=C1.mean(0),C2.mean(0)
    C=cost_matrix(S1,S2)
    gamma,_=opt_lp(p,q,C,Lambda=Lambda,numItermax=n*500)
    
    return gamma

def pu_prediction_gw(C1,C2,r=0.2,G0=None,method='pgw',param={'Lambda':30.0}):
    C1,C2=C1.astype(np.float64),C2.astype(np.float64)
    #C1,C2=cost_matrix_d(X_p,X_p),cost_matrix_d(X_u,X_u)
    n,m=C1.shape[0],C2.shape[0]
    size_p=int(m*r)
    if size_p!=n:
        print('# of positives in X_p and X_u are different, we suggest to modify them')
    if method=='gw':
        p=np.ones(n)/n
    if method=='primal_pgw':
        p,q,mass=init_pgw_param(C1,C2,r)
#       mass=min(r*np.sum(q),np.sum(p)) # this used to avoid numerical issue 
        C1,C2=C1.astype(np.float64),C2.astype(np.float64)
        gamma=partial_gromov_wasserstein(C1,C2,p,q,m=mass,G0=G0,numItermax=n*1000,nb_dummies=1,line_search=False)
        
    if method=='pgw':
        Lambda=param['Lambda']
        p,q,mass=init_pgw_param(C1,C2,r)
        C1,C2=C1.astype(np.float64),C2.astype(np.float64)
        gamma=partial_gromov_ver1(C1,C2,p,q,Lambda=Lambda,G0=G0,numItermax=n*1000,nb_dummies=1,line_search=False)
    if method=='ugw':
        mu,nu,eps,rho,rho2,Cx,Cy=init_param_ugw(C1,C2)
        if 'rho' in param:
            rho=param['rho']
            rho2=rho
        if 'eps' in param:
            eps=param['eps']
        # need to try different rho for better performance
#        rho=0.0023 surf A
        if type(G0)==np.ndarray:
            init_plan=torch.from_numpy(G0).to(torch.float32).reshape((1,n,m))
        elif type(G0)==torch.Tensor:
            init_plan=G0
        gamma = log_batch_ugw_sinkhorn(mu, Cx, nu, Cy, init=init_plan,
                                eps=eps, rho=rho, rho2=rho2,
                                nits_plan=3000, tol_plan=1e-5,
                                nits_sinkhorn=3000, tol_sinkhorn=1e-6)
        print('gamma_mass_diff',gamma.sum()-r)
        gamma=gamma[0]
    return gamma

In [None]:
#nb_dummies=1
p_label=1
nb_dummies=1
name1='MNIST' 
name2='EMNIST'
#name3='webcam_surf'
file_name=name1+name2+'.pt' #'surf.pt' #name1+'-'+name2+'.pt'
try:
    result=torch.load('pu_learning/result/'+filename)
except:
    result={}
n=1000
r=1/5
m=int(n/r)
seed_nb=3

dataset1,_=data_process(name=name1)
dataset2,_=data_process(name=name2)

dataname_list=[name1,name2]
dataset_list=[dataset1,dataset2]
init_method_list=['flb_uot','flb_pot']
method_list=['primal_pgw','pgw'] #
for (data1_name,data1) in zip(dataname_list,dataset_list):
    for (data2_name,data2) in zip(dataname_list,dataset_list):

        print('data 1 is',data1_name)
        print('data 2 is',data2_name)
        if data1_name==data2_name:
            same_dataset=True
        else:
            same_dataset=False
        for init_method in init_method_list:
            G0 = None

            X_p,X_u,label_u=draw_pu_dataset_scar(data1,data2,p_label=p_label,prior=r,size_p=n, size_u=m,seed_nb=seed_nb,same_dataset=same_dataset)
            C, C1, C2, mu, nu=compute_cost_matrices(P=X_p, U=X_u, prior=r, nb_dummies=1)
            p,q=mu[0:n],nu[0:m]
            C1=C1[0:n,0:n]
            C2=C2[0:m,0:m]


            time1=time.time()
            if init_method=='pot_r' and C is not None:
                G0=ot.emd(mu, nu, C)[:n, :] 
                #pu_w_emd(mu, nu, C, nb_dummies=nb_dummies)
                #G0=G0[0:-nb_dummies,:]
            elif init_method=='flb_pot':
                G0=init_flb_pot(C1,C2,p,q,r,Lambda=30.0)
            elif init_method=='flb_uot':
                G0=init_flb_uot(C1,C2)

            time2=time.time()
            run_time=time2-time1
            if G0 is not None:
                l_G0=gamma_to_l(G0,r)
                acc0=accuracy_score(l_G0,label_u)
                result[init_method+'-'+data1_name+'-'+data2_name]={}
                result[init_method+'-'+data1_name+'-'+data2_name]['accuracy']=acc0
                result[init_method+'-'+data1_name+'-'+data2_name]['time']=run_time
                #result[init_method+'-'+data1+'-'+data2]['G0']=G0
                print('init method is',init_method)
                print('accuracy is',acc0)
                print('time is',run_time)    
            # if G0 is not None:    
                for method in method_list:
                    if True: #if init_method+'-'+data1+'-'+data2+'-'+method not in result:
                        if method=='ugw':
                            param={'None'}
                        elif method=='pgw':
                            param={'Lambda':20.0}
                        else:
                            param=None
                        time1=time.time()
                        G=pu_prediction_gw(C1.copy(),C2.copy(),r=r,G0=G0.copy(),method=method,param=param)
                        time2=time.time()
                        run_time=time2-time1

                        l_G=gamma_to_l(G,r)
                        acc=accuracy_score(l_G,label_u)
                        result[init_method+'-'+data1_name+'-'+data2_name+'-'+method]={}
                        result[init_method+'-'+data1_name+'-'+data2_name+'-'+method]['time']=run_time
                        result[init_method+'-'+data1_name+'-'+data2_name+'-'+method]['accuracy']=acc
                        print('method is',method)
                        print('accuracy is',acc)
                        print('time is',run_time)
                #torch.save(result,'pu_learning/result/'+file_name)
                    
# p,q=np.ones(n)*r/n,np.ones(m)/m

# l_G0=gamma_to_l(G0,r)
# acc_G0=accuracy_score(l_G0,label_u)
# print('acc_G0',acc_G0)

# if C is not None:
#     G0=pu_w_emd(mu, nu, C, nb_dummies=nb_dummies)
#     G0=G0[0:-nb_dummies,:]

#     l_G0=gamma_to_l(G0,r)
#     acc_G0=accuracy_score(l_G0,label_u)
#     print('acc_G0',acc_G0)
# gamma=pu_prediction_gw(C1,C2,r=r,method='ugw',G0=G0,param={'Lambda':30.0})
# l_G=gamma_to_l(gamma,r)
# acc=accuracy_score(l_G,label_u)
# print('acc',acc)

In [None]:
print('done')
result=torch.load('pu_learning/result/MNIST-EMNIST.pt')
for key in result:
    str_list=key.split("-")
    if len(str_list)==3:
        init,data1,data2=str_list[0],str_list[1],str_list[2]
        print('data1 is',data1)
        print('data2 is',data2)
        print('init method is',init)
        print('accuracy is', result[key]['accuracy'])
        print('time is', result[key]['time'])

    
    elif len(str_list)==4:
        init,data1,data2,method=str_list[0],str_list[1],str_list[2],str_list[3]
        print('method is',method)
        print('accuracy is', result[key]['accuracy'])
        print('time is', result[key]['time'])
    