In [1]:
## custom
import lovasz_losses as L

## sys
import random
import time

## numeric
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

## vis
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits import mplot3d
from matplotlib import collections  as mc
from mpl_toolkits.mplot3d.art3d import Line3DCollection

## notebook
from IPython import display
from tqdm import tqdm_notebook as tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

plt.style.use('ggplot')
plt.style.use('seaborn-colorblind')

## Utils

In [4]:
def draw_graph_3d(ax, x, G, grad=None):
    ax.scatter(x[:,0], x[:,1], x[:,2])
    # ax.view_init(elev=20.0, azim=0)

    edgeLines = [(x[e0][:3], x[e1][:3]) for e0,e1 in G.edges]
    lc = Line3DCollection(edgeLines, linewidths=1)
    ax.add_collection(lc)
    
    if grad is not None:
        ax.quiver(x[:,0], x[:,1], x[:,2], 
                 -grad[:,0], -grad[:,1], -grad[:,2], length=4, colors='C1')
    return ax

def colorScale2cmap(domain, range1):
    domain = np.array(domain)
    domain = (domain-domain.min())/(domain.max()-domain.min())
    range1 = np.array(range1)/255.0
    
    red = [r[0] for r in range1]
    green = [r[1] for r in range1]
    blue = [r[2] for r in range1]
    red = tuple((d,r,r) for d,r in zip(domain, red))
    green = tuple((d,r,r) for d,r in zip(domain, green))
    blue = tuple((d,r,r) for d,r in zip(domain, blue))
    return LinearSegmentedColormap('asdasdas', {'red':red, 'green': green, 'blue':blue})
    
colors = [
    [44,52,179],
    [0,0,0],
    [174,33,57],
]



#https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3
def pairwise_distances(x, y=None, w=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is None:
        y = x
        y_t = y.t()
        y_norm = x_norm
    if w is not None:
        x = x * w    
        y = y * w    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)


def file2graph(fn='./facebook/0.edges'):
    with open(fn) as f:
        lines = [l.split()[:2] for l in f.readlines()]
        edges = [tuple(int(i) for i in l) for l in lines]
        nodes = set(sum(edges, ())) ## SLOW?
#         edges += [(-1, n) for n in nodes]
#         nodes.update({-1})
    G = nx.Graph()
    G.add_nodes_from(list(nodes))
    G.add_edges_from(edges)
    return G


def dict2tensor(d, fill=None):
    n = len(d.keys())
    k2i = {k:i for i,k in enumerate(sorted(d.keys()))}
    res = torch.zeros(len(d.keys()), len(d.keys()), device=device)
    for src_node, dst_nodes in d.items():
        for dst_node, distance in dst_nodes.items():
            if fill is not None:
                res[k2i[src_node],k2i[dst_node]] = fill
            else:
                res[k2i[src_node],k2i[dst_node]] = distance
    return res, k2i

## Optimization Procedures

In [4]:
# def stress_minimization(X, D, Adj, optimizer, max_iter=10):
#     def stress(x, d, adjacency):
#         pdist = pairwise_distances(x)
#         s = ((pdist - d)**2).mean()
#         return s

#     for i in range(max_iter):
#         s = stress(X, D, Adj)
#         s.backward()
#         optimizer.step()
#         X.grad.data.fill_(0)
#     return X, s

In [24]:
def neighbor_preservation_bce(X, Adj, optimizer, max_iter=10, w=None):
    
    sigmoid = nn.Sigmoid()
    neighborSizes = Adj.sum(dim=1).int()
    nodeCount = Adj.shape[0]
    
    def euclidean_neighbor(x):
        pdist = pairwise_distances(x, w=w)
        res = torch.zeros([x.shape[0],x.shape[0]], device=device)
        for i, [distances, ns] in enumerate(zip(pdist, neighborSizes)):
#             print(distances.shape, nodeCount, ns.item())
            topk = distances.topk(nodeCount-ns.item())
            thresh = topk.values[-2:].mean()
            scale = 2.0
            res[i,:] = 1-sigmoid((distances-thresh) * scale)
        return res

    bce = nn.BCELoss()
#     jaccard_loss = L.lovasz_softmax
    
    for i in range(max_iter):
        if X.grad is not None:
            X.grad.data.fill_(0)
        pred = euclidean_neighbor(X)
        truth = Adj
        eye = torch.eye(pred.shape[0], device=device)
        pred *= (1-eye)
        loss = bce(pred, truth)
#         loss = jaccard_loss(pred.view(1,*pred.shape), truth, classes=[1])
        loss.backward()
        X.grad.data[X.grad.data.abs()>0.01].sign_() ## (optional) fast gradient sign method
        optimizer.step()
        
    return X, loss, pred



def neighbor_preservation_jaccard(X, Adj, optimizer, max_iter=10, w=None):
    
    sigmoid = nn.Sigmoid()
    neighborSizes = Adj.sum(dim=1).int()
    nodeCount = Adj.shape[0]

    def model(x):
        pdist = pairwise_distances(x, w=w)
        maxDist = pdist.max()
        res = torch.zeros([x.shape[0], x.shape[0]], device=device)
        for i, [distances, ns] in enumerate(zip(pdist, neighborSizes)):
            topk = distances.topk(nodeCount-ns.item())
            thresh = topk.values[-2:].mean()
            res[i,:] = thresh - pdist[i,:]
        return res

#     jaccard_loss = lovasz(hingeError) #option 3 (slow and does not work as expected)
    jaccard_loss = L.lovasz_hinge #option 4 the official lovasz hinge!
    
    eye = torch.eye(Adj.shape[0], device=device)
    truth = Adj + eye
    
    for i in range(max_iter):
        if X.grad is not None:
            X.grad.data.fill_(0)
        pred = model(X)
#         pred *= (1-eye)
        loss = jaccard_loss(pred.view(-1), truth.view(-1))
        loss.backward()
        optimizer.step()
    return X, loss, pred


relu = nn.ReLU()
def hingeError(logits, target):
    logits = logits*2-1
    target = target*2-1
    return relu(1-logits*target)


def lovasz(error_func=hingeError):
    def jaccardLoss(error, target):
        union = (error+target)#.clamp(0,1)
        if error.sum()==0:
            return torch.tensor(0.0)
        else:
            return error.sum() / union.sum()
    
    def f(logits, target):
        error = error_func(logits, target)
        sorted_error = torch.sort(error)
        values, indices = sorted_error.values, sorted_error.indices
        loss = 0
        jaccardLosses = []
        for i in range(0,values.shape[0]+1):
            error_i = error.clone()
            error_i[indices[:i]] = 0
            error_i[indices[i:]] = 1
            jl = jaccardLoss(error_i, target)
            jaccardLosses.append(jl)
        
        for i in range(values.shape[0], 0, -1):
            jl0 = jaccardLosses[i]
            jl1 = jaccardLosses[i-1]
            weight = values[i-1]
            margin = jl1 - jl0
            loss += weight * margin
        return loss
    return f



## test:
def jaccardIndex(pred, target):
    intersect = pred*target
    union = (pred+target).clamp(0,1)
    if intersect.sum() == 0:
        return torch.tensor(0.0)
    else:
        return intersect.sum() / union.sum()

logits = torch.tensor([1.0, 0.0], requires_grad=True)
target = torch.tensor([1.0, 1.0])
f = lovasz()
# f = L.lovasz_hinge

print('jaccard index:', jaccardIndex(logits, target).item())
print(' jaccard loss:', 1-jaccardIndex(logits, target).item())
print('  lovasz loss:', f(logits, target).item())

jaccard index: 0.5
 jaccard loss: 0.5
  lovasz loss: 0.6666666865348816


In [25]:
# # ## test

# ground_truth = torch.tensor([0.0, 0.0])
# steps = 19
# x,y = np.meshgrid(np.linspace(-2,2,steps).astype('float32'), np.linspace(-2,2,steps).astype('float32'))
# z = []
# for logits in np.c_[x.ravel(), y.ravel()]:
#     logits = torch.tensor(logits, requires_grad=True)
#     loss = f(logits, ground_truth).item()
#     z.append(loss)
# z = np.array(z)

# x = x.reshape([steps,steps])
# y = y.reshape([steps,steps])
# z = z.reshape([steps,steps])

# fig = plt.figure()
# ax = plt.axes(projection='3d')
# ax.plot_surface(x,y,z, cmap='viridis')
# ax.view_init(elev=20.0, azim=210)
# plt.show()

## generate a graph

In [43]:
%%time

print('generating graph...')
# G = nx.path_graph(10)
# G = nx.cycle_graph(10)
G = nx.balanced_tree(3,3)
# G = nx.connected_watts_strogatz_graph(10,5,0.5)
# G = file2graph('./facebook/0.edges')

print('calculating all pairs shortest path...')
D,k2i = dict2tensor(dict(nx.all_pairs_shortest_path_length(G)))
Adj,_ = dict2tensor(dict(G.adjacency()), fill=1)

print(len(G.nodes), 'nodes')
print('\n\n')

generating graph...
calculating all pairs shortest path...
40 nodes



CPU times: user 112 ms, sys: 436 µs, total: 112 ms
Wall time: 110 ms


## Optimize via Stochastic Gradient Descent (SGD)

In [48]:
!rm -r fig
!mkdir fig

In [49]:
X = torch.rand(len(G.nodes), 3, requires_grad = True, device=device)

# stress_optimizer = optim.SGD([X], lr=0.01)
# neighbor_optimizer = optim.SGD([X], lr=0.03) ## works for cycle graph, jaccard loss
# neighbor_optimizer = optim.SGD([X], lr=0.5)## for jaccard loss
# neighbor_optimizer = optim.SGD([X], lr=0.5)## for bce loss
neighbor_optimizer = optim.Adam([X], lr=0.1)## for bce loss
lossHistory = []

# def schedule(i, n):
#     return np.exp(-i/n*2)
# niter = 200
# i=np.linspace(0, niter, niter+1)
# plt.plot(i, schedule(i, niter))

In [50]:
niter = 150
iterBar = tqdm(range(niter))
for i in iterBar:
    
    ## option 1: stress minimization
#     X, stress = stress_minimization(X, D, Adj, stress_optimizer, max_iter=10)

    ## option 2: neighbor preservation
#     w = torch.tensor([1.0, 1.0, schedule(i, niter), schedule(i, niter)], device=device)
    X, loss, pred = neighbor_preservation_jaccard(X, Adj, neighbor_optimizer, max_iter=1, w=None)
#     X, loss, pred = neighbor_preservation_bce(X, Adj, neighbor_optimizer, max_iter=5, w=None)
    
    if i%10 == 0:
        iterBar.set_postfix({'loss': loss.item()})
    ## debug & vis
    lossHistory.append(loss.item())
#     if i%10==9:

    if i%2==0:
        x = X.detach().cpu().numpy()
        grad = X.grad.data.cpu().numpy()
        
        print(f'loss: {loss.item()}')
        print(f'max grad: {np.abs(grad).max()}')
        
        fig = plt.figure(figsize=[14,10])
#         display.clear_output(wait=True)
        
        ## graph
        if x.shape[1] == 2:
            plt.subplot(221)
            nx.draw_networkx(G, pos={k: x[k2i[k],:2] for k in G.nodes}, font_color='white')
            plt.quiver(x[:,0], x[:,1], 
                       -grad[:,0], -grad[:,1], 
                       units='inches', label=f'neg grad (max={np.linalg.norm(grad, axis=1).max():.2e})')
            plt.axis('equal')
            plt.legend()  
        else:
            ax = fig.add_subplot(2,2,1, projection='3d')
            ax = draw_graph_3d(ax, x, G, grad)
        plt.title('epoch: {}'.format(i))

        ## loss
        plt.subplot(222)
        plt.plot(lossHistory)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')

        ## pred vs truth
        plt.subplot(234)
        pdist = pairwise_distances(X)
        pdist = pdist.detach().cpu()
        plt.imshow(pdist.max()-pdist-np.eye(pdist.shape[0]))
        plt.title('max - distance')
        plt.colorbar()

        plt.subplot(235)
        pred = pred.detach().cpu()
        vmax = min(pred.max(), -pred.min())
        cmap = colorScale2cmap([-1, 0, 1], colors)
        plt.imshow(pred, cmap=cmap, vmin=-1, vmax=1)
        plt.title('Prediction')
        plt.colorbar()
        
        plt.subplot(236)
        cmap = colorScale2cmap([0, 0.5, 1], colors)
        plt.imshow(Adj.detach().cpu(), cmap=cmap)
        plt.colorbar()
        plt.title('Ground Truth')
        
        plt.savefig(f'fig/epoch{i}.png')
        plt.close()
#         plt.show()
        

HBox(children=(IntProgress(value=0, max=150), HTML(value='')))

loss: 0.6263233423233032
max grad: 0.02267475612461567
loss: 0.45497316122055054
max grad: 0.03127071261405945
loss: 0.33995646238327026
max grad: 0.03460439294576645
loss: 0.2924215495586395
max grad: 0.016093028709292412
loss: 0.27986443042755127
max grad: 0.011121461167931557
loss: 0.2693791687488556
max grad: 0.009518582373857498
loss: 0.260297030210495
max grad: 0.010846502147614956
loss: 0.2529999017715454
max grad: 0.010002418421208858
loss: 0.24558748304843903
max grad: 0.010891210287809372
loss: 0.2384827733039856
max grad: 0.011976135894656181
loss: 0.23078852891921997
max grad: 0.00683437567204237
loss: 0.2230801284313202
max grad: 0.011432328261435032
loss: 0.21376268565654755
max grad: 0.008773665875196457
loss: 0.20523102581501007
max grad: 0.008309494704008102
loss: 0.19592680037021637
max grad: 0.007674236781895161
loss: 0.18805405497550964
max grad: 0.014590760692954063
loss: 0.17981551587581635
max grad: 0.010350048542022705
loss: 0.172564297914505
max grad: 0.0087083

## animation

In [52]:
from PIL import Image
from natsort import natsorted
from glob import glob

# Create the frames
frames = []
imgs = natsorted(glob('fig/*.png'))

for img in imgs:
    new_frame = Image.open(img)
    frames.append(new_frame)

# Save into a GIF file that loops forever
frames[0].save('anim6.gif', format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=60, loop=0)

In [None]:
# import imageio
# from natsort import natsorted
# from glob import glob

# fig = plt.figure(figsize=[14,10])

# ims = []
# for fn in natsorted(glob('fig/epoch*.png')):
#     im = imageio.imread(fn)
#     im = plt.imshow(im, animated=True)
#     ims.append([im])

# ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
#                                 repeat_delay=1000)

# # ani.save('dynamic_images.mp4')

# display.HTML(ani.to_jshtml())
# # plt.show()

In [73]:
%%time

print('generating graph...')
# G = nx.path_graph(10)
G = nx.cycle_graph(100)
# G = nx.balanced_tree(3,4)
# G = nx.connected_watts_strogatz_graph(10,5,0.5)
# G = file2graph('./facebook/0.edges')

print('calculating all pairs shortest path...')
D,k2i = dict2tensor(dict(nx.all_pairs_shortest_path_length(G)))
Adj,_ = dict2tensor(dict(G.adjacency()), fill=1)

print(len(G.nodes), 'nodes')
print('\n\n')

generating graph...
calculating all pairs shortest path...
100 nodes



CPU times: user 645 ms, sys: 3.89 ms, total: 649 ms
Wall time: 647 ms


In [74]:
# from umap import UMAP

model = UMAP(metric='precomputed', unique=True, 
             min_dist=0.9,
             n_neighbors=5, n_epochs=200)
# dist = (1-Adj.cpu().numpy())*(1-np.eye(dist.shape[0]))
# dist = torch.tanh(D).cpu().numpy()
dist = (D**0.4).cpu().numpy()
xy = model.fit_transform(dist)

  "using precomputed metric; transform will be unavailable for new data and inverse_transform "


In [None]:
fig = plt.figure(figsize=[8,8])
nx.draw_networkx(G, pos={k: xy[k2i[k],:2] for k in G.nodes}, font_color='white')
