In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import networkx as nx
import torch.nn as nn
import scipy.ndimage
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
"""
def force_cudnn_initialization():
    s = 32
    dev = torch.device('cuda')
    torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
"""

"\ndef force_cudnn_initialization():\n    s = 32\n    dev = torch.device('cuda')\n    torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))\n"

In [2]:
#!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
#!pip install torchmetrics
#!pip install numpy==1.21 --user
#!pip install matplotlib --user
print(torch.__version__, device)

1.13.1 cuda


In [3]:
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    sort_edge_index,
    add_remaining_self_loops,
    to_undirected,
    degree,
    
)

In [4]:
graph_size = 128
BATCH_SIZE = 4
NUM_PATHS = 4

gs = torch.Tensor(graph_size).to(device)
G = nx.grid_graph([graph_size, graph_size])
"""
G.add_edges_from([
    ((x, y), (x+1, y+1))
    for x in range(3)
    for y in range(3)
] + [
    ((x+1, y), (x, y+1))
    for x in range(3)
    for y in range(3)
], weight=1.4)
"""
edges = list(G.edges)
nedges = []
for i in range(len(edges)):
    nedges.append([edges[i][0][0] * graph_size + edges[i][0][1]
                   , edges[i][1][0] * graph_size + edges[i][1][1]])
edge_maps = np.asarray(nedges).astype(np.int32).T
edge_maps = torch.LongTensor(edge_maps)
edge_maps = to_undirected(edge_maps).numpy()
self_connects = np.arange(graph_size*graph_size, dtype=np.int32)
self_connects = np.tile(np.expand_dims(self_connects, axis=0), [2,1])
edge_maps = np.concatenate([edge_maps, self_connects], axis=1)


In [5]:
edge_maps

array([[    0,     0,     1, ..., 16381, 16382, 16383],
       [    1,   128,     0, ..., 16381, 16382, 16383]], dtype=int64)

In [6]:
#edges_maps = np.concatenate([edge_maps, edge_maps[::-1]], axis=1)

In [7]:
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_networkx, coalesce
def dist(a, b):
    (x1, y1) = a//graph_size, a%graph_size
    (x2, y2) = b//graph_size, b%graph_size
    return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5


In [8]:
depth = np.asarray([2])
rem_depth = np.maximum((np.log2(graph_size // (2**depth))).astype(np.int32) - 2, 0)
reduction_rate = (2**depth)
enc_size = graph_size // reduction_rate


In [9]:
import torch_geometric
import cv2
from torch_geometric.transforms.add_positional_encoding import AddLaplacianEigenvectorPE
from numpy.random import RandomState
class RandomImageDataset(torch.utils.data.Dataset):
    def __init__(self, graph_size=64, mfill_rate=0.5, sfill_rate=0.135, steps=128):
        super().__init__()
        self.graph_size = graph_size
        self.mfill_rate = mfill_rate
        self.sfill_rate = sfill_rate
        self.steps = steps
        #self.transform = AddLaplacianEigenvectorPE(4, None, is_undirected=True)

    def len(self):
        return self.steps#len(self.img_labels)
    
    def __len__(self):
        return self.len()
    
    def __getitem__(self, idx):
        graph_size = self.graph_size
        mfill_rate = self.mfill_rate
        sfill_rate = self.sfill_rate
        i_ = 0
        t0 = time.time()
        while True:
            tt = time.time()
            fill_rate = np.random.random()*0.75 + 0.1
            r = (np.random.choice(2, graph_size*graph_size,
            p=[1-fill_rate, fill_rate])).reshape(graph_size*graph_size, 1)
            map = r.reshape(graph_size, graph_size, 1)
            label, num = scipy.ndimage.label(1 - map)
            #print("map gen", time.time() - tt)
            probs = np.zeros((num+1,))
            #for i in range(1,num):
            #    probs[i] = (label == i).sum()
            u, c = np.unique(label, return_counts=True)
            #print("unique counts:", u, c)
            umask = u>0
            c = c[umask]
            u = u[umask]
            probs[u.astype(np.int32)] = c
            probs = probs * (probs > 200)
            #print("map gen 1 ", time.time() - tt)
            if probs.sum() > 0:
                break
            t1 = time.time()
            #print(f"loop {i_} {t1 - t0}: ", probs, fill_rate, r.sum())
            i_ += 1
        #print("loop out: ", time.time() - t0)
        
        
        probs = probs/probs.sum()
        start = np.zeros((NUM_PATHS, 2))
        end = np.zeros((NUM_PATHS, 2))        
        for i in range(NUM_PATHS):
            chosen = np.random.choice(range(1,probs.shape[0]+1), p=probs) - 1
            clust = np.argwhere(label == chosen)[:,:-1]
            link = np.arange(clust.shape[0])
            np.random.shuffle(link)
            cur_pos = clust[link[0]]
            rem_pos = clust[link[1:]]
            oabs_dist = np.abs(rem_pos[:,0] - cur_pos[0]) + np.abs(rem_pos[:,1] - cur_pos[1])
            abs_dist = (oabs_dist)
            abs_dist = abs_dist / abs_dist.sum()
            end_pos = np.random.choice(np.arange(rem_pos.shape[0]), p=abs_dist)
            #print(oabs_dist[end_pos], oabs_dist.max())
            end_pos = clust[link[1 + end_pos]]
            start[i] = cur_pos
            end[i] = end_pos
        start = torch.LongTensor(start)
        end = torch.LongTensor(end)
        map = 1 - map
        
        movinglabels = np.zeros((map.shape[0], map.shape[1], 1), dtype=np.int64)
        for i in range(enc_size[0]):
            for j in range(enc_size[0]):
                #print(i, reduction_rate[0], i*reduction_rate[0])
                map_snip = map[i*reduction_rate[0]:(i+1)*reduction_rate[0], 
                               j*reduction_rate[0]:(j+1)*reduction_rate[0]]
                labelc, num = scipy.ndimage.measurements.label(map_snip)
                movinglabels[i*reduction_rate[0]:(i+1)*reduction_rate[0], 
                               j*reduction_rate[0]:(j+1)*reduction_rate[0]][labelc>0] = labelc[labelc>0] + movinglabels.max()

        
        
        
        map = torch.Tensor(map)
        
        label2 = torch.LongTensor(label)
        label = torch.LongTensor(movinglabels)
        
        #print("TIME TO YIELD: ", time.time() - t0)
        
        """
        idx = np.where(r)[0]
        gen_grid = r
        #print(idx.shape, r.shape, np.isin(edge_maps, idx).shape)
        edges = edge_maps[:,~np.isin(edge_maps, idx).any(axis=0)]
        r = edges
        #print(np.unique(edges[0,:]))
        z = np.zeros((self.graph_size*self.graph_size, 1))
        onesTens = torch.Tensor(z)
        r = torch.LongTensor(r.astype(np.int32))
        
        dat = Data(x=onesTens, edge_index=r)
        dat.num_features = 1
        dat.num_nodes = dat.x.shape[0]
        #print(dat, , dat.num_nodes)
        net = to_networkx(dat)
        cur_node = cur_pos[1] + cur_pos[0] * graph_size
        end_node = end_pos[1] + end_pos[0] * graph_size
        path = nx.shortest_path(net, cur_node, end_node, weight="cost")
        y = np.zeros((self.graph_size*self.graph_size, 1))
        path = [int(p) for p in path]
        y[path] = 1
        z = 1-gen_grid
        dat.x = torch.Tensor(z)
        dat.start = torch.LongTensor([cur_node.astype(np.int64)])
        dat.end = torch.LongTensor([end_node.astype(np.int64)])
        #print(end_node.astype(np.int64))
        #print(dat.end)
        dat.y = torch.Tensor(y)
        #dat = self.transform(dat)
        #pos_encs = dat.x[:,1:]
        direction = end_pos - cur_pos
        direction = direction / np.sqrt(np.square(direction).sum())
        pos = torch.Tensor(direction)
        #print(pos.shape)
        pos = pos.unsqueeze(0)
        
        dat.startend = pos#.unsqueeze(0)
        """
        return map, start, end, label, label2 #image_edges


class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, path, graph_size=512, batch_size=2):
        super().__init__()
        self.path = path
        self.images = [os.path.join(path, p) for p in os.listdir(path)]
        #self.steps = steps
        self.gs = graph_size
        self.length = len(self.images)
        self.bs = batch_size
        self.i = 0
        self.idx = np.arange(len(self.images))
        print(len(self.idx))
        #self.prng = RandomState(1234567890)

        #self.transform = AddLaplacianEigenvectorPE(4, None, is_undirected=True)

    def len(self):
        return self.length#len(self.img_labels)
    
    def __len__(self):
        return self.len()
    
    def __getitem__(self, idx):
        i_ = 0
        t0 = time.time()
        #prng = self.prng
        while True:
            tt = time.time()
            image_idx = self.idx[self.i]
            self.i += 1
            if self.i >= self.length:
                self.i = 0 
                #np.random.shuffle(self.idx)
            #image_idx = np.random.randint(len(self.images))
            map = cv2.imread(self.images[image_idx],0)
            map = 1*(cv2.resize(map, (self.gs, self.gs)) > 127)
            map = map.reshape(graph_size, graph_size, 1)
            #print(map.shape)
            label, num = scipy.ndimage.label(map)
            #print("map gen", time.time() - tt)
            probs = np.zeros((num+1,))
            #for i in range(1,num):
            #    probs[i] = (label == i).sum()
            u, c = np.unique(label, return_counts=True)
            #print("unique counts:", u, c)
            umask = u>0
            c = c[umask]
            u = u[umask]
            probs[u.astype(np.int32)] = c
            probs = probs * (probs > 100)
            #print("map gen 1 ", time.time() - tt)
            if probs.sum() > 0:
                
                #print(image_idx, self.images[image_idx], num)
                break
            t1 = time.time()
            #print(f"loop {i_} {t1 - t0}: ", probs, fill_rate, r.sum())
            i_ += 1
        #print("loop out: ", time.time() - t0)
        
        
        probs = probs/probs.sum()
        start = np.zeros((NUM_PATHS, 2))
        end = np.zeros((NUM_PATHS, 2))        
        for i in range(NUM_PATHS):
            chosen = np.random.choice(range(1,probs.shape[0]+1), p=probs) - 1
            clust = np.argwhere(label == chosen)[:,:-1]
            link = np.arange(clust.shape[0])
            np.random.shuffle(link)
            cur_pos = clust[link[0]]
            end_pos = clust[link[1]]
            start[i] = cur_pos
            end[i] = end_pos
        start = torch.LongTensor(start)
        end = torch.LongTensor(end)
        
        movinglabels = np.zeros((map.shape[0], map.shape[1], 1), dtype=np.int64)
        for i in range(enc_size[0]):
            for j in range(enc_size[0]):
                #print(i, reduction_rate[0], i*reduction_rate[0])
                map_snip = map[i*reduction_rate[0]:(i+1)*reduction_rate[0], 
                               j*reduction_rate[0]:(j+1)*reduction_rate[0]]
                labelc, num = scipy.ndimage.measurements.label(map_snip)
                movinglabels[i*reduction_rate[0]:(i+1)*reduction_rate[0], 
                               j*reduction_rate[0]:(j+1)*reduction_rate[0]][labelc>0] = labelc[labelc>0] + movinglabels.max()

        
        
        
        map = torch.Tensor(map)
        
        label2 = torch.LongTensor(label)
        label = torch.LongTensor(movinglabels)
        
        #print("TIME TO YIELD: ", time.time() - t0)
        
        """
        idx = np.where(r)[0]
        gen_grid = r
        #print(idx.shape, r.shape, np.isin(edge_maps, idx).shape)
        edges = edge_maps[:,~np.isin(edge_maps, idx).any(axis=0)]
        r = edges
        #print(np.unique(edges[0,:]))
        z = np.zeros((self.graph_size*self.graph_size, 1))
        onesTens = torch.Tensor(z)
        r = torch.LongTensor(r.astype(np.int32))
        
        dat = Data(x=onesTens, edge_index=r)
        dat.num_features = 1
        dat.num_nodes = dat.x.shape[0]
        #print(dat, , dat.num_nodes)
        net = to_networkx(dat)
        cur_node = cur_pos[1] + cur_pos[0] * graph_size
        end_node = end_pos[1] + end_pos[0] * graph_size
        path = nx.shortest_path(net, cur_node, end_node, weight="cost")
        y = np.zeros((self.graph_size*self.graph_size, 1))
        path = [int(p) for p in path]
        y[path] = 1
        z = 1-gen_grid
        dat.x = torch.Tensor(z)
        dat.start = torch.LongTensor([cur_node.astype(np.int64)])
        dat.end = torch.LongTensor([end_node.astype(np.int64)])
        #print(end_node.astype(np.int64))
        #print(dat.end)
        dat.y = torch.Tensor(y)
        #dat = self.transform(dat)
        #pos_encs = dat.x[:,1:]
        direction = end_pos - cur_pos
        direction = direction / np.sqrt(np.square(direction).sum())
        pos = torch.Tensor(direction)
        #print(pos.shape)
        pos = pos.unsqueeze(0)
        
        dat.startend = pos#.unsqueeze(0)
        """
        return map, start, end, label, label2 #image_edges
    

TOTAL_IMAGES = 1024 // NUM_PATHS
#STEPS = 
train_data = RandomImageDataset(graph_size, steps=TOTAL_IMAGES)
val_data = RandomImageDataset(graph_size)
#train_data = CustomImageDataset("bg512-png", graph_size, batch_size = BATCH_SIZE)
#val_data = CustomImageDataset("bg512-png", graph_size)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_data, )

In [10]:
fill_rate = np.random.random()*0.75 + 0.1
r = (np.random.choice(2, graph_size*graph_size,
p=[1-fill_rate, fill_rate])).reshape(graph_size*graph_size, 1)
map = r.reshape(graph_size, graph_size, 1)
label, num = scipy.ndimage.label(1 - map)

In [11]:
u, c = np.unique(label, return_counts=True)
print(u, c)

[   0    1    2 ... 2052 2053 2054] [12861     1     3 ...     1     1     1]


In [12]:
"""
graph_size = 128#self.graph_size
mfill_rate = 0.4#self.mfill_rate
sfill_rate = 0.04#self.sfill_rate

fill_rate = min([max([np.random.normal(mfill_rate, sfill_rate), 0.05]), 0.95])
while True:
    r = (np.random.choice(2, graph_size*graph_size,
    p=[1-fill_rate, fill_rate])).reshape(graph_size*graph_size, 1)
    map = r.reshape(graph_size, graph_size, 1)
    label, num = scipy.ndimage.measurements.label(1 - map)
    probs = np.zeros((num,))
    for i in range(1,num):
        probs[i] = (label == i).sum()
    probs = probs * (probs > 1)
    if probs.sum() > 0:
        break

img = batch[0][1].numpy()
label, num = scipy.ndimage.measurements.label(img)
image_snips = np.zeros((enc_size[0], enc_size[0], reduction_rate[0], reduction_rate[0]))
image_edges = []
dat = []
for i in range(image_snips.shape[0]):
    for j in range(image_snips.shape[0]):
        image_snips[i,j] = label[image_snips.shape[2]*i:image_snips.shape[2]*(i+1),
                                 image_snips.shape[2]*j:image_snips.shape[2]*(j+1), 0]
        
nc = 0        
for i in range(image_snips.shape[0]):
    for j in range(image_snips.shape[0]): 
        loc_data = {}
        for uc, un_cluster in enumerate(np.unique(image_snips[i,j].flatten()).astype(np.int32)):
            dat = [i,j]
            if un_cluster == 0:
                continue
            loc_data[un_cluster] = []
            if i > 0:
                cond = (image_snips[i,j, 0, :] == un_cluster) & \
                        (image_snips[i-1, j, -1, :] == un_cluster)
                if cond.any():
                    loc_data[un_cluster].append([i*image_snips.shape[0] + j, (i-1)*image_snips.shape[0] + j])
            
            if j > 0:
                cond = (image_snips[i,j, :, 0] == un_cluster) & \
                        (image_snips[i, j-1, :, -1] == un_cluster)
                if cond.any():
                    loc_data[un_cluster].append([i*image_snips.shape[0] + j, i*image_snips.shape[0] + (j - 1)])
                    
            if i < (image_snips.shape[0] - 1):
                cond = (image_snips[i,j, -1, :] == un_cluster) & \
                        (image_snips[i+1, j, 0, :] == un_cluster)
                if cond.any():
                    loc_data[un_cluster].append([i*image_snips.shape[0] + j, (i+1)*image_snips.shape[0] + j])
            
            if j < (image_snips.shape[0] - 1):
                cond = (image_snips[i,j, :, -1] == un_cluster) & \
                        (image_snips[i, j+1, :, 0] == un_cluster)
                if cond.any():
                    loc_data[un_cluster].append([i*image_snips.shape[0] + j, i*image_snips.shape[0] + (j + 1)])
        image_edges.append(loc_data)

new_ie = []
mappers = []
nc = 0
for i in range(len(image_edges)):
    mappers.append({k:v+nc for v, k in enumerate(image_edges[i])})
    nc += len(image_edges[i])

assigner = {}
for i, map_ in enumerate(mappers):
    for key in map_:
        assigner[(i, key)] = map_[key]
    
for i in range(len(image_edges)):
    for j in image_edges[i]:
        cont = image_edges[i][j]
        new_ie.append([mappers[i][j], mappers[i][j]])
        for item in cont:
            new_ie.append([mappers[item[0]][j], mappers[item[1]][j]])
"""

'\ngraph_size = 128#self.graph_size\nmfill_rate = 0.4#self.mfill_rate\nsfill_rate = 0.04#self.sfill_rate\n\nfill_rate = min([max([np.random.normal(mfill_rate, sfill_rate), 0.05]), 0.95])\nwhile True:\n    r = (np.random.choice(2, graph_size*graph_size,\n    p=[1-fill_rate, fill_rate])).reshape(graph_size*graph_size, 1)\n    map = r.reshape(graph_size, graph_size, 1)\n    label, num = scipy.ndimage.measurements.label(1 - map)\n    probs = np.zeros((num,))\n    for i in range(1,num):\n        probs[i] = (label == i).sum()\n    probs = probs * (probs > 1)\n    if probs.sum() > 0:\n        break\n\nimg = batch[0][1].numpy()\nlabel, num = scipy.ndimage.measurements.label(img)\nimage_snips = np.zeros((enc_size[0], enc_size[0], reduction_rate[0], reduction_rate[0]))\nimage_edges = []\ndat = []\nfor i in range(image_snips.shape[0]):\n    for j in range(image_snips.shape[0]):\n        image_snips[i,j] = label[image_snips.shape[2]*i:image_snips.shape[2]*(i+1),\n                                 i

In [13]:
from torch_geometric.nn import TAGConv


In [14]:
import torch
import torch.nn.functional as F
from torch_sparse import spspmm

from torch_geometric.nn import GCNConv, TopKPooling, SAGPooling, BatchNorm

from torch_geometric.utils.repeat import repeat

def TagConv(channels):
    return nn.ModuleList([TAGConv(channels, channels, 7, improved=True),
                   nn.ReLU(),
                   nn.Dropout(0.0)])

def StackedConv(channels, depth=3):
    return nn.ModuleList([TagConv(channels) for _ in range(depth-1)] + [TAGConv(channels, channels, 7, improved=True)])
class GraphUNet(torch.nn.Module):
    r"""The Graph U-Net model from the `"Graph U-Nets"
    <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
    architecture with graph pooling and unpooling operations.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Size of each hidden sample.
        out_channels (int): Size of each output sample.
        depth (int): The depth of the U-Net architecture.
        pool_ratios (float or [float], optional): Graph pooling ratio for each
            depth. (default: :obj:`0.5`)
        sum_res (bool, optional): If set to :obj:`False`, will use
            concatenation for integration of skip connections instead
            summation. (default: :obj:`True`)
        act (torch.nn.functional, optional): The nonlinearity to use.
            (default: :obj:`torch.nn.functional.relu`)
    """
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 pool_ratios=0.5, sum_res=True, act=F.relu, reps=1):
        super().__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res
        self.reps = reps
        self.limiter = nn.Parameter(torch.Tensor(np.zeros((1))), requires_grad=True)
        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        
        self.down_convs.append(TAGConv(in_channels, channels, 7, improved=True))
        for rep in range(self.reps - 1):
            self.down_convs.append(TAGConv(channels, channels, 7, improved=True))
            self.down_convs.append(TAGConv(channels, channels, 7, improved=True))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            for rep in range(self.reps):
                self.down_convs.append(TAGConv(channels, channels, 7, improved=True))



        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            for rep in range(self.reps):
                self.up_convs.append(TAGConv(channels, channels, 7, improved=True))
            
        for rep in range(self.reps - 1):
            self.up_convs.append(TAGConv(channels, channels, 7, improved=True))
        self.up_convs.append(TAGConv(in_channels, out_channels, 7, improved=True))
        self.linear0 = nn.Linear(2, channels)
        self.linear1 = nn.Linear(channels, channels)
        self.linear2 = nn.Linear(channels, channels)
        self.reset_parameters()
        self.channels = channels
    def reset_parameters(self):
        for conv in self.down_convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        for conv in self.up_convs:
            conv.reset_parameters()
        

    def forward(self, x, edge_index, startend, batch=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        edge_weight = x.new_ones(edge_index.size(1))
        
        x = self.down_convs[0](x, edge_index, edge_weight)
        x = self.act(x)
        for rep in range(1, self.reps):
            x = self.down_convs[rep](x, edge_index, edge_weight)
            x = self.act(x)

        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        perms = []

        for i in range(1, self.depth + 1):
            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
                                                       x.size(0))
            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
                x, edge_index, edge_weight, batch)
            for rep in range(self.reps):
                x = self.down_convs[rep + i*self.reps](x, edge_index, edge_weight)
                x = self.act(x)

            if i < self.depth:
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
            perms += [perm]
            
        h = self.linear0(startend)
        h = self.act(h)
        h = self.linear1(h)
        h = self.act(h)
        h = self.linear2(h)
        #print(x.shape)
        h = h.squeeze(1)
        #print(batch.shape, x.shape)
        h = torch.gather(h, 0 , batch.unsqueeze(-1).repeat([1, self.channels]))
        x = x + h
        for i in range(self.depth):
            j = self.depth - 1 - i

            res = xs[j]
            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            perm = perms[j]

            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            
            for rep in range(self.reps):
                x = self.up_convs[rep + i*self.reps](x, edge_index, edge_weight)
                x = x if ((i == self.depth-1) and (rep == self.reps-1)) else self.act(x)
                #x =  if (i < self.depth - 1) else x
            
        x = x*self.limiter
        return x


    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.hidden_channels}, {self.out_channels}, '
                f'depth={self.depth}, pool_ratios={self.pool_ratios})')

In [15]:

import math
from dataclasses import dataclass
from typing import Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F



In [16]:
neighbor_filter = torch.zeros(1, 1, (2*(graph_size + 1)) + 1)
neighbor_filter[:,:,1] = 1
neighbor_filter[:,:,[graph_size, graph_size+2]] = 1
neighbor_filter[:,:,(2*graph_size+1)] = 1
neighbor_filter = neighbor_filter.to(device)

neighbor_kernel = torch.zeros(1, 1, 3, 3)
neighbor_kernel[:,:,:,1] = 1
neighbor_kernel[:,:,1,:] = 1
neighbor_kernel[:,:,1,1] = 0

neighbor_kernel = neighbor_kernel.to(device)

diff_kernel = torch.zeros(1, 1, 3, 3)
diff_kernel[:,:,:,1] = 1
diff_kernel[:,:,1,:] = 1
diff_kernel[:,:,1,1] = 0

diff_kernel = diff_kernel.to(device)

In [17]:
neighbor_kernel

tensor([[[[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.]]]], device='cuda:0')

In [18]:
def backtrack(start_maps: torch.tensor, goal_maps: torch.tensor,
              parents: torch.tensor, current_t: int) -> torch.tensor:
    """
    Backtrack the search results to obtain paths
    Args:
        start_maps (torch.tensor): one-hot matrices for start locations
        goal_maps (torch.tensor): one-hot matrices for goal locations
        parents (torch.tensor): parent nodes
        current_t (int): current time step
    Returns:
        torch.tensor: solution paths
    """

    num_samples = start_maps.shape[0]
    parents = parents.type(torch.long)
    goal_maps = goal_maps.type(torch.long)
    #start_maps = start_maps.type(torch.long)
    path_maps = goal_maps.type(torch.long)
    num_samples = len(parents)
    loc = (parents * goal_maps.view(-1)).sum(-1)
    for _ in range(current_t):
        path_maps.view(-1)[loc] = 1
        loc = parents[loc]
    return path_maps


In [19]:
def get_heuristic(goal: torch.tensor,
                  size : torch.LongTensor,
                  tb_factor: float = 0.001) -> torch.tensor:
    """
    Get heuristic function for A* search (chebyshev + small const * euclidean)
    Args:
        goal_maps (torch.tensor): one-hot matrices of goal locations
        tb_factor (float, optional): small constant weight for tie-breaking. Defaults to 0.001.
    Returns:
        torch.tensor: heuristic function matrices
    """
    idx = torch.arange(size, device=device)
    rc = torch.remainder(idx, graph_size).unsqueeze(-1)
    rr = torch.div(idx, graph_size, rounding_mode="floor").unsqueeze(-1)
    xy = torch.cat([rr, rc], axis=-1)
    xy = xy.unsqueeze(0).repeat([goal.shape[0], 1, 1])
    goal = goal.unsqueeze(1).repeat([1,xy.shape[1],1])
    dxdy = torch.abs(xy - goal)
    h = dxdy.sum(dim=-1) - dxdy.min(dim=-1)[0]
    euc = torch.sqrt(((xy - goal)**2).sum(-1))
    h = (euc)# + tb_factor * h)#.reshape_as(goal_maps)

    return h

In [20]:
#torch.autograd.set_detect_anomaly(True)

In [21]:



def _st_softmax_noexp(val: torch.tensor) -> torch.tensor:
    """
    Softmax + discretized activation
    Used a detach() trick as done in straight-through softmax
    Args:
        val (torch.tensor): exponential of inputs.
    Returns:
        torch.tensor: one-hot matrices for input argmax.
    """

    val_ = val
    y = val_ / (val_.sum(dim=-1, keepdim=True))
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y)
    y_hard[range(y_hard.shape[0]), ind] = 1
    y_hard = y_hard.reshape_as(val)
    y = y.reshape_as(val)
    return ((y_hard - y).detach() + y)

def _st_softmax_noexp_ret(val: torch.tensor) -> torch.tensor:
    """
    Softmax + discretized activation
    Used a detach() trick as done in straight-through softmax
    Args:
        val (torch.tensor): exponential of inputs.
    Returns:
        torch.tensor: one-hot matrices for input argmax.
    """

    val_ = val
    y = val_ / (val_.sum(dim=-1, keepdim=True))
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y)
    y_hard[range(y_hard.shape[0]), ind] = 1
    y_hard = y_hard.reshape_as(val)
    y = y.reshape_as(val)
    return ((y_hard - y).detach() + y), y


def expand(x: torch.tensor, neighbor_filter: torch.tensor) -> torch.tensor:
    """
    Expand neighboring node 
    Args:
        x (torch.tensor): selected nodes
        neighbor_filter (torch.tensor): 3x3 filter to indicate 8 neighbors
    Returns:
        torch.tensor: neighboring nodes of x
    """

    x = x.unsqueeze(1)#.unsqueeze(0)
    #num_samples = x.shape[1]
    #print(x.shape, neighbor_filter.shape)
    y = F.conv1d(x, neighbor_filter, padding="same")#.squeeze(1)
    y = y.squeeze(1)#.squeeze(0)
    return y

def expand2d(x: torch.tensor, neighbor_filter: torch.tensor) -> torch.tensor:
    """
    Expand neighboring node 
    Args:
        x (torch.tensor): selected nodes
        neighbor_filter (torch.tensor): 3x3 filter to indicate 8 neighbors
    Returns:
        torch.tensor: neighboring nodes of x
    """

    x = x.unsqueeze(1)#.unsqueeze(0)
    #num_samples = x.shape[1]
    #print(x.shape, neighbor_filter.shape)
    y = F.conv2d(x, neighbor_filter, padding="same")#.squeeze(1)
    y = y.squeeze(1)#.squeeze(0)
    return y


def find_path(cost_maps, start, end, img, back=True):
    #cost_maps = out.squeeze(-1)*0

    #self.open_maps *= 0
    #self.goal_maps *= 0
    #self.histories *= 0
    #self.g *= 0
    #self.stp *= 0
    #self.stp += 1

    start_maps = start#data.start#[:, 0]
    goal_point = end#data.end#[:, 0]
    obstacles_maps = img#data.x[:,0]#[:, 0]


    num_samples = cost_maps.shape[0]
    #neighbor_filter = self.neighbor_filter
    #neighbor_filter = torch.repeat_interleave(neighbor_filter, num_samples,
    #                                          0)
    size = cost_maps.shape[1]
    obstacles_maps = obstacles_maps.reshape(-1, size)
    open_maps = torch.zeros_like(cost_maps)
    open_maps[range(num_samples), start_maps] = 1
    goal_maps = torch.zeros_like(cost_maps)#self.goal_maps
    goal_maps[range(num_samples), goal_point] = 1
    cost_maps = cost_maps.squeeze(-1)

    histories = torch.zeros_like(cost_maps)#self.histories
    intermediate_results = []

    #h = co#torch.arange(size, device=device)*0 #self.get_heuristic(goal_maps)

    goal = torch.cat([(end//graph_size).unsqueeze(1), (end%graph_size).unsqueeze(1)], axis=1)


    h = get_heuristic(goal, size)
    h = h * (1+(cost_maps)*1e-2)#.squeeze(-1)
    g = torch.zeros_like(cost_maps)#self.g


    parents = (torch.ones_like(cost_maps) * goal_maps.max(-1, keepdim=True)[-1])

    training = True
    Tmax = 1#0.25 if training else 1.
    Tmax = int(Tmax * size)
    g_ratio = 0.5

    t_exp = 0

    for t in range(Tmax):

        # select the node that minimizes cost
        #print(g.shape, h.shape)
        f = g_ratio * g + (1 - g_ratio) * h
        f_exp = torch.exp(-1 * f / size)
        f_exp = f_exp * open_maps
        selected_node_maps = _st_softmax_noexp(f_exp)
        #print(selected_node_maps.shape)
        #if store_intermediate_results:
        #    intermediate_results.append({
        #        "histories":
        #        histories.unsqueeze(1).detach(),
        #        "paths":
        #        selected_node_maps.unsqueeze(1).detach()
        #    })

        # break if arriving at the goal
        dist_to_goal = (selected_node_maps * goal_maps).sum(axis=-1, keepdim=True)
        is_unsolved = (dist_to_goal < 1e-8).float()

        histories = histories + selected_node_maps
        histories = torch.clamp(histories, 0, 1)
        open_maps = open_maps - is_unsolved * selected_node_maps
        open_maps = torch.clamp(open_maps, 0, 1)

        # open neighboring nodes, add them to the openlist if they satisfy certain requirements
        #raise Exception()
        #t1 = time.time()
        gcpy = selected_node_maps.reshape(num_samples, graph_size, graph_size)
        gcpy = F.pad(gcpy, (1, 1, 1, 1))#.reshape(num_samples, -1)
        neighbor_nodes = expand2d(gcpy, neighbor_kernel)
        #neighbor_nodes = neighbor_nodes.reshape(num_samples, graph_size+2, graph_size+2)
        neighbor_nodes = neighbor_nodes[:, 1:-1, 1:-1].reshape(num_samples, -1)
        neighbor_nodes = neighbor_nodes * obstacles_maps


        # update g if one of the following conditions is met
        # 1) neighbor is not in the close list (1 - histories) nor in the open list (1 - open_maps)
        # 2) neighbor is in the open list but g < g2
        cost_expand = (g + (cost_maps) * 1e-2) * selected_node_maps
        cost_expand = cost_expand.reshape(num_samples, graph_size, graph_size)
        cost_expand = F.pad(cost_expand, (1, 1, 1, 1))
        cost_expand = expand2d(cost_expand, neighbor_kernel)
        cost_expand = cost_expand[:, 1:-1, 1:-1].reshape(num_samples, -1)
        g2 = cost_expand#expand(, neighbor_filter)
        idx = (1 - open_maps) * (1 - histories) + open_maps * (g > g2)
        idx = idx * neighbor_nodes
        idx = idx.detach()
        g = g2 * idx + g * (1 - idx)
        g = g.detach()
        # update open maps
        open_maps = torch.clamp(open_maps + idx, 0, 1)
        open_maps = open_maps.detach()

        # for backtracking
        idx = idx.reshape(num_samples, -1)
        snm = selected_node_maps.reshape(num_samples, -1)
        new_parents = snm.max(-1, keepdim=True)[-1]
        #print(parents, new_parents * idx, idx.dtype)
        parents = new_parents * idx + parents * (1 - idx)
        if torch.all(is_unsolved.flatten() == 0):
            break

    if not back:
        return histories, None
    #t_e = time.time()
    parents = parents.type(torch.long)
    goal_maps = goal_maps.type(torch.long)
    #start_maps = start_maps.type(torch.long)
    path_maps = goal_maps.type(torch.long)
    num_samples = len(parents)
    loc = ((parents * goal_maps).sum(-1))
    for _ in range(t):
        path_maps[range(path_maps.shape[0]), loc] = 1
        loc = parents[range(parents.shape[0]), loc]
    #print("TIME BACKING: ", time.time() - t_e)
    return histories, path_maps


In [22]:
from torch_geometric.utils import scatter
from torch_geometric.nn import (MessageNorm, GraphNorm, PANPooling,
ASAPooling, XConv, PointTransformerConv, PDNConv, PANConv, MessagePassing, GATv2Conv, EdgePooling)
from torch_geometric.utils import softmax
from torch.nn.init import xavier_normal_, zeros_

from torch.nn.parameter import Parameter


class Expander(MessagePassing):
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super().__init__(aggr='add')
        
    def forward(self, x, edge_index):
        """
        x - feature matrix of the whole graph [num_nodes, label_dim]
        pos - node position matrix [num_nodes, coors]
        edge_index - graph connectivity [2, num_edges]
        """
        
        return self.propagate(edge_index=edge_index, x=x, aggr='add') # [N, out_channels, label_dim]

    def message(self, x_j):
        return x_j

def expand_graph(x, edge_index):
    cur_x = scatter(x, edge_index[1], 1, reduce="add")
    return cur_x


def find_path_graph(cost_maps, start, end, pos, edge_index, back=True):
    #cost_maps = out.squeeze(-1)*0

    #self.open_maps *= 0
    #self.goal_maps *= 0
    #self.histories *= 0
    #self.g *= 0
    #self.stp *= 0
    #self.stp += 1

    start_maps = start#data.start#[:, 0]
    goal_point = end#data.end#[:, 0]
    
    
    num_samples = cost_maps.shape[0]
    #neighbor_filter = self.neighbor_filter
    #neighbor_filter = torch.repeat_interleave(neighbor_filter, num_samples,
    #                                          0)
    size = cost_maps.shape[1]
    open_maps = torch.zeros_like(cost_maps)
    open_maps[range(num_samples), start_maps] = 1
    goal_maps = torch.zeros_like(cost_maps)#self.goal_maps
    goal_maps[range(num_samples), goal_point] = 1
    cost_maps = cost_maps.squeeze(-1)

    histories = torch.zeros_like(cost_maps)#self.histories
    intermediate_results = []

    #h = co#torch.arange(size, device=device)*0 #self.get_heuristic(goal_maps)

    goal = torch.cat([(end//graph_size).unsqueeze(1), (end%graph_size).unsqueeze(1)], axis=1)


    h = torch.sqrt(torch.square(goal-pos).sum(axis=-1))#get_heuristic(goal, size)
    h = h * (1+(cost_maps)*1e-2)#.squeeze(-1)
    g = torch.zeros_like(cost_maps)#self.g


    parents = (torch.ones_like(cost_maps) * goal_maps.max(-1, keepdim=True)[-1])

    training = True
    Tmax = 1#0.25 if training else 1.
    Tmax = int(Tmax * size)
    g_ratio = 0.5

    t_exp = 0

    for t in range(Tmax):

        # select the node that minimizes cost
        #print(g.shape, h.shape)
        f = g_ratio * g + (1 - g_ratio) * h
        f_exp = torch.exp(-1 * f / size)
        f_exp = f_exp * open_maps
        selected_node_maps = _st_softmax_noexp(f_exp)
        #print(selected_node_maps.shape)
        #if store_intermediate_results:
        #    intermediate_results.append({
        #        "histories":
        #        histories.unsqueeze(1).detach(),
        #        "paths":
        #        selected_node_maps.unsqueeze(1).detach()
        #    })

        # break if arriving at the goal
        dist_to_goal = (selected_node_maps * goal_maps).sum(axis=-1, keepdim=True)
        is_unsolved = (dist_to_goal < 1e-8).float()

        histories = histories + selected_node_maps
        histories = torch.clamp(histories, 0, 1)
        open_maps = open_maps - is_unsolved * selected_node_maps
        open_maps = torch.clamp(open_maps, 0, 1)

        # open neighboring nodes, add them to the openlist if they satisfy certain requirements
        #raise Exception()
        #t1 = time.time()
        neighbor_nodes = expand_graph(selected_node_maps, edge_index)
        #neighbor_nodes = neighbor_nodes.reshape(num_samples, graph_size+2, graph_size+2)


        # update g if one of the following conditions is met
        # 1) neighbor is not in the close list (1 - histories) nor in the open list (1 - open_maps)
        # 2) neighbor is in the open list but g < g2
        cost_expand = (g + (cost_maps) * 1e-2) * selected_node_maps
        cost_expand = expand_graph(cost_expand, edge_index)
        g2 = cost_expand#expand(, neighbor_filter)
        idx = (1 - open_maps) * (1 - histories) + open_maps * (g > g2)
        idx = idx * neighbor_nodes
        idx = idx.detach()
        g = g2 * idx + g * (1 - idx)
        g = g.detach()
        # update open maps
        open_maps = torch.clamp(open_maps + idx, 0, 1)
        open_maps = open_maps.detach()

        # for backtracking
        idx = idx.reshape(num_samples, -1)
        snm = selected_node_maps.reshape(num_samples, -1)
        new_parents = snm.max(-1, keepdim=True)[-1]
        #print(parents, new_parents * idx, idx.dtype)
        parents = new_parents * idx + parents * (1 - idx)
        if torch.all(is_unsolved.flatten() == 0):
            break

    if not back:
        return histories, None
    #t_e = time.time()
    parents = parents.type(torch.long)
    goal_maps = goal_maps.type(torch.long)
    #start_maps = start_maps.type(torch.long)
    path_maps = goal_maps.type(torch.long)
    num_samples = len(parents)
    loc = ((parents * goal_maps).sum(-1))
    for _ in range(t):
        path_maps[range(path_maps.shape[0]), loc] = 1
        loc = parents[range(parents.shape[0]), loc]
    #print("TIME BACKING: ", time.time() - t_e)
    return histories, path_maps

In [23]:
def find_path_cost(cost_maps, start, end, img, back=True):
    #cost_maps = out.squeeze(-1)*0

    #self.open_maps *= 0
    #self.goal_maps *= 0
    #self.histories *= 0
    #self.g *= 0
    #self.stp *= 0
    #self.stp += 1

    start_maps = start#data.start#[:, 0]
    goal_point = end#data.end#[:, 0]
    obstacles_maps = img#data.x[:,0]#[:, 0]


    num_samples = cost_maps.shape[0]
    #neighbor_filter = self.neighbor_filter
    #neighbor_filter = torch.repeat_interleave(neighbor_filter, num_samples,
    #                                          0)
    size = cost_maps.shape[1]
    obstacles_maps = obstacles_maps.reshape(-1, size)
    open_maps = torch.zeros_like(cost_maps)
    open_maps[range(num_samples), start_maps] = 1
    open_maps.requires_grad_(True)
    goal_maps = torch.zeros_like(cost_maps)#self.goal_maps
    goal_maps[range(num_samples), goal_point] = 1
    goal_maps.requires_grad_(True)
    
    cost_maps = cost_maps.squeeze(-1)

    histories = torch.zeros_like(cost_maps, requires_grad=True)#self.histories
    histories_y = torch.zeros_like(cost_maps, requires_grad=True)#self.histories
    intermediate_results = []

    #h = co#torch.arange(size, device=device)*0 #self.get_heuristic(goal_maps)

    goal = torch.cat([(end//graph_size).unsqueeze(1), (end%graph_size).unsqueeze(1)], axis=1)


    h = get_heuristic(goal, size)
    h = h +  (1+(cost_maps))#.squeeze(-1)
    g = torch.zeros_like(cost_maps)#self.g


    parents = (torch.ones_like(cost_maps) * goal_maps.max(-1, keepdim=True)[-1])

    training = True
    Tmax = 1#0.25 if training else 1.
    Tmax = int(Tmax * size)
    g_ratio = 0.1

    t_exp = 0

    for t in range(Tmax):

        # select the node that minimizes cost
        #print(g.shape, h.shape)
        f = g_ratio * g + (1 - g_ratio) * h
        f_exp = torch.exp((f/size)) + 1e-5
        f_exp = f_exp * open_maps
        selected_node_maps, y_grid = _st_softmax_noexp_ret(f_exp)
        if (y_grid < 0).any():
            raise Exception("LESS THAN ZERO")
        histories_y = histories_y + torch.exp(y_grid) * open_maps
        #print(selected_node_maps.shape)
        #if store_intermediate_results:
        #    intermediate_results.append({
        #        "histories":
        #        histories.unsqueeze(1).detach(),
        #        "paths":
        #        selected_node_maps.unsqueeze(1).detach()
        #    })

        # break if arriving at the goal
        dist_to_goal = (selected_node_maps * goal_maps).sum(axis=-1, keepdim=True)
        is_unsolved = (dist_to_goal < 1e-8).float()

        histories = histories + selected_node_maps
        histories = torch.clamp(histories, 0, 1)
        open_maps = open_maps - is_unsolved * selected_node_maps
        open_maps = torch.clamp(open_maps, 0, 1)

        # open neighboring nodes, add them to the openlist if they satisfy certain requirements
        #raise Exception()
        #t1 = time.time()
        gcpy = selected_node_maps.reshape(num_samples, graph_size, graph_size)
        gcpy = F.pad(gcpy, (1, 1, 1, 1))#.reshape(num_samples, -1)
        neighbor_nodes = expand2d(gcpy, neighbor_kernel)
        #neighbor_nodes = neighbor_nodes.reshape(num_samples, graph_size+2, graph_size+2)
        neighbor_nodes = neighbor_nodes[:, 1:-1, 1:-1].reshape(num_samples, -1)
        neighbor_nodes = neighbor_nodes * obstacles_maps


        # update g if one of the following conditions is met
        # 1) neighbor is not in the close list (1 - histories) nor in the open list (1 - open_maps)
        # 2) neighbor is in the open list but g < g2
        cost_expand = (g + (cost_maps)) * selected_node_maps
        cost_expand = cost_expand.reshape(num_samples, graph_size, graph_size)
        cost_expand = F.pad(cost_expand, (1, 1, 1, 1))
        cost_expand = expand2d(cost_expand, neighbor_kernel)
        cost_expand = cost_expand[:, 1:-1, 1:-1].reshape(num_samples, -1)
        g2 = cost_expand#expand(, neighbor_filter)
        idx = (1 - open_maps) * (1 - histories) + open_maps * (g > g2)
        idx = idx * neighbor_nodes
        idx = idx.detach()
        g = g2 * idx + g * (1 - idx)
        g = g.detach()
        # update open maps
        open_maps = torch.clamp(open_maps + idx, 0, 1)
        open_maps = open_maps.detach()

        # for backtracking
        idx = idx.reshape(num_samples, -1)
        snm = selected_node_maps.reshape(num_samples, -1)
        new_parents = snm.max(-1, keepdim=True)[-1]
        #print(parents, new_parents * idx, idx.dtype)
        parents = new_parents * idx + parents * (1 - idx)
        if torch.all(is_unsolved.flatten() == 0):
            break

    if not back:
        return histories, None
    #t_e = time.time()
    parents = parents.type(torch.long)
    goal_maps = goal_maps.type(torch.long)
    #start_maps = start_maps.type(torch.long)
    path_maps = goal_maps.type(torch.long)
    num_samples = len(parents)
    loc = ((parents * goal_maps).sum(-1))
    for _ in range(t):
        path_maps[range(path_maps.shape[0]), loc] = 1
        loc = parents[range(parents.shape[0]), loc]
    #print("TIME BACKING: ", time.time() - t_e)
    return histories, path_maps, histories_y

In [24]:


class FocalLoss(nn.Module):
    # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
    def __init__(self, loss_fcn, gamma=2, alpha=0.75):
        super().__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # required to apply FL to each element

    def forward(self, pred, true):
        loss = self.loss_fcn(pred, true)
        # p_t = torch.exp(-loss)
        # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability

        # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
        pred_prob = torch.sigmoid(pred)  # prob from logits
        p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
        alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
        modulating_factor = (1.0 - p_t) ** self.gamma
        loss *= alpha_factor * modulating_factor

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss


In [25]:
import torch_geometric.nn as pyn

class DiscreteSingleEdgeMessaging():
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super(DiscreteSpatialGraphConv, self).__init__(aggr='add')
        self.dropout = dropout
        self.kernel_size = k
        self.choice = Parameter((torch.rand((in_channels, 1, k,k))-0.5), requires_grad=True)
        self.kernel = Parameter((torch.rand((in_channels, out_channels, k,k))-0.5), requires_grad=True)
        
        #self.bias = Parameter(torch.zeros((out_channels)), requires_grad=True)
        self.in_channels = in_channels

    def forward(self, x, edge_index, pos):
        """
        x - feature matrix of the whole graph [num_nodes, label_dim]
        pos - node position matrix [num_nodes, coors]
        edge_index - graph connectivity [2, num_edges]
        """
        
        print(edge_index, x.size(0))
        edge_index, _ = remove_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes
        
        return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add')  # [N, out_channels, label_dim]
    def message(self, pos_i, pos_j, x_j, x_i, edge_index):
        """
        pos_i [num_edges, coors]
        pos_j [num_edges, coors]
        x_j [num_edges, label_dim]
        """
        
        edge_index = coalesce(edge_index, sort_by_row=False)
        row, col = edge_index
        
        #x_i, x_j = x[row], x[col]
        #pos_i, pos_j = pos[row], pos[col]
        
        relative_pos = pos_j - pos_i  # [n_edges, hidden_size * in_channels]
        #print("dreiction:", relative_pos.max(), relative_pos.min())

        magnitude = (torch.abs(relative_pos).max(axis=-1)[0]).unsqueeze(-1) + 1e-5
        direction = relative_pos / magnitude
        is_self = (magnitude < 0.5).type(torch.int64)

        
        quantized = (torch.clip(torch.round(direction), -self.kernel_size//2, self.kernel_size//2) + self.kernel_size//2).type(torch.int64)* (1 - is_self) + is_self
        
        y_i = torch.einsum("ij, jki -> ik", x_i, self.choice[:, :, 0, 0])# / torch.sqrt(1+magnitude)
        y_j = torch.einsum("ij, jki -> ik", x_j, self.choice[:, :, 0, 0])# / torch.sqrt(1+magnitude)
        
        ey = y_j - y_i
        
        
        #print("result:", result.shape)
        #spatial_scaling = F.relu(self.lin_in(relative_pos))  # [n_edges, hidden_size * in_channels]

        #n_edges = spatial_scaling.size(0)
        # [n_edges, in_channels, ...] * [n_edges, in_channels, 1]
        #result = spatial_scaling.reshape(n_edges, self.in_channels, -1) * x_j.unsqueeze(-1)
        return result#result.view(n_edges, -1)

In [26]:
from torch_geometric.utils import softmax

class DiscreteSpatialGraphConv(MessagePassing):
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super().__init__(aggr='add')
        self.dropout = dropout
        self.kernel_size = k
        self.kernel = Parameter(torch.Tensor(in_channels, in_channels, k, k), requires_grad=True)
        self.bias = Parameter(torch.Tensor(out_channels), requires_grad=True)
        self.kernel.retain_grad()
        self.in_channels = in_channels
        self.kernel = xavier_normal_(self.kernel)
        self.bias = zeros_(self.bias)
        print(self.kernel)
    def forward(self, x, edge_index, pos):
        """
        x - feature matrix of the whole graph [num_nodes, label_dim]
        pos - node position matrix [num_nodes, coors]
        edge_index - graph connectivity [2, num_edges]
        """
        #print(edge_index, x.size(0))
        #edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes
        #print(x.shape, pos.shape, edge_index.shape)
        return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add') + self.bias  # [N, out_channels, label_dim]

    def message(self, pos, x_j):
        """
        pos_i [num_edges, coors]
        pos_j [num_edges, coors]
        x_j [num_edges, label_dim]
        """
        relative_pos = pos # [n_edges, hidden_size * in_channels]
        #print("dreiction:", relative_pos.max(), relative_pos.min())

        magnitude = torch.sqrt((torch.square(relative_pos)).sum(axis=-1)).unsqueeze(-1) + 1e-5
        direction = relative_pos / magnitude
        is_self = (magnitude < 0.5).type(torch.int64)

        
        quantized = (torch.clip(torch.round(direction), -self.kernel_size//2, self.kernel_size//2) + self.kernel_size//2).type(torch.int64)* (1 - is_self) + is_self
        quantized = quantized.detach()
        #print("QUANT:" ,quantized.shape, self.kernel[:, :, quantized[:,0], quantized[:,1]].shape)
        #print("direction: ", direction, "\n rounded dir: ", torch.round(direction), "\n quant: ", quantized)
        #raise Exception("")
        #print(x_j.shape, self.kernel[:, :, quantized[:,0], quantized[:,1]].shape)
        result = torch.einsum("ij, jki -> ik", x_j, self.kernel[:, :, quantized[:,0], quantized[:,1]])# / (1+magnitude)
        
        #ai = x_j.unsqueeze(1)
        #bi = self.kernel[:, :, quantized[:,0], quantized[:,1]].permute(2, 0, 1)
        #result = torch.bmm(ai, bi).squeeze(1)
        #print("result:", result.shape)
        #spatial_scaling = F.relu(self.lin_in(relative_pos))  # [n_edges, hidden_size * in_channels]

        #n_edges = spatial_scaling.size(0)
        # [n_edges, in_channels, ...] * [n_edges, in_channels, 1]
        #result = spatial_scaling.reshape(n_edges, self.in_channels, -1) * x_j.unsqueeze(-1)
        return result#result.view(n_edges, -1)

class DiscreteSpatialEdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super().__init__()
        self.dropout = dropout
        self.kernel_size = k
        self.kernel = Parameter(torch.Tensor(in_channels, in_channels, k, k), requires_grad=True)
        self.bias = Parameter(torch.Tensor(out_channels), requires_grad=True)
        self.scaler = nn.Linear(in_channels*2, out_channels)
        self.kernel.retain_grad()
        self.in_channels = in_channels
        self.kernel = xavier_normal_(self.kernel)
        self.bias = zeros_(self.bias)

    def forward(self, x, edge_index, pos):
        src, dst = edge_index
        x_i, x_j = x[src], x[dst]
        
        
        relative_pos = pos  # [n_edges, hidden_size * in_channels]
        #print("dreiction:", relative_pos.max(), relative_pos.min())

        magnitude = torch.sqrt((torch.square(relative_pos)).sum(axis=-1)).unsqueeze(-1) + 1e-5
        direction = relative_pos / magnitude
        is_self = (magnitude < 0.5).type(torch.int64)

        
        quantized = (torch.clip(torch.round(direction), -self.kernel_size//2, self.kernel_size//2) + self.kernel_size//2).type(torch.int64)* (1 - is_self) + is_self
        quantized = quantized.detach()
        
        x_i = torch.einsum("ij, jki -> ik", x_i, self.kernel[:, :,
                                                             self.kernel_size-quantized[:,0]-1,
                                                             self.kernel_size-quantized[:,1]-1])
        x_j = torch.einsum("ij, jki -> ik", x_j, self.kernel[:, :,
                                                             quantized[:,0],
                                                             quantized[:,1]])
        
        x = self.scaler(torch.cat([x_i,x_j], axis=-1))
        #print(edge_index, x.size(0))
        #edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes
        #print(x.shape, pos.shape, edge_index.shape)
        
        
        return softmax(x, src)

class PathConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super().__init__()
        self.dropout = dropout
        self.kernel_size = k
        #self.kernel = Parameter(torch.Tensor(in_channels, in_channels, k, k), requires_grad=True)
        #self.bias = Parameter(torch.Tensor(out_channels), requires_grad=True)
        self.scaler = nn.Linear(in_channels*3, out_channels)
        #self.kernel.retain_grad()
        self.in_channels = in_channels
        #self.kernel = xavier_normal_(self.kernel)
        #self.bias = zeros_(self.bias)

    def forward(self, x, path_index):
        src, mid, dst = path_index
        x_i, x_m, x_j = x[src], x[mid], x[dst]
        
        
        
        
        
        
        #degrees = degree(x_j)
        #reps = torch.gather(degrees, 0, x_i)
        #cur_x = scatter(torch.ones_like(x_j), x_j, 0, reduce="mean")

        #relative_pos = pos  # [n_edges, hidden_size * in_channels]
        #print("dreiction:", relative_pos.max(), relative_pos.min())

        #magnitude = torch.sqrt((torch.square(relative_pos)).sum(axis=-1)).unsqueeze(-1) + 1e-5
        #direction = relative_pos / magnitude
        #is_self = (magnitude < 0.5).type(torch.int64)

        
        #quantized = (torch.clip(torch.round(direction), -self.kernel_size//2, self.kernel_size//2) + self.kernel_size//2).type(torch.int64)* (1 - is_self) + is_self
        #quantized = quantized.detach()
        
        #x_i = torch.einsum("ij, jki -> ik", x_i, self.kernel[:, :,
        #                                                     self.kernel_size-quantized[:,0]-1,
        #                                                     self.kernel_size-quantized[:,1]-1])
        #x_j = torch.einsum("ij, jki -> ik", x_j, self.kernel[:, :,
        #                                                     quantized[:,0],
        #                                                     quantized[:,1]])
        
        x0 = self.scaler(torch.cat([x_i, x_m, x_j], axis=-1))
        x1 = self.scaler(torch.cat([x_j, x_m, x_i], axis=-1))
        #print(edge_index, x.size(0))
        #edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes
        #print(x.shape, pos.shape, edge_index.shape)
        
        
        return x0 + x1#softmax(x, src)

   
    

class ScoredEdgePooling():
    def __init__(self, min_score=0.1):
        self.min_score = min_score

        

    def forward(self, score, edge_index):
        
        return 



class DiscreteSingleMessageGraphConv(MessagePassing):
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super(DiscreteSingleMessageGraphConv, self).__init__(aggr='add')
        self.dropout = dropout
        self.kernel_size = k
        self.choice = Parameter(data=(torch.rand((in_channels, 1, k,k))-0.5), requires_grad=True)
        self.kernel = Parameter(data=(torch.rand((in_channels, out_channels, k,k))-0.5), requires_grad=True)
        #self.bias = Parameter(torch.zeros((out_channels)), requires_grad=True)
        self.in_channels = in_channels
        self._score = None

    def forward(self, x, pos, edge_index):
        """
        x - feature matrix of the whole graph [num_nodes, label_dim]
        pos - node position matrix [num_nodes, coors]
        edge_index - graph connectivity [2, num_edges]
        """
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes
        out = self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add')
        return out#, self._score  # [N, out_channels, label_dim]

    def message(self, pos_i, pos_j, x_i, x_j, index, ptr, size_i):
        """
        pos_i [num_edges, coors]
        pos_j [num_edges, coors]
        x_j [num_edges, label_dim]
        """
        relative_pos = pos_j - pos_i

        magnitude = (torch.abs(relative_pos).max(axis=-1)[0]).unsqueeze(-1) + 1e-5
        direction = relative_pos / magnitude
        is_self = (magnitude < 0.5).type(torch.int64)

        
        quantized = (torch.clip(torch.round(direction), -self.kernel_size//2, self.kernel_size//2) + self.kernel_size//2).type(torch.int64)* (1 - is_self) + is_self
        


        score = torch.einsum("ij, jki -> ik", x_i+x_j, self.choice[:, :, quantized[:,0], quantized[:,1]])# / torch.sqrt(1+magnitude)
        result = torch.einsum("ij, jki -> ik", x_i+x_j, self.kernel[:, :, quantized[:,0], quantized[:,1]])# / torch.sqrt(1+magnitude)
        alpha = softmax(score, index, ptr, size_i)
        
        self._score = score


        return result * (alpha + 0.05)
    
class DiscreteSpatialGraphConv2(MessagePassing):
    def __init__(self, in_channels, out_channels, k=3, dropout=0):
        """
        coors - dimension of positional descriptors (e.g. 2 for 2D images)
        in_channels - number of the input channels (node features)
        out_channels - number of the output channels (node features)
        hidden_size - number of the inner convolutions
        dropout - dropout rate after the layer
        """
        super(DiscreteSpatialGraphConv2, self).__init__(aggr='add')
        self.dropout = dropout
        self.kernel_size = k
        self.kernel = Parameter((torch.rand((in_channels, out_channels, k,k))-0.5), requires_grad=True)
        #self.bias = Parameter(torch.zeros((out_channels)), requires_grad=True)
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x, pos, edge_index):
        """
        x - feature matrix of the whole graph [num_nodes, label_dim]
        pos - node position matrix [num_nodes, coors]
        edge_index - graph connectivity [2, num_edges]
        """
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # num_edges = num_edges + num_nodes
        x = torch.einsum("ij, jkwh -> ikwh", x, self.kernel).reshape(-1, self.out_channels*self.kernel_size*self.kernel_size)

        return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add')  # [N, out_channels, label_dim]

    def message(self, pos_i, pos_j, x_j):
        """
        pos_i [num_edges, coors]
        pos_j [num_edges, coors]
        x_j [num_edges, label_dim]
        """
        x_j = x_j.reshape(-1, self.out_channels, self.kernel_size, self.kernel_size)
        relative_pos = pos_j - pos_i  # [n_edges, hidden_size * in_channels]
        #print("dreiction:", relative_pos.max(), relative_pos.min())

        magnitude = (torch.abs(relative_pos).max(axis=-1)[0]).unsqueeze(-1) + 1e-5
        direction = relative_pos / magnitude
        is_self = (magnitude < 0.5).type(torch.int64)

        
        quantized = (torch.clip(torch.round(direction), -self.kernel_size//2, self.kernel_size//2) + self.kernel_size//2).type(torch.int64)* (1 - is_self) + is_self
        #print("QUANT:" ,quantized.shape, self.kernel[:, :, quantized[:,0], quantized[:,1]].shape)
        #print(direction, torch.round(direction), quantized)
        #raise Exception("")
        result = x_j[:,:,quantized[:,0], quantized[:,1]]
        
        #print("result:", result.shape)
        #spatial_scaling = F.relu(self.lin_in(relative_pos))  # [n_edges, hidden_size * in_channels]

        #n_edges = spatial_scaling.size(0)
        # [n_edges, in_channels, ...] * [n_edges, in_channels, 1]
        #result = spatial_scaling.reshape(n_edges, self.in_channels, -1) * x_j.unsqueeze(-1)
        return result#result.view(n_edges, -1)


def extract_image_patches(x, kernel, stride=1, dilation=1):
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    
    return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

def hard_sigmoid(x):
    det = ((x >= 0.5).type(x.dtype) - x).detach()
    return x + det

class Mapper(torch.nn.Module):
    def __init__(self, channels=4, depth=5, reps=3, scalers = [], dropout=0.25):
        super().__init__()
        self.dropout = dropout
        inp_channels = channels
        self.conv1 = nn.Conv2d(1, channels, 3, padding="same", bias=False)
        self.bn1 = nn.BatchNorm2d(channels)

        self.depth = depth
        self.reps = reps
        self.cnv = []
        self.bn = []
        self.bttc = []
        self.bttb = []
        
        for d in range(depth):
            inchannels = channels
            channels = channels#2 * channels
            for rep in range(reps-1):
                self.cnv.append(nn.Conv2d(inchannels, channels, 3, 1, padding="same", bias=True))
                self.bn.append(nn.BatchNorm2d(channels))
                inchannels = channels
            self.cnv.append(nn.Conv2d(inchannels, channels, 3, 1, dilation=2, padding="same", bias=True))
            self.bn.append(nn.BatchNorm2d(channels))
        
        
        self.cnv_l = nn.Conv2d(channels, channels, 3, 1, padding="same", bias=True)
        self.bn_l = nn.BatchNorm2d(channels)
        self.cnv = nn.ModuleList(self.cnv)
        self.bn = nn.ModuleList(self.bn)
        
    def forward(self, x, training=True):
        im = x
        x = self.conv1(x)
        #x = im * x
        x = self.bn1(x)
        x = F.leaky_relu(x)
        x = F.dropout(x, self.dropout, training=training)
        
        
        x = x
        #self.scales = []
        for d in range(self.depth):
            #self.scales.append(x)
            #h = x
            for r in range(self.reps):
                layer = self.cnv[d*self.reps+r]
                x = layer(x)
                #x = im * x
                x = self.bn[d*self.reps+r](x)
                x = F.leaky_relu(x)
                x = F.dropout(x, self.dropout, training=training)
            #x = F.max_pool2d(x, (2,2))
            #x = h + x
        x = self.cnv_l(x)
        #x = im * x
        x = self.bn_l(x)
        x = F.leaky_relu(x)
        x = F.dropout(x, self.dropout, training=training)
        return x
"""

class Encoder(torch.nn.Module):

    def __init__(self, channels=32, depth=5, reps=3):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, channels, 3, padding="same")
        self.bn1 = nn.BatchNorm2d(channels)
        
        self.reps = reps
        self.depth = depth
        self.channels = channels

        self.convs = []
        self.cbns = []

        self.gcns = []
        self.sagpool = []
        
        for rep in range(self.reps):
            self.convs.append(nn.Conv2d(1, channels, 3, padding="same"))
            self.cbns.append(nn.BatchNorm2d(channels))
        
        for depth in range(self.depth):
            for rep in range(self.reps):
                self.gcns.append(GCNConv(channels, channels, improved=True))
            self.sagpool.append(SAGPooling(channels, min_score=0.1))

"""


def Convolution(in_channels, out_channels):
    return PointTransformerConv(in_channels, out_channels)
def Convolution2(in_channels, out_channels):
    return PointTransformerConv(in_channels, out_channels)
def FinalConvolution(in_channels, out_channels):
    return PointTransformerConv(in_channels, out_channels)


class Pather(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 pool_ratios=0.25, sum_res=True, act=F.leaky_relu, reps=1):
        super().__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res
        self.reps = reps
        self.limiter = nn.Parameter(torch.Tensor(np.zeros((1))), requires_grad=True)
        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.down_bns = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        
        self.down_convs.append(Convolution(in_channels, channels))
        self.down_bns.append(nn.BatchNorm1d(channels))
        for rep in range(self.reps - 1):
            self.down_convs.append(Convolution(channels, channels))
            self.down_bns.append(nn.BatchNorm1d(channels))
            
        for i in range(depth):
            self.pools.append(SAGPooling(channels, self.pool_ratios[i]))
            for rep in range(self.reps):
                self.down_convs.append(Convolution(channels, channels))
                self.down_bns.append(nn.BatchNorm1d(channels))
                



        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        self.up_bns = torch.nn.ModuleList()
        for i in range(depth):
            for rep in range(self.reps):
                self.up_convs.append(Convolution(channels, channels))
                self.up_bns.append(nn.BatchNorm1d(channels))
            
        for rep in range(self.reps - 1):
            self.up_convs.append(Convolution2(channels, channels))
            self.up_bns.append(nn.BatchNorm1d(channels))
            
        self.up_convs.append(FinalConvolution(channels, out_channels))
        self.up_bns.append(nn.BatchNorm1d(out_channels))
        #self.linear1 = nn.Linear(channels, channels)
        #self.linear2 = nn.Linear(channels, channels)
        #self.reset_parameters()
        self.channels = channels
    def reset_parameters(self):
        for conv in self.down_convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        for conv in self.up_convs:
            conv.reset_parameters()
        for bn in self.up_bns:
            bn.reset_parameters()
        for bn in self.down_bns:
            bn.reset_parameters()
        

    def forward(self, x, edge_index, batch=None, pos=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        edge_weight = x.new_ones(edge_index.size(1))
        #print(x.shape)
        x = self.down_convs[0](x, pos, edge_index)
        #print(x.shape)
        #print("POSITION", pos)
        x = self.down_bns[0](x)
        x = self.act(x)
        
        for rep in range(1, self.reps):
            h = x
            x = self.down_convs[rep](x, pos, edge_index)
            x = self.down_bns[rep](x)
            x = self.act(x)
            x = h + x
        #"""
        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        perms = []
        pos_i = [pos]
        batchs = [batch]
        pos_x = pos
        for i in range(1, self.depth + 1):
            
            #edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
            #                                           x.size(0))
            #print(x.shape, edge_index.shape)
            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
                x, edge_index, edge_weight, batch)
            #print(x.shape, edge_index.shape, batch.shape)
            pos_x = pos_x[perm]
            for rep in range(self.reps):
                x = self.down_convs[rep + i*self.reps](x, pos_x, edge_index)
                #print(x.shape)
                x = self.down_bns[rep + i*self.reps](x)
                x = self.act(x)

            if i < self.depth:
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
                pos_i += [pos_x]
                batchs += [batch]
            perms += [perm]
            
            
        #h = self.linear0(startend)
        #h = self.act(h)
        #h = self.linear1(h)
        #h = self.act(h)
        #h = self.linear2(h)
        #print(x.shape)
        #h = h.squeeze(1)
        #print(batch.shape, x.shape)
        #h = torch.gather(h, 0 , batch.unsqueeze(-1).repeat([1, self.channels]))
        #x = x + h
        for i in range(self.depth):
            j = self.depth - 1 - i
            #print(x.shape, pos_x.shape, batch.shape)
            print(x.shape, pos_x.shape)
            for rep in range(self.reps):
                x = self.up_convs[rep + i*self.reps](x, pos_x, edge_index)
                x = self.up_bns[rep + i*self.reps](x)
                x = self.act(x)
            
            #unp = perms[j]
            #x, edge_index, batch = self.pools[j].unpool(x, unp)
            res = xs[j]
            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            pos_x = pos_i[j]
            perm = perms[j]
            batch = batchs[j]

            up = torch.zeros_like(res)
            up[perm] = x
            #up = x
            
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
        
        h = x
        for rep in range(self.reps - 1):
            x = self.up_convs[rep + self.depth*self.reps](x, pos, edge_index)#[0]
            x = self.up_bns[rep + self.depth*self.reps](x)
            x = self.act(x)
        #"""
        x = h + x
        x = self.up_convs[-1](x, pos, edge_index)#[0]

        return x


    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        #print(edge_index, edge_weight, num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.hidden_channels}, {self.out_channels}, '
                f'depth={self.depth}, pool_ratios={self.pool_ratios})')
    
def pos_2_direction(pos, edge_index):
    return pos[edge_index[0, :]] - pos[edge_index[1, :]]

def GraphNormalizer(channels):
    return pyn.BatchNorm(channels)

class Masker(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 pool_ratios=0.25, sum_res=True, act=F.leaky_relu, reps=1, dropout=0.25):
        super().__init__()
        assert depth >= 1
        self.dropout = dropout
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        if act is None:
            self.act = torch.nn.Identity()
        self.sum_res = sum_res
        self.reps = reps
        self.limiter = nn.Parameter(torch.Tensor(np.zeros((1))), requires_grad=True)
        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.down_bns = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(Convolution(in_channels, channels))
        self.down_bns.append(GraphNormalizer(channels))
        in_chans = channels
        for rep in range(self.reps - 1):
            self.down_convs.append(Convolution(in_chans, channels))
            self.down_bns.append(GraphNormalizer(channels))
            in_chans = channels
        channels = channels
        for i in range(depth):
            self.pools.append(ASAPooling(in_chans, self.pool_ratios[i]))
            for rep in range(self.reps):
                self.down_convs.append(Convolution(in_chans, channels))
                print(i, in_chans, channels)
                self.down_bns.append(GraphNormalizer(channels))
                in_chans = channels
            channels = channels#2 * channels
                
        channels = channels# // 2
        #print("Res", channels)

        #in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        self.up_bns = torch.nn.ModuleList()
        inchannels = channels
        for i in range(depth):
            outchannels = inchannels# // 2
            for rep in range(self.reps):
                self.up_convs.append(Convolution(inchannels, outchannels))
                print(i, inchannels, outchannels)
                self.up_bns.append(GraphNormalizer(outchannels))
                inchannels = outchannels
            
        for rep in range(self.reps - 1):
            self.up_convs.append(Convolution2(outchannels, outchannels))
            self.up_bns.append(GraphNormalizer(outchannels))
            
        self.up_convs.append(FinalConvolution(outchannels, out_channels))
        self.up_bns.append(GraphNormalizer(out_channels))
        
        self.final_edge = DiscreteSpatialEdgeConv(outchannels, 1)
        #self.linear1 = nn.Linear(channels, channels)
        #self.linear2 = nn.Linear(channels, channels)
        self.smd_conv = PathConvolution(out_channels, 1)
        #self.reset_parameters()
        self.channels = channels
    def reset_parameters(self):
        for conv in self.down_convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        for conv in self.up_convs:
            conv.reset_parameters()
        for bn in self.up_bns:
            bn.reset_parameters()
        for bn in self.down_bns:
            bn.reset_parameters()
        

    def forward(self, x, edge_index, loc=None, edge_data=None, batch=None, smds=None, training=True):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        if edge_data is None:
            edge_weight = x.new_ones(edge_index.size(1))
        else:
            edge_weight = edge_data
        
        if not loc is None:
            if loc.shape[-1] == 2:
                loc = F.pad(loc, (0,1))
        #print(x.shape)
        pos = pos_2_direction(loc, edge_index)
        x = self.down_convs[0](x, loc, edge_index)
        #print(x.shape)
        #print("POSITION", pos)
        x = self.down_bns[0](x)
        x = self.act(x)
        h = x
        for rep in range(1, self.reps):
            #h = x
            x = self.down_convs[rep](x, loc, edge_index)
            x = self.down_bns[rep](x)
            x = self.act(x)
            x = F.dropout(x, self.dropout, training=training)
        #x = h + x
        #"""
        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        edge_data = [edge_data]
        locs = [loc]
        perms = []
        batchs = [batch]
        for i in range(1, self.depth + 1):
            
            #edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
            #                                           x.size(0))
            #print(x.shape, edge_index.shape)
            
            x, edge_index, edge_weight, batch, perm = self.pools[i - 1](
                x, edge_index, edge_weight, batch)
            loc = loc[perm]
            pos = pos_2_direction(loc, edge_index)
            #print(x.shape, edge_index.shape, batch.shape)
            for rep in range(self.reps):
                x = self.down_convs[rep + i*self.reps](x, loc, edge_index)
                #print(x.shape)
                x = self.down_bns[rep + i*self.reps](x)
                x = self.act(x)
                x = F.dropout(x, self.dropout, training=training)

            if i < self.depth:
                #print("added", x.shape)
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
                batchs += [batch]
                locs += [loc]
            perms += [perm]
            
            
        #h = self.linear0(startend)
        #h = self.act(h)
        #h = self.linear1(h)
        #h = self.act(h)
        #h = self.linear2(h)
        #print(x.shape)
        #h = h.squeeze(1)
        #print(batch.shape, x.shape)
        #h = torch.gather(h, 0 , batch.unsqueeze(-1).repeat([1, self.channels]))
        #x = x + h
        for i in range(self.depth):
            j = -(i + 1) #self.depth - 1 - i
            for rep in range(self.reps):
                #print(x.shape, loc.shape, batch.shape)
                x = self.up_convs[rep + i*self.reps](x, loc, edge_index)
                x = self.up_bns[rep + i*self.reps](x)
                x = self.act(x)
                x = F.dropout(x, self.dropout, training=training)
            
            #unp = perms[j]
            #x, edge_index, batch = self.pools[j].unpool(x, unp)
            res = xs[j]
            #print(res.shape, x.shape)
            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            perm = perms[j]
            batch = batchs[j]
            loc = locs[j]
            pos = pos_2_direction(loc, edge_index)
            up = torch.zeros_like(res)
            up[perm] = x
            #up = x
            
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
        
        h = x
        for rep in range(self.reps - 1):
            x = self.up_convs[rep + self.depth*self.reps](x, loc, edge_index)#[0]
            x = self.up_bns[rep + self.depth*self.reps](x)
            x = self.act(x)
            x = F.dropout(x, self.dropout, training=training)
        #"""
        #x = h + x
        x = self.up_convs[-1](x, loc, edge_index)#[0]
        x = self.up_bns[-1](x)
        #x = self.act(x)
        #x = F.dropout(x, self.dropout, training=training)
        edge_x = self.final_edge(x, edge_index, pos)
        path_vals = None
        if not (smds is None):
            path_vals = self.smd_conv(x, smds)
        x = self.max_edge(edge_x, edge_index)
        
        
        return x, edge_x, path_vals
    
    def max_edge(self, x, edge_index):
        src, dst = edge_index
        x = scatter(x, dst, 0, reduce="mean")
        return x

    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        #print(edge_index, edge_weight, num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.hidden_channels}, {self.out_channels}, '
                f'depth={self.depth}, pool_ratios={self.pool_ratios})')

class CNNMasker(torch.nn.Module):
    def __init__(self, channels=32, depth=5, reps=3):
        super().__init__()

        self.depth = depth
        self.reps = reps
        
        self.bttc = []
        self.bttb = []
        
    
        self.conv_query = nn.Conv2d(4, channels, 3, padding="same", bias=False)
        self.bn_q = nn.BatchNorm2d(channels)
        
        self.rem_depth = max([(int(np.log2(graph_size // (2**depth))) - 2), 0])

        for rep in range(reps):
            self.bttc.append(nn.Conv2d(channels, channels, 3, 1, padding="same", bias=False))
            self.bttb.append(nn.BatchNorm2d(channels))
        
        for d in range(self.rem_depth):
            for rep in range(reps-1):
                self.bttc.append(nn.Conv2d(channels, channels, 3, 1, padding="same", bias=False))
                self.bttb.append(nn.BatchNorm2d(channels))
            self.bttc.append(nn.Conv2d(channels, channels, 3, 2, padding=1, bias=False))
            self.bttb.append(nn.BatchNorm2d(channels))
            
        for d in range(self.rem_depth):
            self.bttc.append(nn.ConvTranspose2d(channels, channels, 4, 2, padding=1, bias=False))
            self.bttb.append(nn.BatchNorm2d(channels))
            for rep in range(reps-1):
                self.bttc.append(nn.Conv2d(channels, channels, 3, 1, padding="same", bias=False))
                self.bttb.append(nn.BatchNorm2d(channels))
        
        self.bttc.append(nn.Conv2d(channels, 1, 3, 1, padding="same", bias=False))
        self.bttb.append(nn.BatchNorm2d(1))
        
        self.bttc = nn.ModuleList(self.bttc)
        self.bttb = nn.ModuleList(self.bttb)
        
        
    def forward(self, x, query):

        q = self.conv_query(query)
        #q = self.bn_q(q)
        
        x = x + q
        #h = x
        for r in range(self.reps):
            x = self.bttc[r](x)
            #x = self.bttb[r](x)
            x = F.leaky_relu(x)
            #print(r)
        #x  = h + x
        
        a = x
        splits = []
        for r in range(self.rem_depth):
            splits.append(a)
            for d in range(self.reps):
                #print(r, d, self.reps + r * self.reps + d)
                a = self.bttc[self.reps+ r * self.reps + d](a)
                #a = self.bttb[self.reps+ r * self.reps + d](a)
                a = F.leaky_relu(a)
            
        for r in range(self.rem_depth):
            for d in range(self.reps):
                #print(r, d, self.reps + (rem_depth * self.reps) + r * self.reps + d)
                a = self.bttc[self.reps + (self.rem_depth * self.reps) + r * self.reps + d](a)
                #a = self.bttb[self.reps + (self.rem_depth * self.reps) + r * self.reps + d](a)
                a = F.leaky_relu(a)
            #print(a.shape, [split.shape for split in splits])
            a = a + splits[-(r+1)]
            #splits.append(a)
        
            
            
        a = self.bttc[-1](a)
        #a = self.bttb[-1](a)
        mask = torch.sigmoid(a)
        mask = hard_sigmoid(mask)
        #x = x * mask
        
        mask_up = mask
        for d in range(self.depth):
            mask_up = torch.repeat_interleave(torch.repeat_interleave(mask_up, 2, dim=2), 2, dim=3)
        
        return a, mask, mask_up
    
    
    
class CNNPather(torch.nn.Module):
    def __init__(self, channels=32, depth=5, reps=3):
        super().__init__()

        self.depth = depth
        self.reps = reps
        
        self.bttc = []
        self.bttb = []
        
    
        self.conv_query = nn.Conv2d(3, channels, 3, padding="same", bias=False)
        self.bn_q = nn.BatchNorm2d(channels)
        
        for rep in range(reps):
            self.bttc.append(nn.Conv2d(channels, channels, 3, 1, padding="same", bias=False))
            self.bttb.append(nn.BatchNorm2d(channels))
        
        for d in range(self.depth):
            for rep in range(reps-1):
                self.bttc.append(nn.Conv2d(channels, channels, 3, 1, padding="same", bias=False))
                self.bttb.append(nn.BatchNorm2d(channels))
            self.bttc.append(nn.Conv2d(channels, channels, 3, 2, padding=1, bias=False))
            self.bttb.append(nn.BatchNorm2d(channels))
            
        for d in range(self.depth):
            self.bttc.append(nn.ConvTranspose2d(channels, channels, 4, 2, padding=1, bias=False))
            self.bttb.append(nn.BatchNorm2d(channels))
            for rep in range(reps-1):
                self.bttc.append(nn.Conv2d(channels, channels, 3, 1, padding="same", bias=False))
                self.bttb.append(nn.BatchNorm2d(channels))
        
        self.bttc.append(nn.Conv2d(channels, 1, 3, 1, padding="same", bias=False))
        self.bttb.append(nn.BatchNorm2d(1))
        
        self.bttc = nn.ModuleList(self.bttc)
        self.bttb = nn.ModuleList(self.bttb)
        
        
    def forward(self, x):
        x = self.conv_query(x)
        x = self.bn_q(x)
        #h = x
        for r in range(self.reps):
            x = self.bttc[r](x)
            x = self.bttb[r](x)
            x = F.leaky_relu(x)
            #print(r)
        #x  = h + x
        
        a = x
        splits = []
        for r in range(self.depth):
            splits.append(a)
            for d in range(self.reps):
                #print(r, d, self.reps + r * self.reps + d)
                a = self.bttc[self.reps+ r * self.reps + d](a)
                a = self.bttb[self.reps+ r * self.reps + d](a)
                a = F.leaky_relu(a)
            
        for r in range(self.depth):
            for d in range(self.reps):
                #print(r, d, self.reps + (rem_depth * self.reps) + r * self.reps + d)
                a = self.bttc[self.reps + (self.depth * self.reps) + r * self.reps + d](a)
                a = self.bttb[self.reps + (self.depth * self.reps) + r * self.reps + d](a)
                a = F.leaky_relu(a)
            #print(a.shape, [split.shape for split in splits])
            a = a + splits[-(r+1)]
            #splits.append(a)
        
            
            
        a = self.bttc[-1](a)
        
        return a

In [27]:
import torch

a = torch.LongTensor(np.random.randint(0, 10, 100))

In [28]:
degrees = degree(a)

In [29]:
a

tensor([2, 0, 1, 0, 9, 7, 7, 1, 7, 9, 8, 3, 7, 0, 4, 1, 4, 5, 6, 7, 1, 2, 6, 0,
        4, 8, 1, 9, 8, 8, 0, 1, 3, 7, 1, 7, 4, 7, 7, 3, 2, 9, 7, 1, 1, 4, 9, 8,
        8, 7, 1, 9, 5, 0, 0, 2, 4, 6, 8, 4, 8, 6, 7, 6, 6, 4, 9, 5, 6, 5, 5, 9,
        7, 0, 9, 2, 4, 3, 0, 8, 6, 6, 9, 8, 7, 0, 5, 0, 0, 6, 6, 2, 0, 2, 2, 9,
        7, 9, 8, 1])

In [30]:
a

tensor([2, 0, 1, 0, 9, 7, 7, 1, 7, 9, 8, 3, 7, 0, 4, 1, 4, 5, 6, 7, 1, 2, 6, 0,
        4, 8, 1, 9, 8, 8, 0, 1, 3, 7, 1, 7, 4, 7, 7, 3, 2, 9, 7, 1, 1, 4, 9, 8,
        8, 7, 1, 9, 5, 0, 0, 2, 4, 6, 8, 4, 8, 6, 7, 6, 6, 4, 9, 5, 6, 5, 5, 9,
        7, 0, 9, 2, 4, 3, 0, 8, 6, 6, 9, 8, 7, 0, 5, 0, 0, 6, 6, 2, 0, 2, 2, 9,
        7, 9, 8, 1])

In [31]:
degrees

tensor([13., 11.,  8.,  4.,  9.,  6., 11., 15., 11., 12.])

In [32]:
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall

chunk_filters = [torch.ones(1, 1, red, red).to(device) for red in reduction_rate]

encoders = [Mapper(channels=8, depth=d, reps=3).to(device) for d in depth]
maskers = [Masker(in_channels=10, hidden_channels=8, out_channels=8, depth=d, reps=3).to(device) for d in depth]


#encoder = Mapper(channels=8, depth=depth[0], reps=2).to(device)

#pather = CNNPather(channels=128, depth=depth, reps=5).to(device)
#pather = Pather(in_channels=10, hidden_channels=8, out_channels=1, depth=2, reps=2).to(device)

acc = BinaryAccuracy(threshold=0.5, validate_args=False).to(device)
pres = BinaryPrecision(threshold=0.5, validate_args=False).to(device)
rec = BinaryRecall(threshold=0.5, validate_args=False).to(device)

l1_func = nn.L1Loss().to(device)

0 8 8
0 8 8
0 8 8
1 8 8
1 8 8
1 8 8
0 8 8
0 8 8
0 8 8
1 8 8
1 8 8
1 8 8


In [33]:
try:
    mapper.load_state_dict(torch.load("mapper.pt"))
    decoder.load_state_dict(torch.load("decoder.pt"))
except:
    print("No previous weights!")
loss_chunk = FocalLoss(torch.nn.BCEWithLogitsLoss(reduction="none")).to(device)
loss_func = FocalLoss(torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([1]), reduction="none")).to(device)


No previous weights!


In [None]:
import matplotlib.pyplot as plt
from IPython import display
from torch.profiler import profile, record_function, ProfilerActivity
import torchvision.transforms.functional as VF
from tqdm import tqdm
from pyastar2d import astar_path
from pymastar2d import mastar_path

def generate_goal(img, start, end, reduction_rate):
    startrel = (start%reduction_rate)/reduction_rate
    endrel = (end%reduction_rate)/reduction_rate
    startr = start // reduction_rate
    endr = end // reduction_rate
    startrel = startrel.reshape(-1, 2)
    endrel = endrel.reshape(-1, 2)
    startr = startr.reshape(-1, 2)
    endr = endr.reshape(-1, 2)

    goal_map = torch.zeros((img.shape[0]*NUM_PATHS, 4,
                        img.shape[2]//(reduction_rate),
                        img.shape[3]//(reduction_rate)))
    #for p in range(NUM_PATHS):
    goal_map[range(img.shape[0] * NUM_PATHS), :2, startr[:,0], startr[:,1]] += startrel
    goal_map[range(img.shape[0] * NUM_PATHS), :2, endr[:,0], endr[:,1]] += endrel
    goal_map[range(img.shape[0] * NUM_PATHS), 2, startr[:,0], startr[:,1]] += 1#startrel
    goal_map[range(img.shape[0] * NUM_PATHS), 3, endr[:,0], endr[:,1]] += 1#endrel
    goal_map = goal_map.to(device)
    return goal_map
def reduce_graph(edge_index):
    valid_nodes = edge_index[0]
    unique_nodes = np.unique(edge_index)
    mp = np.arange(np.max(unique_nodes)+1)
    new_nodes = np.zeros_like(edge_index)
    key = np.arange(unique_nodes.shape[0])
    mp[unique_nodes] = key
    new_nodes[0,:] = mp[edge_index[0]]
    new_nodes[1,:] = mp[edge_index[1]]
    return new_nodes, unique_nodes

import time
%matplotlib inline  
optimizer = torch.optim.Adam(sum([list(encoder.parameters()) for encoder in encoders], []) +
                                sum([list(masker.parameters()) for masker in maskers], []), lr=1e-2)
#
#optimizer2 = torch.optim.RMSprop(list(encoder.parameters()) + list(pather.parameters()), lr=1e-3)

def data_list_batch(data):
    dat = []
    edge_list = []
    batch = []
    edge_data = []
    smds = []
    en = 0
    for i, d in enumerate(data):
        dat.append(d.x)
        edge_list.append(d.edge_index + en)
        edge_data.append(d.edge_data)
        smds.append(d.smds + en)
        en += d.x.shape[0]
        batch_dim = (torch.zeros((d.x.shape[0]), dtype=torch.int64) + i)
        batch.append(batch_dim)
        
    return Data(x=torch.cat(dat, axis=0).to(device),
                edge_index=torch.cat(edge_list, axis=0).to(device),
                batch=torch.cat(batch, axis=0).to(device),
                edge_data = torch.cat(edge_data, axis=0).to(device),
                smds = torch.cat(smds, axis=0).to(device))

def batch_list_data(x, batch):
    dat = []
    for i in range(batch.max()+1):
        dat.append(x[batch==i])
        
    return dat
def nodedist(a, b):
    (x1, y1) = a//graph_size, a % graph_size
    (x2, y2) = b//graph_size, b % graph_size
    return (((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5)
def wrapped_nodedist(pos_list):
    
    def _nodedist(a, b):
        (x1, y1) = a//graph_size, a % graph_size
        (x2, y2) = b//graph_size, b % graph_size
        return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
    return _nodedist
encoders[0].dropout = 0.1
maskers[0].dropout = 0.1
#find_path = ASTAR(8, graph_size)
graph_sqr = graph_size * graph_size
def visualize(data=None, start=None, end=None, path=None, history=None):

    map = np.zeros((graph_size*graph_size, 3))
    #edges = data.edge_index[0,:]
    #edges = edges[edges<(graph_size*graph_size)]
    #map[edges, :] = 1
    data = data.reshape(-1, 1)
    map[:,:] = data
    #print(path.shape, history.shape, data.x[:graph_size*graph_size].shape)
    path = path * data[:,0]
    history = history * data[:,0]
    empty_map = map[:,:1]
    map[path>0] = [0,0,0]

    history = history>0.5
    map[history] = [1,0.5,0]

    map[path>0] = map[path>0] + [0, 0.5, 0]

    #map = map * (empty_map)
    

    map = map.reshape(graph_size, graph_size, 3)
    
    map[start[0], start[1], :] = [0, 0, 1]
    map[end[0], end[1], :] = [1, 0, 0]
    map = np.clip(map, 0, 1)
    
    return map

#model.train()
buffer = []
noise_effect = 1
for epoch in range(20000):
    tepoch = time.time()
    lss_accum = 0
    lssc_accum = 0
    acc_accum = 0
    pres_accum = 0
    pres_base = 0
    rec_accum = 0
    acc_accum1 = 0
    pres_accum1 = 0
    rec_accum1 = 0
    path_lss = 0
    path_lens = []
    path_olen = []
    path_nosc = []
    path_pure = []
    interfaces = []
    #if len(buffer) < 100:
    for idx, batch in tqdm(enumerate(train_loader)):
        loss = 0
        loss1 = 0
        loss2 = 0
        t0 = time.time()
        optimizer.zero_grad()
        #optimizer2.zero_grad()
        #with torch.autocast("cuda"):
        img = batch[0].to(device)
        start = batch[1]#.to(device)
        end = batch[2]#.to(device)
        labels = batch[3].numpy()
        labels2 = batch[4].numpy()
        assigners = []
        new_ies = []
        smds = [] 
        mapperss = []
        new_pos = []
        interfaces = []
        
        for l, label in enumerate(labels):
            image_snips = np.zeros((enc_size[0], enc_size[0], reduction_rate[0], reduction_rate[0]))
            image_edges = []
            interface = {}
            dat = []
            for i in range(image_snips.shape[0]):
                for j in range(image_snips.shape[0]):
                    image_snips[i,j] = label[image_snips.shape[2]*i:image_snips.shape[2]*(i+1),
                                             image_snips.shape[2]*j:image_snips.shape[2]*(j+1), 0]

            nc = 0        
            for i in range(image_snips.shape[0]):
                for j in range(image_snips.shape[0]): 
                    loc_data = {}
                    for uc, un_cluster in enumerate(np.unique(image_snips[i,j].flatten()).astype(np.int32)):
                        interface[un_cluster] = {"info": [i, j]}
                        dat = [i,j]
                        if un_cluster == 0:
                            continue
                        loc_data[un_cluster] = []
                        if i > 0:
                            filt = image_snips[i,j, 0, :] == un_cluster
                            cond = image_snips[i-1, j, -1, :] * filt
                            pos = (i-1) * image_snips.shape[0] + j
                            locs_co = np.zeros((image_snips.shape[2], 2))
                            locs_co[:,0] = i * image_snips.shape[2]
                            locs_co[:,1] = j * image_snips.shape[2] + np.arange(image_snips.shape[2])
                            
                            for c in np.unique(cond).astype(np.int32):
                                if c > 0:
                                    loc_data[un_cluster].append((c, pos))
                                    if not (c in interface[un_cluster]):
                                        interface[un_cluster][c] = []
                                    interface[un_cluster][c].append(locs_co[cond==c, :])
                        if j > 0:
                            filt = image_snips[i,j, :, 0] == un_cluster
                            cond = image_snips[i, j-1, :, -1] * filt
                            pos = (i) * image_snips.shape[0] + j-1
                            locs_co = np.zeros((image_snips.shape[2], 2))
                            locs_co[:,0] = i * image_snips.shape[2] + np.arange(image_snips.shape[2])
                            locs_co[:,1] = j * image_snips.shape[2] 
                            for c in np.unique(cond).astype(np.int32):
                                if c > 0:
                                    loc_data[un_cluster].append((c, pos))
                                    if not (c in interface[un_cluster]):
                                        interface[un_cluster][c] = []
                                    interface[un_cluster][c].append(locs_co[cond==c, :])
                                    
                        if i < (image_snips.shape[0] - 1):
                            filt = image_snips[i,j, -1, :] == un_cluster
                            cond = image_snips[i+1, j, 0, :] * filt
                            pos = (i+1) * image_snips.shape[0] + j
                            locs_co = np.zeros((image_snips.shape[2], 2))
                            locs_co[:,0] = (i+1) * image_snips.shape[2] - 1
                            locs_co[:,1] = j * image_snips.shape[2] + np.arange(image_snips.shape[2])
                            for c in np.unique(cond).astype(np.int32):
                                if c > 0:
                                    loc_data[un_cluster].append((c, pos))
                                    if not (c in interface[un_cluster]):
                                        interface[un_cluster][c] = []
                                    interface[un_cluster][c].append(locs_co[cond==c, :])
                        if j < (image_snips.shape[0] - 1):
                            filt = image_snips[i,j, :, -1] == un_cluster
                            cond = image_snips[i, j+1, :, 0] * filt
                            pos = (i) * image_snips.shape[0] + j + 1
                            locs_co = np.zeros((image_snips.shape[2], 2))
                            locs_co[:,0] = i * image_snips.shape[2] + np.arange(image_snips.shape[2])
                            locs_co[:,1] = (j+1) * image_snips.shape[2] - 1 
                            for c in np.unique(cond).astype(np.int32):
                                if c > 0:
                                    loc_data[un_cluster].append((c, pos))
                                    if not (c in interface[un_cluster]):
                                        interface[un_cluster][c] = []
                                    interface[un_cluster][c].append(locs_co[cond==c, :])
                    image_edges.append(loc_data)
            interfaces.append(interface)
            new_ie = []
            smd = []
            pos = []
            nc = 0


            for i in range(len(image_edges)):
                for j in image_edges[i]:
                    cont = image_edges[i][j]
                    new_ie.append([j, j])
                    pos_i = i % image_snips.shape[0]
                    pos_j = i // image_snips.shape[0]
                    pos.append([pos_i, pos_j])
                    for k, item in enumerate(cont):
                        new_ie.append([j, item[0]])
                        for l, it2 in enumerate(cont):
                            if k == l:
                                continue
                            smd.append([item[0], j, it2[0]])
                        #xpos_i = item[1] % image_snips.shape[0]
                        #xpos_j = item[1] // image_snips.shape[0]
                        #pos.append([pos_i - xpos_i, pos_j - xpos_j])
            assigners.append(image_edges)
            new_ies.append(new_ie)
            new_pos.append(np.asarray(pos))
            smds.append(np.asarray(smd))

        img = img.reshape(-1, 1, graph_size, graph_size)
        #new_pos = np.asarray(new_pos)
        #start = torch.cat([(data.start//graph_size).unsqueeze(1),
        #                   (data.start%graph_size).unsqueeze(1)], axis=1)

        #goal = torch.cat([(data.end//graph_size).unsqueeze(1),
        #                  (data.end%graph_size).unsqueeze(1)], axis=1)
        
        #keeps = start.cpu()
        #keepe = end.cpu()
        #print("START TIME: ", time.time()-t0)
        enc = encoders[0](img)
        #print("ENCODE TIME: ", time.time()-t0)
        enc = enc.permute(0, 2, 3, 1)
        #enc = enc.reshape(enc.shape[0], -1, enc.shape[-1])
        #enc1 = mappers[1](img)
        #scores = encoder(img)
        startrel = (start%reduction_rate[0])/reduction_rate[0]
        endrel = (end%reduction_rate[0])/reduction_rate[0]
        startr = start // reduction_rate[0]
        endr = end // reduction_rate[0]
        data = []
        loc = []
        start_flat = []
        end_flat = []
        for b in range(enc.shape[0]):
            bdata = []          
            benc = enc[b]
            label = torch.Tensor(labels[b]).to(device)[:,:,0].type(torch.int64)
            #for i in range(1, label.max()+1):
            #    cx = torch.sum(benc[(label == i)], dim=0, keepdim=True)
            #    #print(cx.shape)
            #    bdata.append(cx)
            #cur_x = torch.cat(bdata, axis=0)#enc[b]
            
            index = label.reshape(-1,)#.repeat(1, 8)
            bbenc = benc.reshape(-1, benc.shape[-1])
            #cur_x = torch.zeros((index.max()+1, benc.shape[-1]), dtype=bbenc.dtype).to(device)
            cur_x = scatter(bbenc, index, 0, reduce="mean")
            cur_x = cur_x[1:]
            #raise Exception("calculated")
            
            #print(b, label.max(), cur_x.shape, len(new_pos[b]))
            
            for p in range(NUM_PATHS):
                sstartr = startr[b,p, 0] * image_snips.shape[0] + startr[b,p, 1]
                sendr = endr[b,p, 0] * image_snips.shape[0] + endr[b,p, 1]
                cl = labels[b][start[b,p,0], start[b,p,1]]
                el = labels[b][end[b,p,0], end[b,p,1]]
                goal_selector = torch.zeros((cur_x.shape[0], 2), dtype=cur_x.dtype)
                #print(cl, b, p)
                start_flat.append(cl-1)
                end_flat.append(el-1)
                goal_selector[cl-1, 0] = 1
                goal_selector[el-1, 1] = 1
                dat = Data(x=torch.cat([cur_x, goal_selector.to(device)], axis=1),
                           edge_index=torch.LongTensor(new_ies[b]).to(device)-1,
                           edge_data=torch.Tensor(new_pos[b]),
                          smds = torch.LongTensor(smds[b]).to(device)-1)
                data.append(dat)
        datas = data
        data = data_list_batch(data)
            
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
            a, edge_vals, path_dist = maskers[0](data.x, data.edge_index.T,
                                      data.edge_data, batch=data.batch, smds=data.smds.T)
        #print(edge_vals.shape)
        a_per_batch = batch_list_data(a, data.batch)
        start = start.reshape(-1, 2)
        end = end.reshape(-1, 2)
        flatstart = (start[:, 1] + start[:,0]*graph_size).to(device)
        flatend = (end[:, 1] + end[:,0]*graph_size).to(device)
        new_a = []
        
        small_paths = []
        small_paths_m = []
        pure_paths = []
        costs = []
        #print("time before traverse: ", time.time() - t0)
        edge_cut = 0
        edge_cuts = []
        
        lengths = torch.zeros((data.smds.shape[0])) - 1
        
        for b, ab in enumerate(a_per_batch):
            #print(ab.min())
            ab_n = ab#torch.sigmoid(ab)
            new_edges = datas[b].edge_index.cpu().numpy().T
            G = nx.from_edgelist([(new_edges[0, e], new_edges[1,e])
                                  for e in range(new_edges.shape[1])])
            src, dst = new_edges
            end_cut = edge_cut + new_edges.shape[1]
            edge_cuts.append(edge_vals[edge_cut:end_cut])
            dir_weights = 1 - edge_vals[edge_cut:end_cut].detach().cpu().numpy()#1 - 0.99*((ab_n[dst] - ab_n[src])[:,0].detach().cpu().numpy())
            #if end_cut == edge_vals.shape[0]:
            #    print("done 1")
            
            dir_weights = dir_weights# - dir_weights.min()
            dir_weights = {(new_edges[0, i], new_edges[1,i]):dir_weights[i] for i in range(dir_weights.shape[0])}
            nx.set_edge_attributes(G, dir_weights, "direction")
            #print(G.nodes, int(start_flat[b]), int(end_flat[b]))
            L = nx.shortest_path(G, int(start_flat[b]), int(end_flat[b]), weight="direction")
            zer = torch.zeros_like(ab)
            zer[L] = 1
            zer = zer.detach()
            small_paths.append(zer)
            new_a.append(ab_n * zer + 0.0 * ab_n * (1-zer))
            #EI = new_edges.T
            #pairs = np.asarray([[L[i], L[i+1]] for i in range(len(L)-1)])
            #res_0 = torch.zeros((EI.shape[0],))
            #for i, pair in enumerate(pairs):
            #    loc = np.where((EI == pair).all(axis=-1))[0]
                #print(pair, loc)
            #    res_0[loc] = 1
            #raise Exception("")
            #small_paths_m.append(res_0)
            L = nx.shortest_path(G, int(start_flat[b]), int(end_flat[b]))
            cost_goal = ((datas[b].edge_data - datas[b].edge_data[int(end_flat[b])])**2).sum(axis=-1)
            cost_start = ((datas[b].edge_data - datas[b].edge_data[int(start_flat[b])])**2).sum(axis=-1)
            cost = torch.minimum(cost_start, cost_goal)
            costs.append(cost)
            zer = torch.zeros_like(ab)
            zer[L] = 1
            zer = zer.detach()
            pure_paths.append(zer)
            #print(zer.shape)
            #ab = torch.sigmoid(ab)
            edge_cut = end_cut
        #a_per_batch = new_a
            
        mask_up = torch.zeros((enc.shape[0] * NUM_PATHS, graph_size, graph_size)).to(device)
        temp_map = torch.zeros((reduction_rate[0], reduction_rate[0])).to(device)
        for b in range(mask_up.shape[0]):
            label = torch.Tensor(labels[b//NUM_PATHS]).to(device)
            ab = new_a[b]
            src = torch.cat([ab[:1]*0, ab], axis=0)
            #raise Exception("")
            #for i in range(ab.shape[0]):
            #    mask_temp = (label==(i+1))[:,:,0]
            index = label.type(torch.int64).reshape(-1, 1)
            mask_up[b] = torch.gather(src, 0, index.to(device)).reshape(graph_size, graph_size)
        #mask_up = VF.gaussian_blur(mask_up, 15, 3)
        mask_up = mask_up.reshape(img.shape[0] * NUM_PATHS, graph_sqr)
        
        
        mask_small = torch.zeros((enc.shape[0] * NUM_PATHS, graph_size, graph_size)).to(device)
        temp_map = torch.zeros((reduction_rate[0], reduction_rate[0])).to(device)
        for b in range(mask_up.shape[0]):
            label = torch.Tensor(labels[b//NUM_PATHS]).to(device)
            ab = small_paths[b]
            src = torch.cat([ab[:1]*0, ab], axis=0)
            index = label.type(torch.int64).reshape(-1, 1)
            mask_small[b] = torch.gather(src, 0, index).reshape(graph_size, graph_size)
            #for i in range(ab.shape[0]):
            #    mask_temp = (label==(i+1))[:,:,0]
            #    mask_small[b,mask_temp] = ab[i]
        
        mask_small = mask_small.reshape(img.shape[0] * NUM_PATHS, graph_sqr).cpu()
        
        mask_pure = torch.zeros((enc.shape[0] * NUM_PATHS, graph_size, graph_size)).to("cpu")
        temp_map = torch.zeros((reduction_rate[0], reduction_rate[0])).to("cpu")
        for b in range(mask_up.shape[0]):
            label = torch.Tensor(labels[b//NUM_PATHS]).to("cpu")
            ab = pure_paths[b].to("cpu")
            src = torch.cat([ab[:1]*0, ab], axis=0)
            index = label.type(torch.int64).reshape(-1, 1)
            mask_pure[b] = torch.gather(src, 0, index).reshape(graph_size, graph_size)
            #for i in range(ab.shape[0]):
            #    mask_temp = (label==(i+1))[:,:,0]
            #    mask_pure[b,mask_temp] = (ab[i].to("cpu"))
        
        mask_pure = mask_pure.reshape(img.shape[0] * NUM_PATHS, graph_sqr).cpu()
        #mask_up1 = mask_up1.reshape(img.shape[0] * NUM_PATHS, graph_sqr)
        
        img = img.reshape(BATCH_SIZE, 1, graph_sqr).repeat([1, NUM_PATHS, 1])
        
        #print("BEFORE PATH: ", time.time()- t0)
        
        pmask_up = torch.zeros((mask_up.shape[0], graph_sqr))
        pfollowed = torch.zeros((mask_up.shape[0], graph_sqr))
        edge_res = []
        path_best = []
        path_algo = []
        path_weights = []
        bp = time.time()
        smd_res = []
        smd_masks = []
        for p in range(mask_up.shape[0]):
            label = torch.Tensor(labels[p//img.shape[1]])#.reshape(-1,)#.to(device)
            edge_values = edge_cuts[p]
            new_edges = datas[p].edge_index.cpu().numpy().T
            #raise Exception()
            cl = labels2[p//img.shape[1]][start[p,0], start[p,1]]
            el = labels2[p//img.shape[1]][end[p,0], end[p,1]]
            #print(start[p], end[p], cl, el)
            #grid = labels2[p//img.shape[1],:,:,0] == cl
            #inf_mask = grid == False
            #grid = (grid*1.0).astype(np.float32)
            #grid[inf_mask] = np.inf
            #ab_ = (mask_up[p]).unsqueeze(-1)
            #dist_vals = ab_#[dst] * ( 1 - 1*(ab_[dst] == ab_[src]))
            #dir_weights = ((1 - (dist_vals)[:,0].detach().cpu().numpy())*100) ** 2
            #dir_weights = dir_weights.reshape(grid.shape[0], grid.shape[1])
            #locs = np.where((labels2[p//img.shape[1],:,:,0] == cl).reshape(-1,))[0]
            #mask = np.isin(edge_maps, locs).all(axis=0)
            #new_edges = edge_maps[:,mask]
            #src, dst = new_edges
            
            #dir_weights = {(new_edges[0, i], new_edges[1,i]):dir_weights[i] for i in range(dir_weights.shape[0])}
            #G = nx.from_edgelist([(new_edges[0, e], new_edges[1,e]) for e in range(new_edges.shape[1])])
            #nx.set_edge_attributes(G, dir_weights, "direction")
            #PL = astar_path(grid, np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            PL = astar_path((label.reshape(graph_size, graph_size).numpy()>0).astype(np.int32),
                       new_edges.T.astype(np.int32),
                       1-0*edge_values[:,0].detach().cpu().numpy(),
                       np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            PL = PL[np.where((PL == np.asarray(start[p])).all(axis=-1))[0][0]:]
            L = PL[:,0] * graph_size + PL[:,1]
            path_nosc.append(len(L))
            path_best.append(len(L))
            #raise Exception("")
            
            #grid = grid + dir_weights
            
            PL = astar_path(label[:,:,0].numpy().astype(np.int32),
                       new_edges.T.astype(np.int32),
                       (1*(1-edge_values[:,0].detach().cpu().numpy()))**2,
                       np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            PL = PL[np.where((PL == np.asarray(start[p])).all(axis=-1))[0][0]:]
            #PL = astar_path(grid, np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            L = PL[:,0] * graph_size + PL[:,1]
            
            #L = nx.astar_path(G, int(flatstart[p].cpu()), int(flatend[p].cpu()), nodedist, weight="direction")
            pairs = np.asarray([[L[i], L[i+1]] for i in range(len(L)-1)])
            EI = datas[p].edge_index
            res_0 = torch.zeros((EI.shape[0],))
            #if len(pairs) < 2:
            #    print(pairs)
            #datas[p].smds
            label = label.reshape(-1)
            for i, pair in enumerate(pairs):
                c_0 = label[pair[0]]
                c_1 = label[pair[1]]
                loc = torch.where((EI.cpu() == torch.Tensor([c_0, c_1])).all(axis=-1))[0]
                
                res_0[loc] = 1
                #raise Exception("")
                #EI[:,0] == c_0) &
            path_lens.append(len(L))
            pmask_up[p, L] = 1
            #path_nosc.append(len(L))
            #path_olen.append(len(L))
            #path_pure.append(len(L))
            edge_res.append(res_0)
            #"""
            
            grid = (labels2[p//img.shape[1],:,:,0] == cl)
            grid = mask_small[p].reshape(graph_size, graph_size).cpu().numpy() * grid
            #inf_mask = grid == False
            grid = (grid*1).astype(np.int32)
            #grid[inf_mask] = np.inf
            #locs = np.where( * ((labels2[p//img.shape[1],:,:,0]) == cl).reshape(-1,))[0]
            #mask = np.isin(edge_maps, locs).all(axis=0)
            #new_edges = edge_maps[:,mask]
            
            #G = nx.from_edgelist([(new_edges[0, e], new_edges[1,e]) for e in range(new_edges.shape[1])])
            
            #L = nx.astar_path(G, int(flatstart[p].cpu()), int(flatend[p].cpu()), nodedist)
            PL = astar_path(grid,
                       new_edges.T.astype(np.int32),
                       1-0*edge_values[:,0].detach().cpu().numpy(),
                       np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            PL = PL[np.where((PL == np.asarray(start[p])).all(axis=-1))[0][0]:]
            #PL = astar_path(grid, np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            L = PL[:,0] * graph_size + PL[:,1]
            path_olen.append(len(L))
            path_algo.append(len(L))
            pfollowed[p, L] = 1
            
            grid = (labels2[p//img.shape[1],:,:,0] == cl)
            grid = mask_pure[p].reshape(graph_size, graph_size).cpu().numpy() * grid
            #inf_mask = grid == False
            grid = (grid*1).astype(np.int32)
            #grid[inf_mask] = np.inf
            
            PL = astar_path(grid,
                       new_edges.T.astype(np.int32),
                       1-0*edge_values[:,0].detach().cpu().numpy(),
                       np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            PL = PL[np.where((PL == np.asarray(start[p])).all(axis=-1))[0][0]:]
            #PL = astar_path(grid, np.asarray(start[p]), np.asarray(end[p]), allow_diagonal=False)
            L = PL[:,0] * graph_size + PL[:,1]
            #locs = np.where(mask_pure[p] * ((labels2[p//img.shape[1],:,:,0]) == cl).reshape(-1,))[0]
            #mask = np.isin(edge_maps, locs).all(axis=0)
            #new_edges = edge_maps[:,mask]
            
            #G = nx.from_edgelist([(new_edges[0, e], new_edges[1,e]) for e in range(new_edges.shape[1])])
            
            #L = nx.astar_path(G, int(flatstart[p].cpu()), int(flatend[p].cpu()), nodedist)
            path_pure.append(len(L))
            path_weights.append(np.abs(path_best[p] - path_algo[p]) * ((res_0*0 + 1)))
            #"""
            iface = interfaces[p//img.shape[1]]
            cur_smd = datas[p].smds
            label = torch.Tensor(labels[p//img.shape[1]])
            smd_length = torch.zeros((cur_smd.shape[0]))
            smd_mask = torch.zeros((cur_smd.shape[0]))
            for s, smd in enumerate(cur_smd):
                if np.random.uniform() < 0.1:
                    cif = iface[int(smd[1])+1]
                    ix, iy = cif["info"]
                    starts = np.concatenate(cif[int(smd[0])+1], axis=0).astype(np.int32)
                    ends = np.concatenate(cif[int(smd[2])+1], axis=0).astype(np.int32)
                    starts[:,0] = starts[:,0] - ix * image_snips.shape[2]
                    starts[:,1] = starts[:,1] - iy * image_snips.shape[2]
                    
                    ends[:,0] = ends[:,0] - ix * image_snips.shape[2]
                    ends[:,1] = ends[:,1] - iy * image_snips.shape[2]
                    
                    snip = label[ix*image_snips.shape[2]:(ix+1)*image_snips.shape[2],
                                 iy*image_snips.shape[2]:(iy+1)*image_snips.shape[2]].numpy()
                    snip_edges = np.ones((1, 2))
                    snip_values = 0*np.random.random(size=1).astype(np.float32)#(np.arange(1,3) / 2).astype(np.float32)#np.zeros((1,))
                    input_map = (snip.reshape(image_snips.shape[2],
                                                   image_snips.shape[2])==(int(smd[1])+1)).astype(np.int32)
                    PL = mastar_path(input_map,
                                      snip_edges.T.astype(np.int32),
                                      1-snip_values,
                                      np.asarray(starts).reshape(-1,2),
                                      np.asarray(ends).reshape(-1,2),
                                      allow_diagonal=False)
                    smd_length[s] = float(len(PL)) / 10
                    smd_mask[s] = 1
                    #raise Exception("")
            smd_res.append(smd_length)
            smd_masks.append(smd_mask)
        #print("AFTER PATH: ", time.time()- bp)
        pmask_up = pmask_up.reshape(-1, 1, graph_size, graph_size)
        
        path_ups = []
        temp_map = torch.zeros((reduction_rate[0], reduction_rate[0])).to(device)
        for b in range(mask_up.shape[0]):
            #path_up = torch.zeros_like(a_per_batch[b])
            label = torch.Tensor(labels[b//NUM_PATHS]).to("cpu").type(torch.int64)
            pb = pmask_up[b].to("cpu")
            
            index = label.reshape(-1,)#.repeat(1, 8)
            bbenc = pb.reshape(-1, 1)#benc.reshape(-1, benc.shape[-1])
            #cur_x = torch.zeros((index.max()+1, benc.shape[-1]), dtype=bbenc.dtype).to(device)
            cur_x = scatter(bbenc, index, 0, reduce="max")
            cur_x = cur_x[1:]
            #for i in range(path_up.shape[0]):
            #    mask_temp = (label == i+1)[:,:,0]
            #    path_up[i] = 1*((mask_temp * pb).any())
            path_ups.append(cur_x > 0)
            
        if True:
            pmask_up1 = torch.zeros((enc.shape[0] * NUM_PATHS, graph_size, graph_size)).to("cpu")
            temp_map = torch.zeros((reduction_rate[0], reduction_rate[0])).to("cpu")
            for b in range(pmask_up.shape[0]):
                label = torch.Tensor(labels[b//NUM_PATHS]).to("cpu")
                ab = path_ups[b].to("cpu")
                src = torch.cat([ab[:1]*0, ab], axis=0)
                index = label.type(torch.int64).reshape(-1, 1)
                pmask_up1[b] = torch.gather(src, 0, index).reshape(graph_size, graph_size)
                
                #for i in range(ab.shape[0]):
                #    mask_temp = (label==(i+1))[:,:,0]
                #    pmask_up1[b,mask_temp] = (ab[i]*1.0).to("cpu")
                    
        #raise Exception("")
        diff = edge_res#path_img.reshape(-1, graph_sqr).type(torch.float32)# * (goal_map[:,2:].sum(axis=1, keepdim=True) < 0.5)
        diffx = torch.cat(diff, axis=0).type(torch.float32).to(device).unsqueeze(-1)
        pw = torch.cat(path_weights, axis=0).type(torch.float32).to(device).unsqueeze(-1)
        #chunk_probs = torch.cat(small_paths_m, axis=0).to(device)#torch.sigmoid(a)
        #goal_mask = ((data.x[:,-2:] > 0)*1.0).max(axis=-1, keepdim=True).values
        #diff_ng = (diff - goal_mask)
        a_n = edge_vals# *  (1 - 0.75*(1 - chunk_probs))#torch.cat(new_a, axis=0)
        #costs = torch.sqrt(torch.cat(costs, axis=0)/16).to(device)
        loss1_ = (diffx * torch.log(torch.clip(a_n, 1e-6, 1 - 1e-6)))#loss_chunk(a_n, diff)# * costs
        loss1 = -1*(loss1_ * (0+(pw))).sum() / data.x.shape[0]# + loss_chunk(a1, path_chunk1)
        
        # reg loss
        smd_distances = torch.cat(smd_res, axis=0).type(torch.float32).unsqueeze(-1).to(device)
        smd_mask = torch.cat(smd_masks, axis=0).type(torch.float32).unsqueeze(-1).to(device)
        loss2 = torch.mean(torch.square(path_dist[smd_mask==1] - smd_distances[smd_mask==1]))
        
        loss = 0*loss1 + loss2
        #loss2 = 
        #print("reached loss")
        loss.backward()
        optimizer.step()
        
        diff = path_ups# * (goal_map[:,2:].sum(axis=1, keepdim=True) < 0.5)
        diff = torch.cat(diff, axis=0).type(torch.float32).to(device)
        chunk_probs = torch.cat(small_paths, axis=0)#torch.sigmoid(a)
        #chunk_probs1 = torch.sigmoid(a1)
        lssc_accum += loss1
        path_lss += loss2
        acc_curr = acc(a_n[diffx==1], diffx[diffx==1])
        #print(acc_curr)
        acc_accum1 += acc_curr# + acc(chunk_probs1, path_chunk1) 
        pres_accum1 += pres(a_n[diffx==1], diffx[diffx==1])# + pres(chunk_probs1, path_chunk1)
        rec_accum1 += rec(a_n[diffx==1], diffx[diffx==1])# + rec(chunk_probs1, path_chunk1)
        #print("AFTER OPTIM: ", time.time() - t0)
        #raise Exception("")
    path_lens = np.asarray(path_lens)
    path_olen = np.asarray(path_olen)
    path_nosc = np.asarray(path_nosc)
    path_pure = np.asarray(path_pure)
    print(f"Epoch Chunk: {epoch} Time: {time.time() - tepoch:.2f} \
    Loss: {lssc_accum/(idx+1):.4f} \
    MPL: {path_lss/(idx+1):.4f} \
    Precision: {pres_accum1/(idx+1):.2f} Recall: {rec_accum1/(idx+1):.2f} \
    Scored Path: {np.mean(path_nosc/path_lens):.3f} \
    Masked Path: {np.mean(path_nosc/path_olen):.3f} \
    A* Masked Path: {np.mean(path_nosc/path_pure):.3f}" )
    #print(f"Epoch: {epoch} Loss: {lss_accum/(l+1)} Accuracy: {acc_accum/(l+1)} Precision: {pres_accum/(l+1)} Recall: {rec_accum/(l+1)}")
    #if True:
    #    vismap = visualize(img[-1,0].cpu(), start[-1].cpu(), end[-1].cpu(), path_sec.cpu()[-1], probs.cpu()[-1])
    #    plt.imshow(vismap)
    #    plt.show()
    #    print("updated")
        #raise Exception("")
    noise_effect *= 0.97
    #raise Exception("")
    

    

  path = pymastar2d.mastar.mastar(
64it [10:33,  9.89s/it]


Epoch Chunk: 0 Time: 633.01     Loss: 0.1734     MPL: 0.0498     Precision: 1.00 Recall: 0.08     Scored Path: 0.995     Masked Path: 0.965     A* Masked Path: 0.945


64it [10:47, 10.11s/it]


Epoch Chunk: 1 Time: 647.19     Loss: 0.1932     MPL: 0.0314     Precision: 1.00 Recall: 0.08     Scored Path: 0.994     Masked Path: 0.960     A* Masked Path: 0.940


64it [10:35,  9.92s/it]


Epoch Chunk: 2 Time: 635.19     Loss: 0.1736     MPL: 0.0309     Precision: 1.00 Recall: 0.09     Scored Path: 0.994     Masked Path: 0.965     A* Masked Path: 0.942


64it [10:52, 10.20s/it]


Epoch Chunk: 3 Time: 652.76     Loss: 0.1774     MPL: 0.0296     Precision: 1.00 Recall: 0.08     Scored Path: 0.994     Masked Path: 0.964     A* Masked Path: 0.943


64it [11:00, 10.33s/it]


Epoch Chunk: 4 Time: 660.97     Loss: 0.1849     MPL: 0.0298     Precision: 1.00 Recall: 0.09     Scored Path: 0.994     Masked Path: 0.964     A* Masked Path: 0.941


64it [10:32,  9.88s/it]


Epoch Chunk: 5 Time: 632.21     Loss: 0.1837     MPL: 0.0294     Precision: 1.00 Recall: 0.09     Scored Path: 0.993     Masked Path: 0.963     A* Masked Path: 0.940


64it [10:17,  9.66s/it]


Epoch Chunk: 6 Time: 617.97     Loss: 0.1827     MPL: 0.0293     Precision: 1.00 Recall: 0.10     Scored Path: 0.995     Masked Path: 0.964     A* Masked Path: 0.943


64it [10:52, 10.20s/it]


Epoch Chunk: 7 Time: 652.59     Loss: 0.1672     MPL: 0.0290     Precision: 1.00 Recall: 0.09     Scored Path: 0.995     Masked Path: 0.966     A* Masked Path: 0.941


54it [08:20,  8.66s/it]

In [None]:
torch.square(path_dist[smd_mask==1] - smd_distances[smd_mask==1])

In [None]:
smd_distances[smd_mask==1]

In [None]:
path_dist[smd_mask==1]

In [None]:
plt.imshow(input_map)

In [None]:
PL

In [None]:
grid,
                       new_edges.T.astype(np.int32),
                       1-0*edge_values[:,0].detach().cpu().numpy(),
                       np.asarray(start[p]), np.asarray(end[p])

In [None]:
from pymastar2d import mastar_path
PL = mastar_path(grid,
                new_edges.T.astype(np.int32),
                1-0*edge_values[:,0].detach().cpu().numpy(),
                np.asarray(start[p]).reshape(1,2), np.asarray(end[p]).reshape(1,2), allow_diagonal=False)

In [None]:
PL = mastar_path(input_map,
                                      snip_edges.T.astype(np.int32),
                                      1-(snip_values/2).astype(np.float32),
                                      np.asarray(starts).reshape(-1,2),
                                      np.asarray(ends).reshape(-1,2),
                                      allow_diagonal=False)

In [None]:
1 - snip_values

In [None]:
end[p]

In [None]:
weights = grid
edges = new_edges.T.astype(np.int32)
trans = 1-0*edge_values[:,0].detach().cpu().numpy()
st = np.asarray(start[p:p+1])
goal = np.asarray(end[p:p+1])
height, width = weights.shape
start_idx = np.ravel_multi_index(st.T, (height, width))
goal_idx = np.ravel_multi_index(goal.T, (height, width))
m = edges.shape[0]
#res = pyastar2d.astar.multi_point_astar(weights.flatten(), edges.flatten(), trans.flatten(),
#                                  height, width, m, start_idx, goal_idx, 1, 1, False, 0)

In [None]:
start_idx.shape

In [None]:
start

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


In [None]:
exp = [path_nosc, path_lens, path_olen, path_pure]
exp = [np.expand_dims(e, axis=-1) for e in exp]
con_path = np.concatenate(exp, axis=1)

In [None]:
con_path

In [None]:
fig = plt.figure(figsize=(12, 128))
gs = fig.add_gridspec(mask_up.shape[0], 4, hspace=0, wspace=0)
axs = gs.subplots(sharex=True, sharey=True)
for b in range(mask_up.shape[0]):
    axs[b,0].imshow(mask_up[b].reshape(graph_size, graph_size).detach().cpu())
    axs[b,1].imshow(pfollowed[b].reshape(graph_size, graph_size).detach().cpu())
    axs[b,2].imshow(pmask_up[b].reshape(graph_size, graph_size).detach().cpu())
    axs[b,3].imshow(pmask_up1[b].reshape(graph_size, graph_size).detach().cpu())
plt.show()

In [None]:
edge_res[0].shape

In [None]:
data.x.shape

In [None]:
(labels[0] == labels[3]).all()

In [None]:
len(list(mappers[0].parameters()))

In [None]:
a = torch.arange(10)*2

In [None]:
torch.gather(a, 0, torch.LongTensor([0,0, 6]))

In [None]:
data

In [None]:
plt.imshow(out.reshape(8, 64, 64).detach().cpu()[0])

In [None]:
map = cv2.imread("bg512-png\AR0011SR.png",0)
map = 1*(cv2.resize(map, (graph_size, graph_size)) > 127)
#map = r.reshape(graph_size, graph_size, 1)
#label, num = scipy.ndimage.label(map)

In [None]:
plt.imshow(map)

In [None]:
d = batch_list_data(a, data.batch)

In [None]:
len(d)

In [None]:
plt.imshow(valid_mask.reshape(64, 64))

In [None]:
out[batch_dim == 0].shape

In [None]:
edge_map

In [None]:
torch.save(mapper.state_dict(), "mapper.pt")
torch.save(decoder.state_dict(), "decoder.pt")


In [None]:
img_graph.shape

In [None]:
path_chunk.shape

In [None]:
plt.imshow(detached[4].reshape(graph_size, graph_size).cpu().detach())

In [None]:
plt.imshow(probs[4].reshape(graph_size, graph_size).cpu().detach())

In [None]:
plt.imshow(chunk_probs[4, 0].cpu().detach() > 0.5)

In [None]:
a_per_batch

In [None]:
torch.masked_select(out, out>-2)

In [None]:

ig = extract_image_patches(img, reduction_rate, stride=reduction_rate)
ig = ig.reshape(img.shape[0], reduction_rate, reduction_rate, enc_size, enc_size)

        
class MaskSelect(torch.nn.Module):
    def __init__(self, ):
        super().__init__()
        
    def forward(self, img_patched, mask):
        batch_size = img_patched.shape[0]
        num_paths = mask.shape[0] // batch_size
        
        path_chunk = chunk_probs.unsqueeze(2).unsqueeze(2)
        
        out = torch.masked_select(ig, chunk_probs>0.5).reshape(path_chunk, path_chunk, -1).permute([2, 0, 1])
        
        out = out.reshape(out.shape[0], -1)
        

In [None]:
plt.imshow((path_chunk[0]>0.5).cpu().detach().reshape(64, 64))

In [None]:
goal_map.shape

In [None]:
enc_up.shape

In [None]:
plt.imshow(path[0].cpu().detach().reshape(32, 32))

In [None]:
plt.imshow((probs[0]>0.5).cpu().detach().reshape(32, 32))

In [None]:
data=data.cpu()
path=pure_hist[0].cpu()
history = probs[0].cpu()
map = np.zeros((graph_size*graph_size, 3))
#edges = data.edge_index[0,:]
#edges = edges[edges<(graph_size*graph_size)]
#map[edges, :] = 1
map[:,:] = data.x[:graph_size*graph_size]
#print(path.shape, history.shape, data.x[:graph_size*graph_size].shape)
path = path * data.x[:graph_size*graph_size,0]
history = history * data.x[:graph_size*graph_size,0]
empty_map = map[:,:1]
map[path>0] = [0,0,0]

history = history>0.5
map[history] = [1,0.5,0]

map[path>0] = map[path>0] + [0, 0.5, 0]

#map = map * (empty_map)
map[data.start[0], :] = [0, 0, 1]
map[data.end[0], :] = [1, 0, 0]
map = np.clip(map, 0, 1)

map = map.reshape(graph_size, graph_size, 3)

In [None]:
def dist(a, b):
    (x1, y1) = a
    (x2, y2) = b
    return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5

t0 =time.time()
glob_grid = nx.grid_2d_graph(512, 512)
glob_grid.remove_nodes_from(np.where(1-img[0,0].reshape(-1).cpu())[0])
nx.astar_path(glob_grid, (427, 166), (105, 480), dist)
print(time.time() - t0)

In [None]:
filters = torch.randn(1, 1, 103)
inputs = torch.randn(32, 1, 2500)

filters = filters.to(device)
inputs = inputs.to(device)

F.conv1d(inputs, filters, padding="same")

In [None]:
torch.Tensor(filters).to(device)

In [None]:
x = selected_node_maps.unsqueeze(1)

In [None]:
n, d= reduce_graph(edge_map.cpu())

In [None]:
n

In [None]:
import matplotlib.pyplot as plt
from IPython import display
import time
%matplotlib inline  
optimizer = torch.optim.RMSprop(list(mapper.parameters()) + list(decoder.parameters()), lr=1e-2)
#find_path = ASTAR(8, graph_size)
TRAIN_LIMIT = 10
def visualize(data=None, start=None, end=None, path=None, history=None):

    map = np.zeros((graph_size*graph_size, 3))
    #edges = data.edge_index[0,:]
    #edges = edges[edges<(graph_size*graph_size)]
    #map[edges, :] = 1
    data = data.reshape(-1, 1)
    map[:,:] = data
    #print(path.shape, history.shape, data.x[:graph_size*graph_size].shape)
    path = path * data[:,0]
    history = history * data[:,0]
    empty_map = map[:,:1]
    map[path>0] = [0,0,0]

    history = history>0.5
    map[history] = [1,0.5,0]

    map[path>0] = map[path>0] + [0, 0.5, 0]

    #map = map * (empty_map)
    

    map = map.reshape(graph_size, graph_size, 3)
    
    map[start[0], start[1], :] = [0, 0, 1]
    map[end[0], end[1], :] = [1, 0, 0]
    map = np.clip(map, 0, 1)
    
    return map

#model.train()
noise_effect = 1
for epoch in range(20000):
    lss_accum = 0
    acc_accum = 0
    pres_accum = 0
    pres_base = 0
    rec_accum = 0
    for l, batch in enumerate(train_loader):
        loss = 0
        optimizer.zero_grad()
        img = batch[0].to(device)
        start = batch[1].to(device)
        end = batch[2].to(device)
        em = batch[3].to(device)
        t0 = time.time()
        img = img.reshape(-1, 1, graph_size, graph_size)
        goal_map = torch.zeros((img.shape[0], 2,
                                img.shape[2]//(reduction_rate),
                                img.shape[3]//(reduction_rate))).to(device)
        
        #start = torch.cat([(data.start//graph_size).unsqueeze(1),
        #                   (data.start%graph_size).unsqueeze(1)], axis=1)
        
        #goal = torch.cat([(data.end//graph_size).unsqueeze(1),
        #                  (data.end%graph_size).unsqueeze(1)], axis=1)
        startrel = (start%reduction_rate)/reduction_rate
        endrel = (end%reduction_rate)/reduction_rate
        startr = start // reduction_rate
        endr = end // reduction_rate
        #keeps = start.cpu()
        #keepe = end.cpu()
        enc = mapper(img)
        if epoch>TRAIN_LIMIT:
            for p in range(NUM_PATHS):
                goal_map = torch.zeros((img.shape[0], 2,
                                    img.shape[2]//(reduction_rate),
                                    img.shape[3]//(reduction_rate))).to(device)
                goal_map[range(img.shape[0]), :, startr[:, p,0], startr[:,p,1]] += startrel[:,p,:] + 1
                goal_map[range(img.shape[0]), :, endr[:,p,0], endr[:,p,1]] += endrel[:,p,:] + 1
                out, chunk_preds = decoder(enc, goal_map, em.squeeze(0), False)
                #out, chunk_preds = model(img, goal_map)
                out = out.reshape(1, -1)
                out = out# + noise_effect*torch.normal(mean=out*0, std=out*0+1).detach()
                probs = torch.sigmoid(out)
                t1 = time.time()
                flatstart = start[:,p, 1] + start[:,p,0]*graph_size
                flatend = end[:,p, 1] + end[:,p,0]*graph_size
                #history, path_fin = find_path(probs, data, True)
                pure_hist, path_sec = find_path(0.01*(1-probs).detach(),
                                                flatstart.detach(),
                                                flatend.detach(),
                                                img.detach(), True)

                path_img = path_sec.reshape(-1, 1, graph_size, graph_size)
                path_chunk = F.conv2d(path_img.type(torch.float32), chunk_filter, padding="valid",
                                      stride = reduction_rate)
                path_chunk = (path_chunk > 0.9).type(torch.float32).detach()
                t2 = time.time()
                #print("TIMES: ", t1-t0, t2-t1)
                detached = path_sec.detach().type(torch.float32)
                loss += 100*loss_func(out, detached) + 100*loss_func(chunk_preds, path_chunk)
            #print(out.shape, data.y.shape)
            acc_accum += acc(probs, detached)
            pres_accum += pres(probs, detached)
            rec_accum += rec(probs, detached)
            lss_accum += loss
            loss.backward()
            optimizer.step()
        else:
            #zero_grid = torch.zeros((1, graph_size*graph_size)).to(device)
            for p in range(NUM_PATHS):
                goal_map = torch.zeros((img.shape[0], 2,
                                    img.shape[2]//(reduction_rate),
                                    img.shape[3]//(reduction_rate))).to(device)
                goal_map[range(img.shape[0]), :, startr[:, p,0], startr[:,p,1]] += startrel[:,p,:] + 1
                goal_map[range(img.shape[0]), :, endr[:,p,0], endr[:,p,1]] += endrel[:,p,:] + 1
                chunk_preds = decoder(enc, goal_map, em.squeeze(0), True)
                out = torch.sigmoid(chunk_preds)
                for d in range(depth):
                    out = torch.repeat_interleave(torch.repeat_interleave(out, 2, dim=2), 2, dim=3)
                #out, chunk_preds = model(img, goal_map)
                out = out.reshape(1, -1)
                #out = out# + noise_effect*torch.normal(mean=out*0, std=out*0+1).detach()
                probs = torch.sigmoid(chunk_preds)
                t1 = time.time()
                flatstart = start[:,p, 1] + start[:,p,0]*graph_size
                flatend = end[:,p, 1] + end[:,p,0]*graph_size
                #history, path_fin = find_path(probs, data, True)
                
                pure_hist, path_sec = find_path(0.1*(1-out).detach(),
                                                flatstart.detach(),
                                                flatend.detach(),
                                                img.detach(), True)

                path_img = path_sec.reshape(-1, 1, graph_size, graph_size)
                path_chunk = F.conv2d(path_img.type(torch.float32), chunk_filter, padding="valid",
                                      stride = reduction_rate)
                path_chunk = (path_chunk > 0.9).type(torch.float32).detach()
                t2 = time.time()
                #print("TIMES: ", t1-t0, t2-t1)
                detached = path_sec.detach().type(torch.float32)
                loss += 100*loss_func(chunk_preds, path_chunk)
            #print(out.shape, data.y.shape)
            acc_accum += acc(probs, path_chunk)
            pres_accum += pres(probs, path_chunk)
            rec_accum += rec(probs, path_chunk)
            lss_accum += loss
            loss.backward()
            optimizer.step()
        #raise Exception("")
    print(f"Epoch: {epoch} Loss: {lss_accum/(l+1)} Accuracy: {acc_accum/(l+1)} Precision: {pres_accum/(l+1)} Recall: {rec_accum/(l+1)}")
    if epoch>TRAIN_LIMIT:
        vismap = visualize(img[0].cpu(), start[0,-1].cpu(), end[0,-1].cpu(), pure_hist.cpu()[0], probs.cpu()[0])
        plt.imshow(vismap)
        plt.show()
        print("updated")
    noise_effect *= 0.97
    