In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.layers import Dense, Flatten, Conv2D, Concatenate, Softmax, LayerNormalization, Dropout
from tensorflow.keras import Model
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import activations
from tensorflow.keras import regularizers

import numpy as np

import networkx as nx

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
import itertools
import functools

def neighbours_8(x, y, x_max, y_max):
    deltas_x = (-1, 0, 1)
    deltas_y = (-1, 0, 1)
    for (dx, dy) in itertools.product(deltas_x, deltas_y):
        x_new, y_new = x + dx, y + dy
        if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0):
            yield x_new, y_new


def neighbours_4(x, y, x_max, y_max):
    for (dx, dy) in [(1, 0), (0, 1), (0, -1), (-1, 0)]:
        x_new, y_new = x + dx, y + dy
        if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0):
            yield x_new, y_new


def get_neighbourhood_func(neighbourhood_fn):
    if neighbourhood_fn == "4-grid":
        return neighbours_4
    elif neighbourhood_fn == "8-grid":
        return neighbours_8
    else:
        raise Exception(f"neighbourhood_fn of {neighbourhood_fn} not possible")

In [3]:
from functools import partial
from collections import namedtuple
import heapq

#DijkstraOutput = namedtuple("DijkstraOutput", ["shortest_path", "is_unique", "transitions"])


def dijkstra(matrix, neighbourhood_fn="8-grid", request_transitions=False):

    x_max, y_max = matrix.shape
    neighbors_func = partial(get_neighbourhood_func(neighbourhood_fn), x_max=x_max, y_max=y_max)

    costs = np.full_like(matrix, 1.0e10)
    costs[0][0] = matrix[0][0]
    num_path = np.zeros_like(matrix)
    num_path[0][0] = 1
    priority_queue = [(matrix[0][0], (0, 0))]
    certain = set()
    transitions = dict()

    while priority_queue:
        cur_cost, (cur_x, cur_y) = heapq.heappop(priority_queue)
        if (cur_x, cur_y) in certain:
            pass

        for x, y in neighbors_func(cur_x, cur_y):
            if (x, y) not in certain:
                if matrix[x][y] + costs[cur_x][cur_y] < costs[x][y]:
                    costs[x][y] = matrix[x][y] + costs[cur_x][cur_y]
                    heapq.heappush(priority_queue, (costs[x][y], (x, y)))
                    transitions[(x, y)] = (cur_x, cur_y)
                    num_path[x, y] = num_path[cur_x, cur_y]
                elif matrix[x][y] + costs[cur_x][cur_y] == costs[x][y]:
                    num_path[x, y] += 1

        certain.add((cur_x, cur_y))
    # retrieve the path
    cur_x, cur_y = x_max - 1, y_max - 1
    on_path = np.zeros_like(matrix)
    on_path[-1][-1] = 1
    while (cur_x, cur_y) != (0, 0):
        cur_x, cur_y = transitions[(cur_x, cur_y)]
        on_path[cur_x, cur_y] = 1.0

    is_unique = num_path[-1, -1] == 1

    return on_path

In [4]:
import os
# let's load the images of the grids
train_prefix = "train"
val_prefix = "test"
data_suffix = "maps"
true_weights_suffix = ""

data_dir = "/mnt/data-c305/mniepert/reason/12x12"

train_data_path = os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")

if os.path.exists(train_data_path):
    train_inputs = np.load(os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")).astype(np.float32)
    train_labels = np.load(os.path.join(data_dir, train_prefix + "_shortest_paths.npy"))
    train_true_weights = np.load(os.path.join(data_dir, train_prefix + "_vertex_weights.npy"))

    train_inputs = train_inputs.transpose(0, 3, 1, 2)
    mean, std = (
        np.mean(train_inputs, axis=(0, 2, 3), keepdims=True),
        np.std(train_inputs, axis=(0, 2, 3), keepdims=True),
      )

    train_inputs -= mean
    train_inputs /= std
    train_inputs = train_inputs.transpose(0, 2, 3, 1)

    val_inputs = np.load(os.path.join(data_dir, val_prefix + "_" + data_suffix + ".npy")).astype(np.float32)
    val_labels = np.load(os.path.join(data_dir, val_prefix + "_shortest_paths.npy"))
    val_true_weights = np.load(os.path.join(data_dir, val_prefix + "_vertex_weights.npy"))
    
    val_inputs = val_inputs.transpose(0, 3, 1, 2)
    val_inputs -= mean
    val_inputs /= std
    val_inputs = val_inputs.transpose(0, 2, 3, 1)

    train_labels = tf.cast(train_labels, tf.float32)
    val_labels = tf.cast(val_labels, tf.float32)

In [5]:
val_inputs[0]

array([[[ 0.88818   ,  1.8510492 , -0.14565916],
        [ 0.4456104 ,  1.7403947 , -0.27828017],
        [ 0.4701976 ,  1.7127311 , -0.27828017],
        ...,
        [ 1.847081  ,  2.072358  ,  2.0028007 ],
        [ 2.0437784 ,  2.2936668 ,  2.2149942 ],
        [ 1.9454297 ,  2.1830125 ,  2.1088974 ]],

       [[ 1.3553369 ,  1.1317954 , -0.46394953],
        [ 0.4701976 ,  1.657404  , -0.30480435],
        [ 0.4947848 ,  1.7403947 , -0.27828017],
        ...,
        [ 1.7979065 ,  2.0170307 ,  1.9497523 ],
        [ 1.5520345 ,  1.7403947 ,  1.6845105 ],
        [ 1.1094649 ,  1.2424499 ,  1.2070749 ]],

       [[ 1.6749705 ,  0.93815017, -0.62309474],
        [ 0.7406568 ,  1.4084315 , -0.41090113],
        [ 0.4701976 ,  1.7680583 , -0.27828017],
        ...,
        [ 2.0683658 ,  2.3213305 ,  2.2415185 ],
        [ 1.256988  ,  1.4084315 ,  1.3662201 ],
        [ 0.34726158,  0.38487816,  0.38482478]],

       ...,

       [[ 2.1421273 ,  1.159459  , -0.5700463 ],
        [ 2

In [6]:
train_labels.shape

TensorShape([10000, 12, 12])

In [7]:
train_true_weights[1]

array([[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 1.2, 1.2],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 1.2, 1.2, 1.2],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 0.8, 0.8, 0.8, 0.8],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 0.8, 0.8, 0.8, 0.8],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 0.8, 0.8, 0.8, 0.8, 0.8],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 0.8, 0.8, 0.8, 0.8],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 0.8, 0.8, 0.8, 0.8],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 1.2, 1.2, 0.8],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 7.7, 1.2, 1.2, 1.2, 1.2],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 9.2, 9.2, 1.2],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 9.2, 1.2, 1.2],
       [7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 1.2, 1.2, 1.2, 1.2, 9.2, 1.2]],
      dtype=float16)

In [8]:
train_labels[1]

<tf.Tensor: shape=(12, 12), dtype=float32, numpy=
array([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)>

In [9]:
val_labels[1]

<tf.Tensor: shape=(12, 12), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)>

In [10]:
train_inputs.shape

(10000, 96, 96, 3)

In [11]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (train_inputs, train_labels)).shuffle(10000).batch(70)

test_ds = tf.data.Dataset.from_tensor_slices(
    (val_inputs, val_labels, val_true_weights)).batch(100)
#for element in train_ds:
#    print(element)

In [None]:
@tf.custom_gradient
def perturb_and_map_gm(x, distributions, map_states, labels):
                   
    # here we would compute distribution (with perturb and map)
    # or only map state. Here, this is passed precomputed for efficiency reasons
    # ...
    
    def custom_grad(dy):
        dy_map = tf.cast(dy < 0, tf.float32)
        grad = -tf.math.subtract(distributions, dy_map), distributions, map_states, labels
        return grad
        
    return map_states, custom_grad

In [60]:
class BasicBlock(tf.keras.layers.Layer):

    def __init__(self, filter_num, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=stride,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=1,
                                            padding="same")
        self.bn2 = tf.keras.layers.BatchNormalization()
        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num,
                                                       kernel_size=(1, 1),
                                                       strides=stride))
            self.downsample.add(tf.keras.layers.BatchNormalization())
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)

        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)

        output = tf.nn.relu(tf.keras.layers.add([residual, x]))

        return output



def make_basic_block_layer(filter_num, blocks, stride=1):
    res_block = tf.keras.Sequential()
    res_block.add(BasicBlock(filter_num, stride=stride))

    for _ in range(1, blocks):
        res_block.add(BasicBlock(filter_num, stride=1))

    return res_block

      
class ResNet18Inference(tf.keras.Model):
    def __init__(self):
        super(ResNet18Inference, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=(7, 7),
                                            strides=2,
                                            padding="same",
                                            use_bias=False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
                                               strides=2,
                                               padding="same")
        self.layer1 = make_basic_block_layer(filter_num=64, blocks=2)

        output_shape = (int(12), int(12))
        self.adaptivepool = tfa.layers.AdaptiveAveragePooling2D(output_shape)

        
    def set_sample_matrix(self, samples_in):
        self.samples = tf.transpose(tf.cast(samples_in, tf.float32))
        print(self.samples.shape)
        
    def call(self, inputs, distributions, map_states, labels, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.pool1(x)
        x = self.layer1(x, training=training)
        x = self.adaptivepool(x)
        x = tf.math.reduce_mean(x, axis=3)
        
        # compute shortest path based on current output
        # at this point, no gradient flow into x_logit!
        map_states = perturb_and_map(x, distributions, map_states, labels)
        
        return x, map_states

In [61]:
model_inference = ResNet18Inference()

In [62]:
x, _ = model_inference(train_inputs[1:2], train_labels[1:2], train_labels[1:2], train_labels[1:2])

In [63]:
x

<tf.Tensor: shape=(1, 12, 12), dtype=float32, numpy=
array([[[0.23979837, 0.23264565, 0.23263429, 0.22638993, 0.23241958,
         0.24237883, 0.29788095, 0.32114625, 0.2983883 , 0.29061562,
         0.27851146, 0.29502916],
        [0.24115965, 0.24500299, 0.24042656, 0.23586008, 0.24297407,
         0.25716746, 0.32157487, 0.29084992, 0.28433514, 0.28513294,
         0.28867802, 0.30894622],
        [0.24231896, 0.24439242, 0.23696208, 0.23364656, 0.24358243,
         0.29852647, 0.31885055, 0.29169732, 0.26031494, 0.23653966,
         0.22138055, 0.22315218],
        [0.24656403, 0.24032918, 0.23265685, 0.23091412, 0.24469167,
         0.32106784, 0.28836292, 0.29939717, 0.23022737, 0.18923935,
         0.17710221, 0.16660522],
        [0.24980474, 0.24439427, 0.2353958 , 0.23173685, 0.24387415,
         0.31358042, 0.29470038, 0.2902798 , 0.20873857, 0.18407404,
         0.17859223, 0.16913651],
        [0.24861898, 0.23942156, 0.23248366, 0.23105788, 0.24219738,
         0.3164469

In [64]:
x[0].shape

TensorShape([12, 12])

In [65]:
train_labels[1]

<tf.Tensor: shape=(12, 12), dtype=float32, numpy=
array([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)>

In [66]:
dijkstra(x[0])

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)

In [67]:
def HammingLoss(y_true, y_pred):
    loss = tf.math.reduce_mean(y_pred * (tf.ones_like(y_true) - y_true) + (tf.ones_like(y_pred) - y_pred) * y_true)
    return loss

In [68]:
optimizer = tf.keras.optimizers.Adam()

In [69]:
# custom accuracy function
class SameSolutionAccuracy(tf.keras.metrics.Metric):

    def __init__(self, name='same_solution_accuracy', **kwargs):
        super(SameSolutionAccuracy, self).__init__(name=name, **kwargs)
        self.same_solutions = self.add_weight(name='tp', initializer='zeros')
        self.counter = self.add_weight(name='counter', initializer='zeros')

    def update_state(self, y_true, y_pred, cost_matrix):
        
        y_true = tf.reshape(y_true, [-1, 144])
        y_pred = tf.reshape(y_pred, [-1, 144])
        cost_matrix = tf.cast(cost_matrix, tf.float32)
        cost_matrix = tf.reshape(cost_matrix, [-1, 144])
        
        #print(y_true.shape)
        
        y_true_cost = tf.math.reduce_sum(cost_matrix * y_true, 1)
        y_pred_cost = tf.math.reduce_sum(cost_matrix * y_pred, 1)
        
        #print(y_true_cost)
        #print(y_pred_cost)
        
        # True if the cost is the same
        equal_values = tf.cast(tf.math.less_equal(y_pred_cost, y_true_cost), tf.float32)
        #print(equal_values)
        sum_correct_in_batch = tf.math.reduce_sum(equal_values)
        #print(sum_correct_in_batch)
        #print(y_true.shape[0])

        self.same_solutions.assign_add(sum_correct_in_batch)
        self.counter.assign_add(y_true.shape[0])
        
        
        #print("---")

    def result(self):
        return self.same_solutions/self.counter

In [70]:
train_loss = tf.keras.metrics.Mean(name='train_loss')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_samesol = SameSolutionAccuracy()

In [71]:
#@tf.function
def train_step(images, distributions, map_states, labels):  
    with tf.GradientTape() as tape:
        x, _ = model_inference(images, distributions, map_states, labels, training=True)     
        loss = HammingLoss(labels, map_states) 
        
    gradients = tape.gradient(loss, model_inference.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model_inference.trainable_variables))

    train_loss(HammingLoss(labels, map_states))

In [78]:
#@tf.function
def test_step(images, labels, cost_matrix):
    predictions, _ = model_inference(images, labels, labels, labels, training=False)
    #t_loss = loss_object_0(labels, predictions[0])

    weight_matrix = predictions.numpy()
    map_paths = np.zeros_like(weight_matrix)
    for i in range(weight_matrix.shape[0]):
        map_paths[i] = dijkstra(weight_matrix[i])
        
    test_samesol(labels, map_paths, cost_matrix)

In [73]:
loss_mse = tf.keras.losses.MeanSquaredError()
def HammingLossMSE(y_true, y_pred, cost_matrix):
    y_true = tf.reshape(y_true, [-1, 144])
    y_pred = tf.reshape(y_pred, [-1, 144])
    cost_matrix = tf.cast(cost_matrix, tf.float32)
    cost_matrix = tf.reshape(cost_matrix, [-1, 144])
        
    y_true_cost = tf.math.reduce_sum(cost_matrix * y_true, 1)
    y_pred_cost = tf.math.reduce_sum(cost_matrix * y_pred, 1)
    loss = tf.math.square(tf.math.subtract(y_true_cost, y_pred_cost))
    return loss

predictions, _ = model_inference(train_inputs[0:1], train_labels[0:1], train_labels[0:1], train_labels[0:1], training=False)
true_weights = train_true_weights[0:1]
true_labels = tf.Variable(train_labels[0])

#print(sum(sum(true_weights[0]*train_labels[0])))

with tf.GradientTape() as tape:
    
    weights = predictions[0].numpy()
    map_path = dijkstra(weights)
    map_path = tf.Variable(map_path)
    
    print(map_path)
    print(true_labels)
     
    loss = HammingLoss(true_labels, map_path)
    var_grad = tape.gradient(loss, [map_path, true_labels])
    print(var_grad[0])
    #print(np.sign(-var_grad.numpy()))
    
    #theta = np.sign(-var_grad.numpy())
    #print(theta)
    #for k in range(20):
    #    map_path = dijkstra(weights + (k*theta))
        #print(map_path)
        #print(labels)
        #print(sum(sum(map_path * weights)))
        

<tf.Variable 'Variable:0' shape=(12, 12) dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)>
<tf.Variable 'Variable:0' shape=(12, 12) dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
  

In [84]:
# training loop
for epoch in range(10):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()

    for images, labels in train_ds:
        
        # first get the weights from the current model
        x, _ = model_inference(images, labels, labels, labels, training=False)

        # store the weights as a numpy matrix
        weight_matrix = x.numpy()
        
        # distributions stores the probabilities of variables
        distributions = np.zeros_like(weight_matrix)
        # stores all the map states
        map_states = np.zeros_like(weight_matrix)
        
        # here we iterate over batches
        for i in range(weight_matrix.shape[0]):
            # first we add the MAP state to the array
            map_path = dijkstra(weight_matrix[i])
            map_states[i] = map_path
            #map_path_length = np.count_nonzero(map_path)
            distributions[i] = distributions[i] + map_path
            # we perturb the predictions 9 times
            #for j in range(5):
            #    perturbed_matrix = weight_matrix[i] + (np.random.logistic(0, 1, weight_matrix[i].shape)/map_path_length)
            #    distributions[i] = distributions[i] + dijkstra(perturbed_matrix)
            #distributions[i] = distributions[i] / 5
            #print(distributions[i])
        
        #print(distributions)
        train_step(images, distributions, map_states, labels)

    template = 'Epoch {}, Loss: {}'
    print(template.format(epoch + 1, train_loss.result()))

Epoch 1, Loss: 0.01045193336904049
Epoch 2, Loss: 0.009578613564372063
Epoch 3, Loss: 0.00895643513649702
Epoch 4, Loss: 0.009052752517163754


KeyboardInterrupt: 

In [None]:
%prun dijkstra(perturbed_matrix)

In [85]:
test_samesol.reset_states()

for test_images, test_labels, cost_matrix in test_ds:
    test_step(test_images, test_labels, cost_matrix)
    
template = 'SameSol: {}'
print(template.format(test_samesol.result()))

SameSol: 0.9480000138282776


In [None]:
val_labels[1]

In [None]:
x = model(val_inputs[0:2])
x[1]
#1-dijkstra(x[1]).shortest_path