In [53]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from glob import glob
# import openmesh as om
import os
import openmesh as om
from sklearn.neighbors import KDTree

In [17]:
x = tf.constant([[[23, 1, 3], [78, 78, 45]], [[23, 1, 3], [78, 78, 45]]])
x = tf.map_fn(fn=lambda t: tf.reshape(tf.gather(t, [0, 0, 0, 1, 1], axis=0), [-1]), elems=x)
with tf.compat.v1.Session() as sess:
    print(x.eval())

[[23  1  3 23  1  3 23  1  3 78 78 45 78 78 45]
 [23  1  3 23  1  3 23  1  3 78 78 45 78 78 45]]


In [None]:
### Left to change to tensorflow2
def preprocess_spiral(face, seq_length, vertices=None, dilation=1):
    from .generate_spiral_seq import extract_spirals
    assert face.shape[1] == 3
    if vertices is not None:
        mesh = om.TriMesh(np.array(vertices), np.array(face))
    else:
        n_vertices = face.max() + 1
        mesh = om.TriMesh(np.ones([n_vertices, 3]), np.array(face))
    spirals = extract_spirals(mesh, seq_length=seq_length, dilation=dilation)
    return spirals

In [52]:
import openmesh as om
from sklearn.neighbors import KDTree
import numpy as np


def _next_ring(mesh, last_ring, other):
    res = []

    def is_new_vertex(idx):
        return (idx not in last_ring and idx not in other and idx not in res)

    for vh1 in last_ring:
        vh1 = om.VertexHandle(vh1)
        after_last_ring = False
        for vh2 in mesh.vv(vh1):
            if after_last_ring:
                if is_new_vertex(vh2.idx()):
                    res.append(vh2.idx())
            if vh2.idx() in last_ring:
                after_last_ring = True
        for vh2 in mesh.vv(vh1):
            if vh2.idx() in last_ring:
                break
            if is_new_vertex(vh2.idx()):
                res.append(vh2.idx())
    return res


def extract_spirals(mesh, seq_length, dilation=1):
    # output: spirals.size() = [N, seq_length]
    spirals = []
    for vh0 in mesh.vertices():
        reference_one_ring = []
        for vh1 in mesh.vv(vh0):
            reference_one_ring.append(vh1.idx())
        spiral = [vh0.idx()]
        one_ring = list(reference_one_ring)
        last_ring = one_ring
        next_ring = _next_ring(mesh, last_ring, spiral)
        spiral.extend(last_ring)
        while len(spiral) + len(next_ring) < seq_length * dilation:
            if len(next_ring) == 0:
                break
            last_ring = next_ring
            next_ring = _next_ring(mesh, last_ring, spiral)
            spiral.extend(last_ring)
        if len(next_ring) > 0:
            spiral.extend(next_ring)
        else:
            kdt = KDTree(mesh.points(), metric='euclidean')
            spiral = kdt.query(np.expand_dims(mesh.points()[spiral[0]],
                                              axis=0),
                               k=seq_length * dilation,
                               return_distance=False).tolist()
            spiral = [item for subspiral in spiral for item in subspiral]
        spirals.append(spiral[:seq_length * dilation][::dilation])
    return spirals


def preprocess_spiral(face, seq_length, vertices=None, dilation=1):
    from .generate_spiral_seq import extract_spirals
    assert face.shape[1] == 3
    if vertices is not None:
        mesh = om.TriMesh(np.array(vertices), np.array(face))
    else:
        n_vertices = face.max() + 1
        mesh = om.TriMesh(np.ones([n_vertices, 3]), np.array(face))
    spirals = extract_spirals(mesh, seq_length=seq_length, dilation=dilation)
    return spirals

class SpiralConv(layers.Layer):
    
    def __init__(self, in_channels, out_channels, indices, dim=1):
        super(SpiralConv, self).__init__()
        self.dim = dim
        self.indices = indices
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.layer = layers.Dense(out_channels)
        
    def call(self, x):
        n_nodes, _ = self.indices.shape
        x = tf.map_fn(fn=lambda t: tf.reshape(tf.gather(t, tf.reshape(self.indices, [-1]), axis=0), [n_nodes, -1]), elems=x)
        x = self.layer(x)
        return x

with tf.compat.v1.Session() as sess:
    inputs = keras.Input(shape=(4,3))
    outputs = SpiralConv(3, 5, tf.constant([[1, 2], [0, 1]]))(inputs)

    model = keras.Model(inputs, outputs)

In [25]:
OUT_IMAGE_SIZE = (224, 224)
inputs = tf.keras.Input(shape=(*OUT_IMAGE_SIZE, 3), name="normalized_image")

In [26]:
resnet50 = tf.keras.applications.ResNet50(
    include_top=True, weights='imagenet', input_tensor=None,
    input_shape=None, pooling=None, classes=1000)
x = resnet50(inputs)

In [27]:
x = tf.keras.layers.Dense(64, activation=tf.keras.activations.relu , name="FC1")(x)

In [28]:
x = tf.keras.layers.Dense(51*48, activation=tf.keras.activations.relu, name="FC2")(x)

In [29]:
x = tf.keras.layers.Reshape((51, 48), name="reshape_to_mesh")(x)

In [30]:
x = tf.keras.layers.UpSampling1D(size=2, name="UP1")(x)

In [31]:
indices = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
spiral_conv1 = SpiralConv(48, 16, indices)

x = spiral_conv1(x)

TypeError: 'int' object is not callable