In [None]:
from matplotlib import pyplot as plt
import numpy as np
from scipy import ndimage
import torch as th
import networkx as nx

import snake
import gradImSnake

In [None]:
def create_X_graph(nnodes,scale,offset):
    # create an X-shaped graph
    g=nx.Graph()
    g.add_node(0,pos=np.array([0.0,0.0])*scale+offset)
    for sign in [( 1, 1), (-1, 1), (-1,-1), ( 1,-1)]:
        prev=0
        for i in range(nnodes):
            n=len(g)
            g.add_node(n,pos=np.array([sign[0]*float(i+1),sign[1]*float(i+1)])*scale+offset)
            g.add_edge(prev,n)
            prev=n
    return g

def drawLine(lbl,begPoint,endPoint):
    # endPoint and begPoint should be np.arrays
    # lbl is an np.array to which the line is rendered
    d=endPoint-begPoint
    mi=np.argmax(np.fabs(d))
    if d[mi]==0: # beginning and end points the same
        lbl[tuple(begPoint.astype(np.int))]=1
    else:
        coef=d/d[mi] # a vector that points from the current to the next pixel
        sz=np.array(lbl.shape) # an array holding a shape not an array of shape
        numsteps=int(abs(d[mi]))+1
        step=int(d[mi]/abs(d[mi])) # +-1
        for t in range(0,numsteps):
            pos=begPoint+coef*t*step
            if np.all(pos<sz) and np.all(pos>=0):
                lbl[tuple(np.round(pos).astype(np.int))]=1
            else:
                print("warning: reqested point",pos,"but the volume size is",sz)
    return lbl

def show(img,title,graph=None):
    plt.figure(figsize=(10,10))
    plt.suptitle(title)
    plt.imshow(img)
    if graph:
        pos = {n:graph.nodes[n]['pos'][-1::-1] for n in graph.nodes()}
        nx.draw_networkx(graph, pos=pos, node_size=16, node_color='gray',
                         edge_color='g', font_size=12, font_color='black')

In [None]:
# a 2D example

nnodes=6
scale=16
margin=16
offset=margin+nnodes*scale
sz=2*offset

# create a graph and the energy image
lbl=np.zeros((sz,sz))
g=create_X_graph(nnodes,scale,offset)
for e in g.edges:
    drawLine(lbl,g.nodes[e[0]]["pos"],g.nodes[e[1]]["pos"])
enim=ndimage.distance_transform_edt(1-lbl)
show(enim,"low energy position",g)

# perturb graph
delta=np.array([0,30])
for n in g.nodes:
    g.nodes[n]["pos"]+=delta
show(enim,"perturbed graph",g)

# get the energy gradient image
fltr=gradImSnake.makeGaussEdgeFltr(1.0,2)
fltrt=th.from_numpy(fltr)
enimt=th.from_numpy(enim[None,None])
gradientImage=gradImSnake.cmptGradIm(enimt,fltrt)[0]
#show(gradientImage[0],"gradient image - vertical")
#show(gradientImage[1],"gradient image - horizontal")

crop=[slice(0,sz),slice(0,sz)]
stepsz=1.0
alpha=0.0
beta=1.0
ndims=2

s=gradImSnake.GradImSnake(g,crop,stepsz,alpha,beta,ndims,gradientImage)
h=s.getGraph()
show(enim,"initial snake (note the perturbed graph is cropped)",h)

niter=100
s.optim(niter)
h=s.getGraph()
show(enim,"final snake",h)

#canvas=th.zeros(enim.shape)
#s.renderSnake(enim)
#show(canvas,"rendered snake")

In [None]:
# backpropagation through snake updates
# for optimization of the energy image

from math import sin,pi
from torch import optim

# the energy image is initialized to all zeros
# it will be optimized to make randomly initialized snakes converge to a predefined shape
enimg=th.zeros((1,1,100,100),dtype=th.double,requires_grad=True)
# the ground truth graph is a sinusoid
g=nx.Graph()
for i in range(25,75):
    g.add_node(i,pos=np.array([sin(i*2*pi/50.0)*25+50,i]))
    if i>25:
        g.add_edge(i,i-1)

stepsz=0.2
extparam=1
alpha=0.0
beta=1.0
crop=[slice(-100,200), slice(-100,200)]
fltrstdev=5
extparam=1
ndims=2

opt = optim.SGD([enimg], lr=1, momentum=0.0)

fltr =gradImSnake.makeGaussEdgeFltr(fltrstdev,ndims)
fltrt=th.from_numpy(fltr)
    
s=gradImSnake.GradImSnake(g,crop,stepsz,alpha,beta,ndims,None)
gt=s.getPos().clone()

show(enimg[0][0].detach().numpy(),"initial energy image and the ground truth graph",s.getGraph())

for i in range(500):
    
    gimg=gradImSnake.cmptGradIm(enimg,fltrt)[0]
    s.gimg=gimg
    # perturb the snake
    s.s=gt+th.normal(th.zeros_like(gt),10*th.ones_like(gt))
    # make it converge
    s.optim(10)
    # the loss
    l=th.norm(s.s-gt)
    l.backward()
    opt.step()
    opt.zero_grad()
    # update
    
s.s=gt
show(enimg[0][0].detach().numpy(),"learnt energy image and the ground truth graph",s.getGraph())
    


In [None]:
# a simple 3D example

print("this 3D snake is a line")
print("initially its trajectory is fixed at [25,...,16]")
print("and the snake converges to [20,...,20]")
lbl = np.zeros((40,40,40))
for i in range(0,40):
    lbl[20,i,20] = 1
    
enimg = ndimage.distance_transform_edt(1-lbl)

G =nx.Graph()
start=3
stop=36
step=5
for i in range(start,stop,step):
    G.add_node(i,pos=np.array([25,i,16]))
    if i>start:
        G.add_edge(i-step,i)

stepsz=0.2
extparam=1
alpha=0.0
beta=1.0
# do not crop
crop=[slice(-100,100), slice(-100,100),slice(-100,100)]
fltrstdev=0.5
extparam=1
nupdates=100
ndims=3

fltr=gradImSnake.makeGaussEdgeFltr(fltrstdev,ndims)
fltrt=th.from_numpy(fltr)
enimgt=th.from_numpy(enimg[np.newaxis,np.newaxis].astype(np.float))
gimg=gradImSnake.cmptGradIm(enimgt,fltrt)[0]
s=gradImSnake.GradImSnake(G,crop,stepsz,alpha,beta,ndims,gimg)
print("starting position")
print(s.s)

#s.cuda()
s.optim(nupdates)
print("end position")
print(s.s)