# HyperAid

Please follow the instruction in the README file for correct running our code.

## 1. Import some basic packages

In [1]:
"""Train a hyperbolic embedding model for hierarchical clustering."""

import argparse
from config import config_args
from utils.training import add_flags_from_config, get_savedir
import os

## 2. Setup default arguments

You may postpone the choice of `--enc_method,--dec_method,--dataset,--num_nodes,--noise_scale` later.
The most important thing to setup here are

1) the negative curvature `--c`
Here's the choice of `--c` for reproducing our results for all datasets

Real-world datasets
Zoo: c=200,
Iris: c=100,
Glass: c=250,
Segmentation: c=300,
Spambase: c=100.

Synthetic datasets
c=100

2) gpu `--device`
Which gpu to use. Set -1 for using cpu.

3) `--verbose`
Whether to output detail information during training

4) `--save`
Whether to save results. Set it to be `True` UNLESS you want to fine-tune the model by yourself.
Note that in order to use decoders such as T-REX and Ufit, you will need the saved results.
The encoder will skip training hyperbolic embeddings for a dataset if some previous result is saved.
This will save alot of time and ensure the fairness if you want to test multiple decoders at a time later.
Results will be saved in the folder `./embeddings`

In [None]:
parser = argparse.ArgumentParser("Hyperbolic Hierarchical Clustering.")
parser.add_argument("--seed", type=str, default=0, help="model seed to use")
parser.add_argument("--enc_method", type=str, default='Hyp', choices=['Hyp', 'Direct'], help="Method for encoding.")
parser.add_argument("--dec_method", type=str, default='random_points', choices=['TreeRep', 'NJ'], help="Method for decoding.")
parser.add_argument("--dataset", type=str, default='zoo', help="dataset option")
parser.add_argument("--num_nodes", type=int, default=64, help="#nodes of synthetic dataset")
parser.add_argument("--noise_scale", type=float, default=0.0, help="noise_scale for random_tree synthetic dataset")
parser.add_argument("--p", type=int, default=2, help="Lp norm in the cost")
parser.add_argument("--c", type=float, default=1, help="negative curvature") # No use for now.
parser.add_argument("--burnin", type=int, default=30, help="Burn-in stage epoch.")
parser.add_argument("--batch_size", type=int, default=100000, help="Batch size. Recommend to set it as large as possible.") 
parser.add_argument("--dtype", type=str, default='double', help="dtype. Set to double for better precision")
parser.add_argument("--scaling_factor", type=float, default=1, help="The scaling factor of all metric.")
parser.add_argument("--burnin_factor", type=float, default=10, help="The factor of lr/burnin_lr")
parser.add_argument("--verbose", type=bool, default=True, help="To print complete info or not.")
parser.add_argument("--dec_repeat", type=int, default=20, help="Number of times to repeat decodeing steps.")
parser.add_argument("--check_tree", type=bool, default=False, help="Check if the (best) output is a tree metric.")
parser.add_argument("--eval_every", type=int, default=1)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=500)
parser.add_argument("--rank", type=int, default=2)
parser.add_argument("--Normalization", type=bool, default=False)
parser.add_argument("--init_size", type=float, default=1e-6)
parser.add_argument("--device", type=int, default='0', help="-1 for cpu.")

parser.add_argument("--anneal_every", type=int, default=200)
parser.add_argument("--anneal_factor", type=float, default=1.0)
parser.add_argument("--save", type=bool, default=True)
parser.add_argument("--num_workers", type=int, default=1)
# Load other args from default config.
parser = add_flags_from_config(parser, config_args)

args = parser.parse_args([])

# We use os.environ['CURVATURE'] to pass the args.curvature to util.poincare.py!
os.environ['CURVATURE']=str(args.c)

## 3. Import the rest packages

Remember to check whether the negative curvature is consistant with your choice.

In [3]:
import json
import logging
import ipdb

import numpy as np
import torch
import torch.utils.data as data
from tqdm import tqdm
import time

import optim
from datasets.hc_dataset import MetricDataset
from datasets.loading import load_data
from model.hyphc import MetricHypHC
from utils.metrics import dasgupta_cost
from datasets.triples import generate_all_pairs

import itertools
from scipy.optimize import linprog
import networkx as nx
import matplotlib as plt
from scipy.sparse import csr_matrix, lil_matrix
from scipy.sparse.csgraph import shortest_path
from scipy.cluster.hierarchy import linkage as Linkage
from scipy.cluster.hierarchy import cophenet
import scipy.spatial.distance as ssd
import newick

os.environ['DATAPATH']="./data"
os.environ['SAVEPATH']="./embeddings"

Current curvature is:-1.0


## 4. Import Julia, also the NJ and TreeRep methods

It may take a while (2 minutes in my case)

Also please set `jpath` as the path to julia you install in the conda enviroment.

In [4]:
###### import setting for TreeRep

# # # This allows multithreading for julia
# # # use 16 thread. Note that using more than 16 thread will incur errors due to the TreeRep implementation.
# os.environ["JULIA_NUM_THREADS"] = str(16)

from julia.api import Julia
jpath = "YOUR_PATH/miniconda3/envs/HyperAid/bin/julia" # path to Julia, from current directory (your path may be slightly different)
jl = Julia(runtime=jpath, compiled_modules=False) # compiled_modules=True may work for you; it didn't for me

t = time.time()
import TreeRep_myVer.src.TreeRepy as TreeRepy
import TreeRep_myVer.src.NJpy as NJpy
print("Time for importing:",time.time() - t)

Time for importing: 136.43162775039673


## 5. Include our own functions (Utils)

In [4]:
def newick2W(trees,n):
    # This convert a tree from newick format to distance csr matrix, including interior nodes ((2n-1)-by-(2n-1))
    # Note that all internal nodes are named "None"
    # All leaves node have name in np.arange(n)
    
    W = csr_matrix(((2*n-1),(2*n-1)))
    assert len(trees[0].descendants) == 2 # An internal/root node should always have 2 children!
    
    W,_,_ = Fill_W_from_tree(trees[0],W,n)
    return W

def Fill_W_from_tree(subtree,W,cur_idx):   
    # Check if we visit current node
    if subtree.name == None:
        # Not visit yet, give id
        subtree.name = str(cur_idx)
        cur_idx += 1
    # Now add all adjacent relation to W
    if not subtree.is_leaf:
        Flag = False
        for i in range(len(subtree.descendants)):
            if subtree.descendants[i].name == None:
                subtree.descendants[i].name = str(cur_idx)
                cur_idx += 1
                Flag = True

            W[int(subtree.descendants[i].name),int(subtree.name)] = subtree.descendants[i].length
            W[int(subtree.name),int(subtree.descendants[i].name)] = subtree.descendants[i].length

            if Flag:
                W,subtree.descendants[i],cur_idx = Fill_W_from_tree(subtree.descendants[i],W,cur_idx)
    
    return W,subtree,cur_idx


def generate_rand_TreeMetric_V2(N,option='unweight',noise_ratio=0.0):
    # Generate (weighted) random binary tree in newick format.
    T = generate_tree_newick(np.arange(N),option=option)
    # Make it into a tree data structure in python
    T = newick.loads(T)
    # Get the distance matrix for leaves in T
    W = newick2W(T,N)
    
    # Add noise. We randomly add noise_ratio*(2n-2) fake connections between nodes (including internal nodes).
    m_Noise = int(noise_ratio*(2*N-2))
    for _ in range(m_Noise):
        idx0,idx1 = np.random.choice(N,size=(2),replace=False)
        if option == 'unweight':
            W[idx0,idx1] = 1.0
            W[idx1,idx0] = 1.0
        else:
            W[idx0,idx1] = np.random.random()
            W[idx1,idx0] = np.random.random()
    
    # Now construct the shortest path based on this noisy observation
    Noisy_D = get_leaves_Dmatrix(W,N)
    
    return Noisy_D


def generate_tree_newick(L,option='unweight',balance=False):
    # base case
    if len(L) == 1: 
        return str(L[0])
    else:
        perm_L = np.random.permutation(L)
        if balance:
            split = int(len(L)/2)
        else:
            split = np.random.randint(1,len(L))
        left = perm_L[:split]
        right = perm_L[split:]
    # recursion
    if option == 'unweight':
        return str((generate_tree_newick(left,balance=balance)+':1.0',
                    generate_tree_newick(right,balance=balance)+':1.0')).replace(' ','').replace("'","")
    else:
        left_weight = ':{0:.5f}'.format(np.random.rand())
        right_weight = ':{0:.5f}'.format(np.random.rand())
        return str((generate_tree_newick(left,option=option,balance=balance)+left_weight,
                    generate_tree_newick(right,option=option,balance=balance)+right_weight)).replace(' ','').replace("'","")

def newick2dist(trees,n):
    # This convert a tree from newick format to distance matrix (n-by-n)
    # Note that all internal nodes are named "None"
    # All leaves node have name in np.arange(n)
    
    D = np.zeros((n,n))
    assert len(trees[0].descendants) == 2 # An internal/root node should always have 2 children!
    
    
    D = Fill_D_from_tree(trees[0],D)
    return D

def Fill_D_from_tree(subtree,D):
    All_leaves = np.arange(D.shape[0])    

    num_children = len(subtree.descendants)
    for i in range(num_children):
        sub_leaves = np.array(subtree.descendants[i].get_leaf_names()).astype(int)
        # We simply set the negative branch weights to zero following the common practice
        if subtree.descendants[i].length>0:
            D[np.ix_(sub_leaves,np.setdiff1d(All_leaves,sub_leaves))] += subtree.descendants[i].length
            D[np.ix_(np.setdiff1d(All_leaves,sub_leaves),sub_leaves)] += subtree.descendants[i].length
    
    
        if subtree.descendants[i].name == None:
            D = Fill_D_from_tree(subtree.descendants[i],D)
    
    return D

def sim2metric(S):
    """
    Input S is a n-by-n np array, representing similarities.
    Output D turn S into a metric.
    It should satisfy:
    1) Sij = 1 iff Dij = 0
    2) Sij = -1 iff Dij = max(D)
    3) basic metric properties (triangle ineq...etc)
    
    Now we choose D = 1-S
    """
    return 1-S

def generate_tree(L):
    # base case
    if len(L) == 1: 
        return L[0]
    else:
        perm_L = np.random.permutation(L)
        split = np.random.randint(1,len(L))
        left = perm_L[:split]
        right = perm_L[split:]
    # recursion
    return (generate_tree(left), generate_tree(right))

def plot_rand_binary_tree(brt,n,plot=True,normalization=False,d_mtx=None):
    w_star = None
    A = np.zeros((2*n-1,2*n-1))
    # Note that we keep the leave node id the same!!!
    # Thus the internal node id start at n.
    # For clarity, we set node id n to be the root node
    A,next_idx = Fill_adj(brt,n,A)
    
    assert next_idx == 2*n-1
    
    if normalization:
        # Apply Puoya's normalized weight
        A_T=np.zeros((n*(n-1)/2,2*n-2))
        # First construct an edge name dictionary
        w_dict = -np.ones((2*n-1,2*n-1))
        w_dict,next_idx,cur_w_idx = Fill_w_dict(brt,n,w_dict)
        assert next_idx == 2*n-1
        assert cur_w_idx == 2*n-2
        
        # Now go through all shortest path for leaves and get d_vec
        G = nx.from_numpy_matrix(np.matrix(A))
        All_path = nx.shortest_path(G)
        count = 0
        d_vec = np.zeros(n*(n-1)/2)
        for i in range(n):
            for j in range(i+1,n):
                d_vec[count] = d_mtx[i,j]
                path = All_path[i][j]
                for k in range(len(path)-1):
                    A_T[count,int(w_dict[path[k],path[k+1]])] = 1
                count += 1
        
        # optimal weight
        w_star = np.matmul(np.matmul(np.linalg.pinv(np.matmul(A_T.transpose(),A_T)),A_T.transpose()),d_vec)
        # modify adj matrix A
        for i in range(2*n-1):
            for j in range(i+1,2*n-1):
                if A[i,j] == 1:
                    A[i,j] = w_star[int(w_dict[i,j])]
                    A[j,i] = w_star[int(w_dict[i,j])]
    
    # Compute d, the shortest path for leave nodes
    G = nx.from_numpy_matrix(np.matrix(A))
    d = nx.algorithms.shortest_paths.dense.floyd_warshall_numpy(G)
    d = d[:n]
    d = d[:,:n]
    
    if plot:
        plot_Tree_metric(A,n,n)
    
    return d,w_star

def Fill_adj(brt,cur_idx,A):
    left = brt[0]
    if isinstance(left,np.int64):
        A[cur_idx,left] = 1
        A[left,cur_idx] = 1
        next_idx = cur_idx+1
    else:
        cur_idx_left = cur_idx+1
        A[cur_idx,cur_idx_left] = 1
        A[cur_idx_left,cur_idx] = 1
        A,next_idx = Fill_adj(left,cur_idx_left,A)
    
    right = brt[1]
    if isinstance(right,np.int64):
        A[cur_idx,right] = 1
        A[right,cur_idx] = 1
    else:
        cur_idx_right = next_idx
        A[cur_idx,cur_idx_right] = 1
        A[cur_idx_right,cur_idx] = 1
        A,next_idx = Fill_adj(right,cur_idx_right,A)

    return A,next_idx
    
def Fill_w_dict(brt,cur_idx,w_dict,cur_w_idx=0):
    left = brt[0]
    if isinstance(left,np.int64):
        w_dict[cur_idx,left] = cur_w_idx
        w_dict[left,cur_idx] = cur_w_idx
        next_idx = cur_idx+1
        cur_w_idx += 1
    else:
        cur_idx_left = cur_idx+1
        w_dict[cur_idx,cur_idx_left] = cur_w_idx
        w_dict[cur_idx_left,cur_idx] = cur_w_idx
        cur_w_idx += 1
        w_dict,next_idx,cur_w_idx = Fill_w_dict(left,cur_idx_left,w_dict,cur_w_idx)
    
    right = brt[1]
    if isinstance(right,np.int64):
        w_dict[cur_idx,right] = cur_w_idx
        w_dict[right,cur_idx] = cur_w_idx
        cur_w_idx += 1
    else:
        cur_idx_right = next_idx
        w_dict[cur_idx,cur_idx_right] = cur_w_idx
        w_dict[cur_idx_right,cur_idx] = cur_w_idx
        cur_w_idx += 1
        w_dict,next_idx,cur_w_idx = Fill_w_dict(right,cur_idx_right,w_dict,cur_w_idx)

    return w_dict,next_idx,cur_w_idx

def plot_Tree_metric(W,N_leaves,root_id=-1,Return_metric=False,plot=True,option='scipy'):
    
    print('Start computing shortest path')
    
    # First to remove potential zero rows
    W = W[~np.all(W == 0, axis=1)]
    W = W[:,~np.all(W == 0, axis=0)]
    
#     # Replace all negative entries to 0 (if any)
#     W[W<0] = 0.0
    
#     # Rounding for better vitualization
#     W = np.around(W,3)
    
    G = nx.from_numpy_matrix(np.matrix(W))
    
    if plot:
        # Use spring_layout to handle positioning of graph
        layout = nx.kamada_kawai_layout(G)

        # # Use a list for node_sizes
        # sizes = [1000,400,200]

        # # Use a list for node colours
        # # Here we use different color to distiguish leave nodes and steiner nodes
        # # By default, we use 'tab:blue' for leave nodes and 'tab:red' for steiner nodes
        N_all = W.shape[0]
        color_map = []
        for _ in range(N_leaves):
            color_map.append('tab:blue')
        for _ in range(N_all-N_leaves):
            color_map.append('tab:red')

        if root_id>-1:
            color_map[root_id] = 'tab:green'

        # Draw the graph using the layout - with_labels=True if you want node labels.
        nx.draw(G, layout, with_labels=True, node_color= color_map)

        # Get weights of each edge and assign to labels
        labels = nx.get_edge_attributes(G, "weight")

        # Draw edge labels using layout and list of labels
        nx.draw_networkx_edge_labels(G, pos=layout, edge_labels=labels)
    
    # Also compute the shortest distance matrix for leaves
    if Return_metric:
        if option == 'networkx':
            n = N_leaves
            d = nx.algorithms.shortest_paths.dense.floyd_warshall_numpy(G)
            d = d[:n]
            d = d[:,:n]
            return np.array(d)
        elif option == 'scipy':
            n = N_leaves
            graph = csr_matrix(W) 
            d = shortest_path(csgraph=graph,directed=False)
            d = d[:n]
            d = d[:,:n]
            return d
    else:
        return None

def get_leaves_Dmatrix(W,N_leaves):
    # Input W is a csr_matrix.
    n = N_leaves 
    d = shortest_path(csgraph=W,directed=False)
    d = d[:n]
    d = d[:,:n]
    return d

def Linkage2dist(Z):
    # First, build a dictionary for mapping all nodes to actual cluters
    n = Z.shape[0]+1
    # Include all leaves as singleton first
    W_Size = int(max(Z[:,0].max(),Z[:,1].max()))+2
    W = csr_matrix((W_Size,W_Size))    

    for i in range(Z.shape[0]):
        idx0 = Z[i,0]
        idx1 = Z[i,1]
        length = Z[i,2]
        pid = i+n
        W[idx0,pid] = length
        W[idx1,pid] = length
        W[pid,idx0] = length
        W[pid,idx1] = length

    d = shortest_path(csgraph=W,directed=False)
    d = d[:n]
    d = d[:,:n]
    return d

## 6. Include decoder wrapper functions

In [5]:
def Decoding_TreeRep(D_input,D_target,p=2,repeat=20,verbose=False,rng_seed=0,tol=1e-6):
    """
    Since TreeRep is a randomized algorithm, when D is not a tree metric, the output tree metric is random.
    Hence we repeat the decoding step using TreeRep with $repeat runs.
    We return the mean loss, std, best loss and the corresponding tree metric
    
    Set verbose = True for printing out more info such as running time...etc.
    
    Note that somehow direct decoding original metric has chance to incur tree weight = inf.
    Now we simply ignore that run and restart.
    
    Also note that we pass and use random seed from rng_seed*repeat to rng_seed*repeat+(repeat-1). 
    """
    # Rounding D_input to prevent numerical issue! tol should be smaller than the min dif of entries in D_input!
    D_input = np.around(D_input,5)
    
    Results = np.zeros(repeat)
    Best_loss = np.inf
    rp = 0
    Flag = True
    BAD_time = 0
    BAD_MAX = repeat # prevent infinte loop!
    cur_seed = rng_seed*repeat
    with tqdm(total=repeat) as bar:
        while Flag:
            t = time.time()
            D_input_tree = TreeRepy.TreeRep(D_input,rng_seed=cur_seed)
#             # Remember to remove 0 rows/cols
#             D_input_tree = D_input_tree[~np.all(D_input_tree == 0, axis=1)]
#             D_input_tree = D_input_tree[:,~np.all(D_input_tree == 0, axis=0)]
            
            #Check if D_input_tree has negative entries
            if D_input_tree[D_input_tree<0].shape[1]>0:
                print('TreeRep error: decoding tree contains negative weights!')
                BAD_time += 1
                # Remember to skip this random seed...
                cur_seed += 1
                if BAD_time>=BAD_MAX:
                    return -1, -1, -1, -1, BAD_time
                continue

            # Plot tree decoded from D_metric
            D_input_tree_metric = get_leaves_Dmatrix(D_input_tree,D_input.shape[0])

            # Check if decoding tree has inf/nan
            if np.isinf(D_input_tree_metric).any() or np.isnan(D_input_tree_metric).any():
                if verbose:
                    print('TreeRep error: decoding tree metric contains inf/nan weight!')
                BAD_time += 1
                # Remember to skip this random seed...
                cur_seed += 1
                if BAD_time>=BAD_MAX:
                    return -1, -1, -1, -1, BAD_time
                continue

            if verbose:
                # Check if there's inf
                print(np.isinf(D_input_tree_metric).any())
                print(D_input_tree_metric)

            loss = (np.abs(D_target-D_input_tree_metric)**p).sum().sum()**(1/p)
            Results[rp] = loss

            if loss < Best_loss:
                D_output = D_input_tree_metric
                
            rp += 1
            cur_seed += 1
            bar.update(1)
            if rp >= repeat:
                Flag = False
    
    return Results.mean(), Results.std(), Results.min(), D_output, BAD_time
    


    
def Decoding_NJ(D_input,D_target,p=2,verbose=False):
    D_input_tree = NJpy.NJ(D_input)
    
    #Check if D_input_tree has negative entries
    if D_input_tree[D_input_tree<0].shape[1]>0:
        print('NJ error: decoding tree contains negative weight!')
        return -1, -1, -1, -1, 1
        
    D_output = get_leaves_Dmatrix(D_input_tree,D_target.shape[0])

    if np.isinf(D_output).any():
        print('NJ error: decoding tree metric contains inf!')
        return -1, -1, -1, -1, 1
    
    
    loss = (np.abs(D_target-D_output)**p).sum().sum()**(1/p)
    
    return loss, -1, loss, D_output, 0

def Decoding_linkage(D_input,D_target,p=2,verbose=False,link_method='single'):
    """
    This decoding option using linkage based method as decoder.
    Note that the output is a *ultrametric* instead of a general *tree (additive) metric*!
    We mainly use scipy.cluster.hierarchy.linkage.
    The link_method correspond to the method option in scipy.cluster.hierarchy.linkage.
    
    Warning! If D_input is *NOT* Euclidean, then do not use ‘centroid’, ‘median’, and ‘ward’ method!
    See scipy website for more information: https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html#scipy.cluster.hierarchy.linkage
    """
    
    
    Z = Linkage(ssd.squareform(D_input),method=link_method)
    # Check if Z has negative edge weights
    if Z[:,2].squeeze()[Z[:,2].squeeze()<0]:
        print('Linkage error: decoding tree contains negative weight!')
        return -1, -1, -1, -1, 1
    
    D_output = ssd.squareform(cophenet(Z))
    # Check if D_output has inf/nan weights
    if np.isinf(D_output).any() or np.isnan(D_output).any():
        print('Linkage error: decoding tree metric contains inf/nan weight!')
        return -1, -1, -1, -1, 1
    
    loss = (np.abs(D_target-D_output)**p).sum().sum()**(1/p)
    
    return loss, -1, loss, D_output, 0

## 7. Include encoder wrapper functions

In [6]:
def train(args):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # set seed
    logging.info("Using seed {}.".format(args.seed))
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # set precision
    logging.info("Using {} precision.".format(args.dtype))
    if args.dtype == "double":
        torch.set_default_dtype(torch.float64)

    # create dataset
    if args.dataset in ['zoo','iris','glass','segmentation','spambase']:
        x, y_true, similarities = load_data(args.dataset)
        D_metric = sim2metric(similarities)
        D_metric *= args.scaling_factor
    elif args.dataset in ['random_tree']:
        D_metric = generate_rand_TreeMetric_V2(args.num_nodes,option='unweight',noise_ratio=args.noise_scale)
        D_metric *= args.scaling_factor
        y_true = np.zeros(D_metric.shape[0],int)
    else:
        print('No such dataset option! args.dataset:{}'.format(args.dataset))
        raise 
    
    # Return D_metric directly if arg.enc_method == 'Direct' (no encoding step)
    if args.enc_method == 'Direct':
        return np.array(D_metric)/args.scaling_factor
    
    # get saving directory
    if args.save:
#         save_dir = get_savedir(args)
        if args.dataset in ['zoo','iris','glass','segmentation','spambase']:
            save_dir = os.path.join(os.environ["SAVEPATH"],
                                args.dataset, 'seed_{}_curv_{}'.format(args.seed,os.environ["CURVATURE"]))
        else:
            save_dir = os.path.join(os.environ["SAVEPATH"],
                                args.dataset,'nodes_{}_noise_{}'.format(args.num_nodes,args.noise_scale),
                                'seed_{}_curv_{}'.format(args.seed,os.environ["CURVATURE"]))
        logging.info("Save directory: " + save_dir)
        save_path = os.path.join(save_dir, "model_{}.pkl".format(args.seed))
        save_Dpath = os.path.join(save_dir, "D_metric_{}_{}.npy".format(args.dataset,args.seed))
        if os.path.exists(save_dir):
            if os.path.exists(save_Dpath) and (args.enc_method != 'Direct'):
                logging.info("D_metric with the same configuration parameters already exists.")
                logging.info("Load and return existing D_metric")
                
                return np.array(D_metric)/args.scaling_factor, np.load(save_Dpath), -1
            if os.path.exists(save_path):
                logging.info("Model with the same configuration parameters already exists.")
                logging.info("Exiting. Just generate D_metric from it!")
                return
        else:
            os.makedirs(save_dir)
            with open(os.path.join(save_dir, "config.json"), 'w') as fp:
                json.dump(args.__dict__, fp)
        log_path = os.path.join(save_dir, "train_{}.log".format(args.seed))
        hdlr = logging.FileHandler(log_path)
        formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
        hdlr.setFormatter(formatter)
        logger.addHandler(hdlr)
    
    dataset = MetricDataset(D_metric, labels=y_true, num_samples=args.num_samples)
    dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
                                 num_workers=args.num_workers, pin_memory=True)

    # create model
    model = MetricHypHC(dataset.n_nodes, args.rank, args.temperature,
                        args.init_size, args.max_scale, scaling_factor=args.scaling_factor)
    
    if args.device == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.device))
    model.to(device)
    model.all_pairs = model.all_pairs.to(device)
    
    
    # create optimizer
    Optimizer = getattr(optim, args.optimizer)
    optimizer = Optimizer(model.parameters(), args.learning_rate)
    # train model
    best_cost = np.inf
    best_model = None
    counter = 0
    logging.info("Start training")
    for epoch in range(args.epochs):
        model.train()
        total_loss = 0.0
        
        # burn-in (assuming input lr is for burn-in stage)
        if epoch==args.burnin:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= args.burnin_factor
                lr = param_group['lr']
            logging.info("Burn-in stage end. Normal learning rate to: {}".format(lr))
            
        if args.verbose:
            with tqdm(total=len(dataloader), unit='ex') as bar:
                for step, (pair_ids, pair_D_metrics) in enumerate(dataloader):
                    pair_ids = pair_ids.to(device)
                    pair_D_metrics = pair_D_metrics.to(device)
                    loss = model.loss(pair_ids, pair_D_metrics,p=args.p,Normalization=args.Normalization)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    bar.update(1)
                    bar.set_postfix(loss=f'{loss.item():.6f}')
                    total_loss += loss
        else:
            for step, (pair_ids, pair_D_metrics) in enumerate(dataloader):
                pair_ids = pair_ids.to(device)
                pair_D_metrics = pair_D_metrics.to(device)
                loss = model.loss(pair_ids, pair_D_metrics,p=args.p,Normalization=args.Normalization)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss
        
        total_loss = total_loss.item()**(1/args.p) / (step + 1.0)
        
        if args.verbose:
            logging.info("\t Epoch {} | average train loss: {:.6f}".format(epoch, total_loss))

        # keep best embeddings
        if (epoch + 1) % args.eval_every == 0:
            model.eval()
            cost = 0.0
            for step, (pair_ids, pair_D_metrics) in enumerate(dataloader):
                pair_ids = pair_ids.to(device)
                pair_D_metrics = pair_D_metrics.to(device)
                loss = model.loss(pair_ids, pair_D_metrics,p=args.p,Normalization=args.Normalization)
                cost += loss
            cost = cost.item()**(1/args.p)
            if args.verbose:
                logging.info("{}:\t{:.4f}".format("Lp norm cost", cost))
            if cost < best_cost:
                counter = 0
                best_cost = cost
                best_model = model.state_dict()
            else:
                counter += 1
                if counter == args.patience:
                    if args.verbose:
                        logging.info("Early stopping.")
                    break
        
        # anneal temperature
        if (epoch + 1) % args.anneal_every == 0:
            model.anneal_temperature(args.anneal_factor)
            if args.verbose:
                logging.info("Annealing temperature to: {}".format(model.temperature))
            for param_group in optimizer.param_groups:
                param_group['lr'] *= args.anneal_factor
                lr = param_group['lr']
            if args.verbose:
                logging.info("Annealing learning rate to: {}".format(lr))
    
    
    logging.info("Optimization finished.")
    if best_model is not None:
        # load best model
        model.load_state_dict(best_model)

    if args.save:
        # save best embeddings
        logging.info("Saving best model at {}".format(save_path))
        torch.save(best_model, save_path)
    
    # evaluation
    model.eval()
#     logging.info("Decoding embeddings (TODO).")
#     tree = model.decode_tree(fast_decoding=args.fast_decoding)
    logging.info("Compute Lp norm cost.")
    D_hyp = model.get_D_hyp(Normalization=args.Normalization).cpu()
    cost = torch.sum(torch.sum((torch.abs(torch.from_numpy(D_metric)-D_hyp)/args.scaling_factor)**args.p)/2)**(1/args.p)
    logging.info("{}:\t{:.4f}".format("Lp norm cost", cost))
    
    
    
    if args.save and (args.enc_method != 'Direct'):
        # save resulting distance matrix
        logging.info("Saving D_hyp at {}".format(save_Dpath))
        np.save(save_Dpath,D_hyp.detach().numpy()/args.scaling_factor)
    
    if args.save:
        logger.removeHandler(hdlr)
        
    return np.array(D_metric)/args.scaling_factor, D_hyp.detach().numpy()/args.scaling_factor, model.cpu()

## 8. Wrapping for experiments

Note that all the hyperparameters for each tested datasets in our paper are included in this part.
If `args.save=True`, a summarized csv file will be saved at `./Results`.

In [15]:
import pandas as pd
from datetime import datetime

def One_Run_Exp(args):
    REPEAT = args.dec_repeat
    
    if args.enc_method in ['Hyp']:
        # Encoding part
        start = time.time()
        D_metric,D_hyp, model = train(args)
        Encode_time = time.time()-start
        
        # Decoding part
        if args.dec_method in ['TreeRep']:
            start = time.time()
            D_hyp_results = Decoding_TreeRep(D_input=D_hyp,D_target=D_metric,p=args.p,repeat=args.dec_repeat,verbose=False,rng_seed=args.seed)
            D_hyp_decode_time = (time.time()-start)/args.dec_repeat
        elif args.dec_method in ['NJ']:
            start = time.time()
            D_hyp_results = Decoding_NJ(D_input=D_hyp,D_target=D_metric,p=args.p,verbose=False)
            D_hyp_decode_time = (time.time()-start)
        elif args.dec_method in ['single','complete','average','weighted']:
            start = time.time()
            D_hyp_results = Decoding_linkage(D_input=D_hyp,D_target=D_metric,p=args.p,
                                             verbose=False,link_method=args.dec_method)
            D_hyp_decode_time = (time.time()-start)
        elif args.dec_method in ['centroid','median','ward']:
            print('These method do not support non-Euclidean metric input!')
            D_hyp_results = [-1,-1,-1,-1,-1]
            D_hyp_decode_time = -1
        else:
            raise NameError('No such a decoding method! name={}'.format(args.dec_method))
        
        if args.check_tree:
            print(is_treemetric(D_hyp_results[3]))
        
        
        # Results
        emb_loss = ((D_metric-D_hyp)**args.p).sum().sum()**(1/args.p)
        
        if args.verbose:
            print("Seed:{},p={},dataset:{},num_nodes={}".format(args.seed,args.p,args.dataset,D_hyp.shape[0]))
            print("Enc:{},Dec:{}".format(args.enc_method,args.dec_method))
            print("Original embedding loss:{}".format(emb_loss))
            print("Tree loss:              {}±{}, best loss:{}".format(np.around(D_hyp_results[0],4),np.around(D_hyp_results[1],4),np.around(D_hyp_results[2],4)))
            print("Decoding bad time:      {}".format(D_hyp_results[4]))
            print('Encoding time:          {}'.format(Encode_time))
            print("Decoding time:    {}".format(D_hyp_decode_time))
        
        One_run_results = np.array([emb_loss,           # emb_loss
                                    D_hyp_results[0],   # mean of tree loss 
                                    D_hyp_results[1],   # std of tree loss
                                    D_hyp_results[2],   # best(min) of tree loss
                                    D_hyp_results[4],   # times that decoding fail (inf distance)
                                    Encode_time,        # Encoding time
                                    D_hyp_decode_time]) # Decoding time
        return One_run_results
    
    else:
        # Encoding part (just generate/load data)
        D_metric = train(args)

        # Decoding part
        if args.dec_method in ['TreeRep']:
            start = time.time()
            D_metric_results = Decoding_TreeRep(D_input=D_metric,D_target=D_metric,p=args.p,repeat=args.dec_repeat,verbose=False,rng_seed=args.seed)
            D_metric_decode_time = (time.time()-start)/args.dec_repeat
        elif args.dec_method in ['NJ']:
            start = time.time()
            D_metric_results = Decoding_NJ(D_input=D_metric,D_target=D_metric,p=args.p,verbose=False)
            D_metric_decode_time = (time.time()-start)
        elif args.dec_method in ['MST']:
            start = time.time()
            D_metric_results = Decoding_MST(D_input=D_metric,D_target=D_metric,p=args.p,verbose=False)
            D_metric_decode_time = (time.time()-start)
        elif args.dec_method in ['single','complete','average','weighted','centroid','median','ward']:
            start = time.time()
            D_metric_results = Decoding_linkage(D_input=D_metric,D_target=D_metric,p=args.p,
                                                verbose=False,link_method=args.dec_method)
            D_metric_decode_time = (time.time()-start)
        else:
            raise NameError('No such a decoding method! name={}'.format(args.dec_method))
            
        # Check if the output is a tree metric
        if args.check_tree:
            print(is_treemetric(D_metric_results[3]))
        
        # Results
        if args.verbose:
            print("Seed:{},p={},dataset:{},num_nodes={}".format(args.seed,args.p,args.dataset,D_metric.shape[0]))
            print("Enc:{},Dec:{}".format(args.enc_method,args.dec_method))
            print("Tree loss:              {}±{}, best loss:{}".format(np.around(D_metric_results[0],4),np.around(D_metric_results[1],4),np.around(D_metric_results[2],4)))
            print("Decoding bad time:      {}".format(D_metric_results[4]))
            print("Decoding time:    {}".format(D_metric_decode_time))
        
        One_run_results = np.array([-1,                    # No emb_loss. Set to -1.
                                    D_metric_results[0],   # mean of tree loss 
                                    D_metric_results[1],   # std of tree loss
                                    D_metric_results[2],   # best(min) of tree loss
                                    D_metric_results[4],   # times that decoding fail (inf distance)
                                    -1,                    # No Encoding time. Set to -1. 
                                    D_metric_decode_time]) # Decoding time
        return One_run_results
    
def All_Run_Exp(args,
                DATASET_list=['zoo'],
                NOISE_list=[1.0],
                NUM_NODES_list=[16],
                ENC_list=['Direct'],
                DEC_list=['TreeRep'],
                ENC_repeat=10):
    
    column_name = ['Dataset','Num nodes','Enc','Dec',
                           'Emb loss','Mean','Std','Best','Bad time',
                           'Enc time','Dec time']
    
    
    Original_seed = args.seed
    
    for dname in DATASET_list:
        args.dataset = dname
        for n in NUM_NODES_list:
            args.num_nodes = n
            for noise in NOISE_list:
                args.noise_scale = noise
                All_results = pd.DataFrame(columns=column_name)
                for enc in ENC_list:
                    args.enc_method = enc
                    for dec in DEC_list:
                        args.dec_method = dec

                        # Remember to set hyperparameters for each combinations
                        if dname in ['zoo']:
                            args.learning_rate = 0.05
                            args.scaling_factor = 1.0
                            args.rank = 64
                        elif dname in ['iris']:
                            args.learning_rate = 0.05
                            args.scaling_factor = 1.0
                            args.rank = 64
                        elif dname in ['glass']:
                            args.learning_rate = 0.05
                            args.scaling_factor = 1.0
                            args.burnin_factor = 50
                            args.rank = 64
                        elif dname in ['segmentation']:
                            args.learning_rate = 0.02
                            args.scaling_factor = 1.0
                            args.rank = 64
                            args.epochs = 100
                            args.patience = 10
                            args.burnin = 20
                            args.burnin_factor = 2
                            args.batch_size = 10000
                            args.num_workers = 64
                        elif dname in ['spambase']:
                            args.learning_rate = 0.002
                            args.scaling_factor = 1.0
                            args.rank = 16
                            args.epochs = 100
                            args.patience = 5
                            args.burnin = 5
                            args.burnin_factor = 1
                            args.batch_size = 100000
                            args.num_workers = 64

                        elif dname in ['random_tree']:
                            if n in [64]:
                                if noise in [0.1,0.3,0.5]:
                                    args.learning_rate = 0.001
                                    args.scaling_factor = 0.01
                                    args.rank = 64
                                    args.burnin = 50
                                    args.burnin_factor = 10
                                    args.batch_size = 1000000
                                    args.num_workers = 1

                            elif n in [128]:
                                if noise in [0.1,0.3,0.5]:
                                    args.learning_rate = 0.05
                                    args.scaling_factor = 0.05
                                    args.rank = 64
                                    args.burnin = 50
                                    args.burnin_factor = 10
                                    args.batch_size = 1000000
                                    args.num_workers = 1
                            
                            elif n in [256]:
                                if noise in [0.1,0.3,0.5]:
                                    args.learning_rate = 0.001
                                    args.scaling_factor = 0.01
                                    args.rank = 64
                                    args.burnin = 30
                                    args.burnin_factor = 10
                                    args.batch_size = 1000000
                                    args.num_workers = 1
                        else:
                            print('Use default settings given by args.')

                        args.seed = Original_seed

                        for _ in range(ENC_repeat):
                            All_results = All_results.append(pd.Series(), ignore_index=True)
                            All_results.iloc[-1,0] = dname
                            All_results.iloc[-1,1] = n
                            All_results.iloc[-1,2] = enc
                            All_results.iloc[-1,3] = dec
                            All_results.iloc[-1,4:] = One_Run_Exp(args)
                            # remember to change args.seed
                            args.seed += 1
                        
                if args.save:
                    # Write to a csv file
                    now = datetime.now()
                    dt_string = now.strftime("%d-%m-%Y_%H-%M-%S")
                    if dname in ['zoo','iris','glass','segmentation','spambase','letter-recognition']:
                        SAVE_PATH = './Results/All_{}_results_{}.csv'.format(dname,dt_string)
                    else:
                        SAVE_PATH = './Results/All_{}_{}_{}_results_{}.csv'.format(dname,n,noise,dt_string)
                    All_results.to_csv(SAVE_PATH,index=False)
                    # Also write the args
                    with open(SAVE_PATH, 'a') as f:
                        print('#',args, file=f)

                    print('Results finished and saved at {}'.format(SAVE_PATH))
                else:
                    print(args)
            
    return All_results
                

## 9. Run the experiment

Note that if you use new datasets or other settings of synthetic datasets, you have to tune the hyperparameters at part 2 by your own.

In [None]:
# DATASET_list choice: 'zoo', 'iris', 'glass', 'segmentation','spambase','random_tree'
# DEC_list: 'NJ','TreeRep','single','complete','average','weighted','centroid','median', 'ward'

All_Run_Exp(args,
            DATASET_list=['zoo'], # List of datasets to be tested.
            NUM_NODES_list=[64], # number of nodes for synthetic datasets, not used for real-world datasets.
            NOISE_list=[0.3], # edge noise ratio, not used for real-world datasets.
            ENC_list=['Direct','Hyp'], # 'Direct' for using decoder directly, 'Hyp' for HyperAid
            DEC_list=['NJ'], # Choice of decoders.
            ENC_repeat=1) # number of independent runs for encoder part. Set to 1 for real-world datasets and 3 for `random_tree`