# GraphGeneration: Modeling and design of hierarchical bio-inspired de novo spider web structures using deep learning and additive manufacturing 

## Model 3: Transformer Model - Training and Inference

Spider webs are incredible biological structures, comprising thin but strong silk filament and arranged into highly complex hierarchical architectures with striking mechanical properties (e.g., lightweight but high strength).  While simple 2D orb webs can easily be mimicked, the modeling and synthesis of artificial, bio-inspired 3D-based web structures is challenging, partly due to the rich set of design features. Here we use deep learning as a way to model and synthesize such 3D web structures, where generative models are conditioned based on key geometric parameters (incl.: average edge length, number of nodes, average node degree, and others). To identify construction principles, we use inductive representation sampling of large spider web graphs and develop and train three distinct conditional generative models to accomplish this task: 1) An analog diffusion model with sparse neighbor representation, 2) a discrete diffusion model with full neighbor representation, and 3) an autoregressive transformer architecture with full neighbor representation. We find that all three models can produce complex, de novo bio-inspired spider web mimics and successfully construct samples that meet the design conditioning that reflect key geometric features (including, the number of nodes,   spatial orientation, and edge lengths). We further present an algorithm that assembles inductive samples produced by the generative deep learning models into larger-scale structures based on a series of geometric design targets, including helical forms and parametric curves. 

[1] W. Lu, N.A. Lee, M.J. Buehler, "Modeling and design of hierarchical bio-inspired de novo spider web structures using deep learning and additive manufacturing," PNAS, 120 (31) e2305273120, 2023, 
https://www.pnas.org/doi/10.1073/pnas.2305273120 

In [None]:
import os, sys
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from tqdm.notebook import trange, tqdm
    
import math
import torch

from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader,Dataset
import pandas as pd
import seaborn as sns
import torchvision
 
import matplotlib.pyplot as plt
import numpy as np
 
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from functools import partial, wraps

import ast
import pandas as pd
import numpy as np

from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data
from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader
from torch_geometric  import transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")

In [None]:
num_of_gpus = torch.cuda.device_count()
print(num_of_gpus)

In [None]:
print("Torch version:", torch.__version__) 

In [None]:
def params (model):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable)

### Dataset

In [None]:
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, LinearTransformation
import math
import numbers
import random
from typing import Tuple, Union

class RandomRotateDiffusion(BaseTransform):
    r"""Rotates node positions around a specific axis by a randomly sampled
    factor within a given interval (functional name: :obj:`random_rotate`).

    Args:
        degrees (tuple or float): Rotation interval from which the rotation
            angle is sampled. If :obj:`degrees` is a number instead of a
            tuple, the interval is given by :math:`[-\mathrm{degrees},
            \mathrm{degrees}]`.
        axis (int, optional): The rotation axis. (default: :obj:`0`)
    """
    def __init__(self, degrees: Union[Tuple[float, float], float],
                 axis: int = 0):
        if isinstance(degrees, numbers.Number):
            degrees = (-abs(degrees), abs(degrees))
        assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
        self.degrees = degrees
        self.axis = axis

    def __call__(self, pos):
        degree = math.pi * random.uniform(*self.degrees) / 180.0
        sin, cos = math.sin(degree), math.cos(degree)
        
        #print ("Rotation: ", degree*180)

        if data.pos.size(-1) == 2:
            matrix = [[cos, sin], [-sin, cos]]
        else:
            if self.axis == 0:
                matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
            elif self.axis == 1:
                matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
            else:
                matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
        matrix=torch.Tensor (matrix)        
       # print (matrix)
        pos=  pos @  matrix#.to(pos.device, pos.dtype)
        return pos
    
    #    return LinearTransformationLoc(torch.tensor(matrix))(data)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.degrees}, '
                f'axis={self.axis})')

In [None]:
input, y_data,node_number_list,labels_y_txt, max_length , posdim_emb, max_neighbors = torch.load('dataset_webs_many_medium.pt')

In [None]:
print (input.shape)
print (y_data.shape)
print (node_number_list.shape, max_length)

In [None]:
class RegressionDataset(Dataset):
    
    def __init__(self, X_data, y_data, node_number_list, degrees=0, jitter=0, clamp_neighbors=True,
                enforce_symm=True,):
        self.X_data = X_data
        self.y_data = y_data
        self.node_number_list=node_number_list
        
        self.degrees=degrees
        self.jitter=jitter
        self.enforce_symm=enforce_symm
        
        self.randomrotatex= RandomRotateDiffusion (degrees=self.degrees, axis=0)
        self.randomrotatey= RandomRotateDiffusion (degrees=self.degrees, axis=1)
        self.randomrotatez= RandomRotateDiffusion (degrees=self.degrees, axis=2)
        self.clamp_neighbors=clamp_neighbors
        
    def __getitem__(self, index):
        
        self.X_data_cl=self.X_data.clone()
        resroundL = self.node_number_list [index] #torch.nonzero(self.X_data_cl[index,:,1])[-1]
        #print (resroundL)
        if self.degrees>0:
            #get length of current graph
            resroundL = torch.nonzero(self.X_data_cl[index,:,1])[-1]
            pos=self.X_data_cl[index,:resroundL,1:4]
            #print (pos.shape)
            pos =self.randomrotatex(pos)
            pos =self.randomrotatey(pos)
            pos =self.randomrotatez(pos)
            self.X_data_cl[index,:resroundL,1:4]=pos

            
        if self.jitter >0:

            dx=torch.randn(resroundL)*self.jitter 
            dy=torch.randn(resroundL)*self.jitter 
            dz=torch.randn(resroundL)*self.jitter 
            
            dx=torch.clamp(dx, min=-self.jitter, max=self.jitter ) 
            dy=torch.clamp(dy, min=-self.jitter, max=self.jitter ) 
            dz=torch.clamp(dz, min=-self.jitter, max=self.jitter ) 
            
            self.X_data_cl[index,:resroundL,1]=self.X_data_cl[index,:resroundL,1]+dx
            self.X_data_cl[index,:resroundL,2]=self.X_data_cl[index,:resroundL,2]+dy
            self.X_data_cl[index,:resroundL,3]=self.X_data_cl[index,:resroundL,3]+dz
            
        self.X_data_cl[index,resroundL:,:]=0
            
        dist_matrix=torch.zeros (max_length, max_length)#.to(device)
        
        for i in range (max_length):
             
            #neighbors are stored in result [4,5,..., 9]
            for j in range (max_neighbors):
                
                neigh_j=self.X_data_cl[index, i, 4+j] 
                
                if neigh_j !=0: #zeros are padded values...
                     
                    neighn=  neigh_j.long()

                    if self.clamp_neighbors:
                        neighn=max( 1, min (neighn, max_length) )
                    
                    dist_matrix[i, neighn-1] = 1
                    if self.enforce_symm:
                        
                        dist_matrix[neighn-1, i] = 1
                    
        output=  toranrch.cat( (self.X_data_cl[index,:, :4],   dist_matrix), 1)
    
        return output, self.y_data[index]
        
    def __len__ (self):
        return len(self.X_data)

def scale_data(image2, maxv, minv): #(input[:,:,4:10]-X_min_neigh)/(X_max_neigh-X_min_neigh)*2-1
    return (image2 -minv)/(maxv-minv) * 2. - 1.0
 
def unscale_data(image2, maxv, minv):
    image2=(image2 +1. )/ 2. * (maxv-minv)+minv 
    return image2


def normalize_data (input, y_data, X_min=None, X_max=None, y_min=None, y_max=None,
                   
                    X_max_neigh=None, X_min_neigh=None,
                    Xscale=0):
    if X_min==None:
        X_min=input[:,:,1:4].min() 
    else:
        print ("use provided X_min", X_min)
    if X_max==None:
        X_max=input[:,:,1:4].max() 
    else:
        print ("use provided X_max", X_max)

    input[:,:,1:4]=(input[:,:,1:4]-X_min)/(X_max-X_min)*(2-2*Xscale)-(1-Xscale) #Normalize range -1 to 1

    if y_min==None:
        y_min=[]
        for i in range (y_data.shape[1]):
            y_min.append(y_data[:,i].min())
        
    else:
        print ("use provided y_min", y_min)
    if y_max==None:
        y_max=[]
        for i in range (y_data.shape[1]):
            y_max.append(y_data[:,i].max())
        
    else:
        print ("use provided y_max", y_max)
    for i in range (y_data.shape[1]):
        y_data[:,i]=(y_data[:,i]-y_min[i] )/(y_max[i] -y_min[i])*2-1 #Normalize range -1 to 1
        
    print ("Check y_data after norm  ", y_data.min(), y_data.max())
    return input, y_data, X_min, X_max, y_min, y_max#,X_min_neigh, X_max_neigh

In [None]:
X_scaled, y_data_scaled, X_min, X_max, y_min, y_max = normalize_data (input, y_data,  Xscale=0. )

In [None]:
 X_min, X_max 

In [None]:
y_min, y_max

In [None]:
def get_data_loaders (X_scaled,  y_data_scaled, node_number_list, split=0.1, batch_size_=16):

    X_train, X_test, y_train, y_test, node_number_list_train, node_number_list_test = train_test_split(X_scaled, 
                                                                                                       y_data_scaled ,
                                                                                                       node_number_list,
                                                                                                       test_size=split,random_state=235)


    print (f"Shapes= {X_scaled.shape}, {y_data_scaled.shape}")
    
     
    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)
    train_dataset = RegressionDataset(X_train, y_train, node_number_list_train, degrees=0, jitter=0.0,
                                     enforce_symm=False) #/ynormfac)

    test_dataset = RegressionDataset(X_test,y_test,node_number_list_test, degrees=0, jitter=0.0,
                                    enforce_symm=False)


    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size_, shuffle=True)
    train_loader_noshuffle = DataLoader(dataset=train_dataset, batch_size=batch_size_, shuffle=False)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size_)

    return train_loader,train_loader_noshuffle, test_loader

In [None]:
train_loader,train_loader_noshuffle, test_loader= get_data_loaders (X_scaled,  y_data_scaled,node_number_list, split=0.1, batch_size_=512)

### Build transformer model 

In [None]:
#Based on: https://github.com/lucidrains/parti-pytorch

from typing import List
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn, einsum
import torchvision.transforms as T

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

DEFAULT_T5_NAME=''

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t + eps)

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# classifier free guidance functions

def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# normalization

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 2d relative positional bias

class RelPosBias2d(nn.Module):
    def __init__(self, size, heads):
        super().__init__()
        self.pos_bias = nn.Embedding((2 * size - 1) ** 2, heads)

        arange = torch.arange(size)

        pos = torch.stack(torch.meshgrid(arange, arange, indexing = 'ij'), dim = -1)
        pos = rearrange(pos, '... c -> (...) c')
        rel_pos = rearrange(pos, 'i c -> i 1 c') - rearrange(pos, 'j c -> 1 j c')

        rel_pos = rel_pos + size - 1
        h_rel, w_rel = rel_pos.unbind(dim = -1)
        pos_indices = h_rel * (2 * size - 1) + w_rel
        self.register_buffer('pos_indices', pos_indices)

    def forward(self, qk):
        i, j = qk.shape[-2:]

        bias = self.pos_bias(self.pos_indices[:i, :(j - 1)])
        bias = rearrange(bias, 'i j h -> h i j')

        bias = F.pad(bias, (j - bias.shape[-1], 0), value = 0.) # account for null key / value for classifier free guidance
        return bias

# feedforward

def FeedForward(dim, mult = 4, dropout = 0.):
    dim_hidden = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, dim_hidden, bias = False),
        nn.GELU(),
        LayerNorm(dim_hidden),
        nn.Linear(dim_hidden, dim, bias = False)
    )

# attention

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        causal = False,
        dropout = 0.,
        norm_context = False,
        rel_pos_bias = False,
        encoded_fmap_size = None
    ):
        super().__init__()
        self.causal = causal
        self.scale = dim_head ** -0.5
        self.norm = LayerNorm(dim)

        inner_dim = heads * dim_head
        context_dim = default(context_dim, dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()

        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(dim, inner_dim, bias = False),
            Rearrange('b n (h d) -> b h n d', h = heads)
        )

        # needed for classifier free guidance for transformers
        # by @crowsonkb, adopted by the paper

        self.null_kv = nn.Parameter(torch.randn(dim_head))

        # one-headed key / value attention, from Shazeer's multi-query paper, adopted by Alphacode and PaLM

        self.to_kv = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(context_dim, dim_head, bias = False)
        )

        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

        # positional bias

        self.rel_pos_bias = None

        if rel_pos_bias:
            assert exists(encoded_fmap_size)
            self.rel_pos_bias = RelPosBias2d(encoded_fmap_size, heads)

    def forward(
        self,
        x,
        context = None,
        context_mask = None
    ):
        batch, device = x.shape[0], x.device

        x = self.norm(x)

        q = self.to_q(x) * self.scale

        context = default(context, x)
        context = self.norm_context(context)

        kv = self.to_kv(context)

        null_kv = repeat(self.null_kv, 'd -> b 1 d', b = batch)
        kv = torch.cat((null_kv, kv), dim = 1)

        sim = einsum('b h i d, b j d -> b h i j', q, kv)

        if exists(self.rel_pos_bias):
            pos_bias = self.rel_pos_bias(sim)
            sim = sim + pos_bias

        mask_value = -torch.finfo(sim.dtype).max

        if exists(context_mask):
            context_mask = F.pad(context_mask, (1, 0), value = True)
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)
        out = einsum('b h i j, b j d -> b h i d', attn, kv)

        return self.to_out(out)

In [None]:
#https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/positional_encodings.py

class PositionalEncoding1D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding1D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, tensor):
        """
        :param tensor: A 3d tensor of size (batch_size, x, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, ch)
        """
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1)
        emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
        emb[:, : self.channels] = emb_x

        return emb[None, :, :orig_ch].repeat(batch_size, 1, 1)


class PositionalEncodingPermute1D(nn.Module):
    def __init__(self, channels):
        """
        Accepts (batchsize, ch, x) instead of (batchsize, x, ch)
        """
        super(PositionalEncodingPermute1D, self).__init__()
        self.penc = PositionalEncoding1D(channels)

    def forward(self, tensor):
        tensor = tensor.permute(0, 2, 1)
        enc = self.penc(tensor)
        return enc.permute(0, 2, 1)

    @property
    def org_channels(self):
        return self.penc.org_channels

class PositionalEncoding2D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding2D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 4) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, tensor):
        """
        :param tensor: A 4d tensor of size (batch_size, x, y, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, y, ch)
        """
        if len(tensor.shape) != 4:
            raise RuntimeError("The input tensor has to be 4d!")
        batch_size, x, y, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1)
        emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1)
        emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
            tensor.type()
        )
        emb[:, :, : self.channels] = emb_x
        emb[:, :, self.channels : 2 * self.channels] = emb_y

        return emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)

class PositionalEncodingPermute2D(nn.Module):
    def __init__(self, channels):
        """
        Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch)
        """
        super(PositionalEncodingPermute2D, self).__init__()
        self.penc = PositionalEncoding2D(channels)

    def forward(self, tensor):
        tensor = tensor.permute(0, 2, 3, 1)
        enc = self.penc(tensor)
        return enc.permute(0, 3, 1, 2)

    @property
    def org_channels(self):
        return self.penc.org_channels

class PositionalEncoding3D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding3D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 6) * 2)
        if channels % 2:
            channels += 1
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, tensor):
        """
        :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
        """
        if len(tensor.shape) != 5:
            raise RuntimeError("The input tensor has to be 5d!")
        batch_size, x, y, z, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
        pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
        emb_x = (
            torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1)
            .unsqueeze(1)
            .unsqueeze(1)
        )
        emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1)
        emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1)
        emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
            tensor.type()
        )
        emb[:, :, :, : self.channels] = emb_x
        emb[:, :, :, self.channels : 2 * self.channels] = emb_y
        emb[:, :, :, 2 * self.channels :] = emb_z

        return emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)


class PositionalEncodingPermute3D(nn.Module):
    def __init__(self, channels):
        """
        Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
        """
        super(PositionalEncodingPermute3D, self).__init__()
        self.penc = PositionalEncoding3D(channels)

    def forward(self, tensor):
        tensor = tensor.permute(0, 2, 3, 4, 1)
        enc = self.penc(tensor)
        return enc.permute(0, 4, 1, 2, 3)

    @property
    def org_channels(self):
        return self.penc.org_channels

class FixEncoding(nn.Module):
    """
    :param pos_encoder: instance of PositionalEncoding1D, PositionalEncoding2D or PositionalEncoding3D
    :param shape: shape of input, excluding batch and embedding size
    Example:
    p_enc_2d = FixEncoding(PositionalEncoding2D(32), (x, y)) # for where x and y are the dimensions of your image
    inputs = torch.randn(64, 128, 128, 32) # where x and y are 128, and 64 is the batch size
    p_enc_2d(inputs)
    """

    def __init__(self, pos_encoder, shape):
        super(FixEncoding, self).__init__()
        self.shape = shape
        self.dim = len(shape)
        self.pos_encoder = pos_encoder
        self.pos_encoding = pos_encoder(
            torch.ones(1, *shape, self.pos_encoder.org_channels)
        )
        self.batch_size = 0

    def forward(self, tensor):
        if self.batch_size != tensor.shape[0]:
            self.repeated_pos_encoding = self.pos_encoding.to(tensor.device).repeat(
                tensor.shape[0], *(self.dim + 1) * [1]
            )
            self.batch_size = tensor.shape[0]
        return self.repeated_pos_encoding    

In [None]:
class EmbeddEncDec(nn.Module):
    def __init__(
        self,
            embed_dim_neighbor=6,
            neigh_emb_trainable=False,
            max_norm=1.,#embedding ayer mnormed
            norm_type=2.,
            max_length=1024,
            max_neighbors=5,
        
            
         
    ):
        super().__init__()

        self.embed_dim_neighbor=embed_dim_neighbor
        self.max_neighbors=max_neighbors
        self.neigh_embs=nn.Embedding(max_length+1, self.embed_dim_neighbor, padding_idx=0, max_norm=max_norm, norm_type =norm_type) 
      
        self.neigh_embs.requires_grad = neigh_emb_trainable     
    def encode(self, output_neighbors, plot_hist=False):
        
        
        for i  in range (self.max_neighbors):
            x_neigh_l=output_neighbors[:,i,:] 
           
            x_neigh_l = torch.unsqueeze(x_neigh_l, dim=-1)
            
            x_cc =  self.neigh_embs(x_neigh_l)#.to(device=device)

            x_cc = torch.squeeze(x_cc, 2)
            x_cc=torch.permute(x_cc, (0,2,1)  )
            
            if i==0:
                output= x_cc#torch.cat( (output_xyz, x_cc  ), 1)
                
            else:    
                output= torch.cat( (output, x_cc ), 1)
                
        if plot_hist:
            sns.set_style("whitegrid")
            fig=sns.histplot(output.detach().numpy().flatten() ,bins=100,  )
            fig.set_xlabel( "Embedding values", fontsize = 10 )
             
            plt.show()
        return output
    
    def decode(self, output):
        ind_list=[]
       
        for i  in range (self.max_neighbors):
             
            ll=self.embed_dim_neighbor
            
            out=output[:,i*ll:(i+1)*ll ]
         
            out=torch.permute(out, (0,2,1)  )
            
            indices=invert_embedding (out, self.neigh_embs )
           
            t=torch.Tensor (indices)
           
            if i==0:
                ind_list=t.unsqueeze (1)
            else:
                ind_list=torch.cat((ind_list, t.unsqueeze (1)), 1)
        return ind_list

In [None]:
EncDec=EmbeddEncDec( embed_dim_neighbor=3,
            neigh_emb_trainable=False,
            max_norm=True,
            max_length=max_length,
            max_neighbors=5,)

In [None]:
# classes
def pad_sequence (output_xyz, max_length):         #pad
    output=torch.zeros((output_xyz.shape[0],  output_xyz.shape[1] , max_length)).to(device)
    output[:,:,:output_xyz.shape[2]]=output_xyz #just positions for now....
    return output

class GraphWebTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        ff_mult = 4,
        #position_dim_graph=16,
        
        max_length=1024,
        
        embed_dim_position=64,
        embed_dim_neighbor=6,
        
        neigh_emb_trainable=False,
        max_norm=1., 
            
        predict_neighbors=True, 
        
        pos_fourier_graph_dim=128, 
        
        pos_emb_fourier=True,
        pos_emb_fourier_add=False,
      
        text_embed_dim = 128,
        cond_drop_prob = 0.25,
        max_text_len = 128,
        max_neighbors=5,
        
        use_categorical_for_neighbors = False, #if False, use fixed embeddings and MSE
        
        predict_distance_matrix=False,#if True predict positions and distance matrix
        
    ):
        super().__init__()

        self.predict_distance_matrix=predict_distance_matrix
        
        if predict_distance_matrix:
            predict_neighbors=False #predict just one slab... no embeddigs etc
        self.pos_emb_fourier=pos_emb_fourier
        self.pos_emb_fourier_add=pos_emb_fourier_add
        self.embed_dim_neighbor=embed_dim_neighbor
        self.predict_neighbors=predict_neighbors
        self.pos_fourier_graph_dim=pos_fourier_graph_dim
        self.use_categorical_for_neighbors=use_categorical_for_neighbors
        self.max_neighbors=max_neighbors
        
        
        self.neigh_emb_trainable=neigh_emb_trainable 
        #################################################
        # text conditioning
        self.fc1 = nn.Linear( 1,  text_embed_dim)  # INPUT DIM (last), OUTPUT DIM, last
        
        self.GELUact= nn.GELU()
        if self.pos_emb_fourier:
            if self.pos_emb_fourier_add==False:
                text_embed_dim=text_embed_dim+embed_dim_position
            if self.pos_emb_fourier_add:
                print ("Add pos encoding... ", text_embed_dim, embed_dim_position)
                
            self.p_enc_1d = PositionalEncoding1D(embed_dim_position)    
            
        self.max_text_len = max_text_len
        #################################################
         
        self.max_length=max_length 
        
        self.max_tokens = max_length+1 #there are as many tokens as there is length in a graph
        if self.predict_neighbors:

            if not self.neigh_emb_trainable:
                self.neigh_embs=nn.Embedding(self.max_tokens, #this is neighbor types we can have .. i.e. equatl to number of nodes                                       
                                                    #self.embed_dim_neighbor, padding_idx=0, max_norm=max_norm, norm_type =2) 
                                                    self.embed_dim_neighbor,  max_norm=max_norm, norm_type =2) 
                
                #not trainable
                #self.neigh_embs.requires_grad = neigh_emb_trainable 
                self.neigh_embs.weight.requires_grad = neigh_emb_trainable 
            else:
                self.neigh_embs=nn.Embedding(self.max_tokens, #this is neighbor types we can have .. i.e. equatl to number of nodes                                       
                                                    #self.embed_dim_neighbor, padding_idx=0 ) 
                                                    self.embed_dim_neighbor  ) 
            
        #######################
        # prediction of graphs
        self.pred_dim=3+self.max_neighbors*self.predict_neighbors*embed_dim_neighbor+self.predict_distance_matrix *self.max_length
        if predict_distance_matrix:
            predict_neighbors=False #predict just one slab... no embeddigs etc
        
        if self.use_categorical_for_neighbors:
            self.logits_dim=  3+self.max_neighbors*self.max_tokens
            self.xyz_and_neigbor_dim = 3+self.max_neighbors
            # if use categorical loss for neighbors then pred_dim is 3+ one hot encoding of neighbors X max_neighbors
        else:
            self.logits_dim=self.pred_dim #if use MSE loss pred_dim is 3+embeddig of beighbors
            self.xyz_and_neigbor_dim = self.pred_dim
            
    
        self.p_enc_1d_graph = PositionalEncodingPermute1D(self.pos_fourier_graph_dim)    
            
        self.start_token = nn.Parameter (torch.randn(self.pred_dim+self.pos_fourier_graph_dim))
        print ("Internal pred dim: ", self.pred_dim, "Four graph enc dim: ", pos_fourier_graph_dim,
              "Logits dim: ", self.logits_dim)

    
        assert cond_drop_prob > 0.
        self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb

        # projecting to logits

        self.init_norm = LayerNorm(dim)

        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                #Attention(dim, causal = True, encoded_fmap_size = self.image_encoded_dim, rel_pos_bias = True, 
                #          dim_head = dim_head, heads = heads, dropout = dropout),
                Attention(dim, causal = True, rel_pos_bias = False, 
                          dim_head = dim_head, heads = heads, dropout = dropout),
                
                Attention(dim, context_dim = text_embed_dim, dim_head = dim_head, heads = heads, dropout = dropout),
                FeedForward(dim, mult = ff_mult, dropout = dropout)
            ]))

        self.final_norm = LayerNorm(dim)

        self.to_logits = nn.Linear(dim, self.logits_dim, bias = False)
        
        self.to_dim = nn.Linear( self.pred_dim+self.pos_fourier_graph_dim, dim, bias = False)
        
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        sequences=None,#conditioning
        *,
        cond_scale = 3.,
        text_mask=None,
        filter_thres = 0.9,
         temperature = 1.,
        tokens_to_generate=None,
        use_argmax=True,#True= use argmax, otherwise gumbel
         
    ):
        device = next(self.parameters()).device
      

        if not exists(text_mask):
            text_mask = torch.ones(sequences.shape[:2], dtype = torch.bool).to(device)

        
        
         
        batch = sequences.shape[0]
        if not exists (tokens_to_generate):
            image_seq_len =self.max_length #just set to max length...
        else:
            image_seq_len=tokens_to_generate

        output = torch.empty((batch, self.xyz_and_neigbor_dim, 0), device = device )
        
        
        if self.use_categorical_for_neighbors:
            
        
            for j in tqdm( range(image_seq_len) ):
                
              
                logits  = self.forward_with_cond_scale(
                    sequences = sequences,
                    text_mask = text_mask,
                    output = output,
                    shift_input_depth=0,#usually is one when original training input is used
                )[:, :, -1]
                
               
                xyz_logits= logits[:,:3] 
             
                for i in range ( self.max_neighbors ):
                    neighbor_logits_i=logits[:,
                                           3+i*self.max_tokens:3+(i+1)*self.max_tokens
                                           ]
                   
                    if not use_argmax:

                       
                        filtered_logits = top_k(neighbor_logits_i, thres = filter_thres)

                        
                        sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
                  

                    else:
                        sampled=torch.argmax (neighbor_logits_i, -1)
                     
                    sampled = rearrange(sampled, 'b -> b 1 ')
                    
                    
                    if i==0:
                        output = torch.cat((xyz_logits, sampled), dim = -1)
                    else:
                        output = torch.cat((output, sampled), dim = -1)
                output=output.unsqueeze(2)   
                
                if j==0:
                    output_f=output
                if j>0:
                    output_f=torch.cat((output_f, output), dim = -1)
                    
           
            return output_f
                
    
        if not self.use_categorical_for_neighbors:

           
            for _ in tqdm( range(image_seq_len) ):
                sampled = self.forward_with_cond_scale(
        
                    sequences = sequences,
                    text_mask = text_mask,
                    output = output,
                    encode_graphs=False,# we are looping with encoded data
                    shift_input_depth=0,
                )
              
                sampled=sampled[:, :, -1]#take LAST prediction....

             
                sampled = rearrange(sampled, 'b c -> b c 1')
            
                output = torch.cat((output, sampled), dim = -1)


            if self.predict_neighbors: 
                ind_list=[]
              
                for i  in range (self.max_neighbors):

                    ll=self.embed_dim_neighbor

                    out=output[:,3+i*ll:3+(i+1)*ll ]
                   
                    out=torch.permute(out, (0,2,1)  )
            
                    indices=invert_embedding (out, self.neigh_embs )
                 
                    t=torch.Tensor (indices)
                 
                    if i==0:
                        ind_list=t.unsqueeze (1)
                    else:
                        ind_list=torch.cat((ind_list, t.unsqueeze (1)), 1) 

                
                output=torch.cat((output[:,0:3,:], ind_list.to(device) ), 1) 

            return output         
    def forward_with_cond_scale(self, *args, cond_scale = 3, **kwargs):
    
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
     
        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        
        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        sequences=None,#conditioning
        output=None,
        text_mask = None,
        cond_drop_prob = None,
        
        return_loss = False,
        shift_input_depth=1, #since first deppth is 1,2,3,4,5... 
        encode_graphs=True, #set to False when generationgng 
         
        
        
    ):
        
        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        
        ########################## conditioning #################################### 
        
        cond_x=sequences.float().unsqueeze(2)
        
        cond_x= self.fc1(cond_x)
        
        
        cond_x=self.GELUact(cond_x) 
        
       
        if self.pos_emb_fourier:
           
            pos_fourier_xy=self.p_enc_1d(cond_x) 
          
            if self.pos_emb_fourier_add:
                cond_x=x+pos_fourier_xy
                
         
            else:
                cond_x= torch.cat( (cond_x,   pos_fourier_xy), 2)
        ########################## END conditioning ####################################   
        
        
        
        if not self.predict_neighbors: 
            if self.predict_distance_matrix:
                output= output [:,shift_input_depth:3+shift_input_depth+max_length,:]
                #print ("otuput ", output.shape)
            elif encode_graphs:
                output =output[:,shift_input_depth:shift_input_depth+3, :]
    
            else:
                output =output[:,0:3, :]
                 
        if self.predict_neighbors: 
            
            if encode_graphs:

            
                pos_1=shift_input_depth
                pos_2=shift_input_depth+3
                
               # print (pos_1, pos_2)
                
                output_xyz =output[:,pos_1:pos_2, :]
                output_neighbors =output[:,pos_2:pos_2+self.max_neighbors, :].long()
               
                output= pad_sequence (output_xyz, self.max_length)          

                for i  in range (self.max_neighbors):
                    #grab next neihhor tensor:
                    x_neigh_l=output_neighbors[:,i,:] 


                    x_neigh_l = torch.unsqueeze(x_neigh_l, dim=-1)

                    if self.neigh_emb_trainable:
                        x_cc =  self.neigh_embs(x_neigh_l)
                    else:
                        with torch.no_grad():
                            x_cc =  self.neigh_embs(x_neigh_l)

    
                    x_cc = torch.squeeze(x_cc, 2)
                    x_cc=torch.permute(x_cc, (0,2,1)  )
             
                    if i==0:
                        output= torch.cat( (output_xyz, x_cc  ), 1)

                    else:    
                        output= torch.cat( (output, x_cc ), 1)
                ###########################################################        

        pos_fourier_graph=self.p_enc_1d_graph( torch.ones (output.shape[0],
                                                           self.pos_fourier_graph_dim,
                                                           output.shape[2] ).to(device) ) 
        
      
        output=torch.cat( (output, pos_fourier_graph ), 1)
        
        output=torch.permute(output, (0,2,1)  )
        
        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = output.shape[0])
            
        output = torch.cat((start_tokens, output), dim = 1)
        
        if return_loss:
           
            output, target = output[:, :-1,:], output[:, 1:,:self.logits_dim]
    

        if not exists(text_mask):
            text_mask = torch.ones(cond_x.shape[:2], dtype = torch.bool).to(device)
            
        cond_x, text_mask = map(lambda t: t[:, :self.max_text_len], (cond_x, text_mask))

        batch=output.shape[0]
        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

       
        x = self.to_dim(output)
        x = self.init_norm(x)

        for self_attn, cross_attn, ff in self.layers:
            x = self_attn(x) + x
            x = cross_attn(x, context = cond_x, context_mask = text_mask) + x
            x = ff(x) + x

        x = self.final_norm(x)
        
        logits = self.to_logits(x)
        logits=torch.permute(logits, (0,2,1)  )

        if not return_loss:
            return logits
        
        if self.use_categorical_for_neighbors:

          
            target=torch.permute(target, (0,2,1)  )
           
            loss_xyz = F.mse_loss(
               
                logits[:,:3,:], target[:,:3,:]
            )
            
            loss_neigh=0
           
            
            for i in range ( self.max_neighbors ):
                neighbor_logits_i=logits[:,
                                       3+i*self.max_tokens:3+(i+1)*self.max_tokens, 
                                       :]
                target_i = output_neighbors[:,i,:].long()
                
                rearr_logits_i=neighbor_logits_i
                
                loss_i=F.cross_entropy(
                           rearr_logits_i,
                            target_i,
                          #  ignore_index = 0
                        )

                loss_neigh =loss_neigh+ loss_i
                
            loss=loss_xyz +loss_neigh 
            
            return loss
                

        #OPTION 1: Use MSE for everything, and use embedding 
        if not self.use_categorical_for_neighbors:

        
            target=torch.permute(target, (0,2,1)  )
          
            loss = F.mse_loss(
        
                logits, target
            )

            return loss

### Trainer 

In [None]:
def train_loop (model,
                train_loader,test_loader,
                optimizer=None,
                print_every=10,
                epochs= 300,
                start_ep=0,
                start_step=0,
                train_unet_number=1,
                print_loss=1000,
                plot_unscaled=False,
                save_model=False,
                cond_scales=[7.5], #list of cond scales  
                num_samples=2,  
                 enforce_symmetry=False,
                save_loss_images=False,clamp=False,
                corplot=False,show_neighbors=False,xyz_and_graph=True,
               ):
    
    steps=start_step
    start = time.time()
    

    loss_total=0
    for e in range(1, epochs+1):
            start = time.time()

            torch.cuda.empty_cache()
          
            train_epoch_loss = 0
            model.train()
            
            for item  in train_loader:


                X_train_batch= item[0].to(device)
                y_train_batch=item[1].to(device)

                X_train_batch=torch.permute(X_train_batch, (0,2,1)  )
    
                optimizer.zero_grad()
            
                loss= model(
                        sequences=y_train_batch,#conditioning
                        output=X_train_batch,
                        text_mask = None,
                    
                        return_loss = True,
                    encode_graphs=True,
                        
                )
                
                loss.backward( )
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

                optimizer.step()

                loss_total=loss_total+loss.item()

                if steps % print_every == 0:
                    print(".", end="")

                if steps>0:
                    if steps % print_loss == 0:
                        norm_loss=loss_total/print_loss
                        print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}")

                        loss_list.append (norm_loss)
                        loss_total=0

                        plt.plot (loss_list, label='Loss')
                        plt.legend()
 
                        if save_loss_images:
                            outname = prefix+ f"loss_{e}_{steps}.jpg"
                            plt.savefig(outname, dpi=200)
                        plt.show()
                        
                        sample_loop (model,
                                test_loader,
                                cond_scales=cond_scales, #list of cond scales - each sampled...
                                num_samples=num_samples, #how many samples produced every time tested.....
                                clamp=clamp,corplot=corplot,
                                save_img=save_loss_images,show_neighbors=show_neighbors,
                                flag=steps,xyz_and_graph=xyz_and_graph,  enforce_symmetry=enforce_symmetry
                                    )
                        
                        print (f"\n\n-------------------\nTime passed for {print_loss} epochs at {steps} = {(time.time()-start)/60} mins\n-------------------")
                        start = time.time()
                        if save_model:
                            
                            fname=f"{prefix}statedict_save-model-epoch_{e}.pt"
                            torch.save(model.state_dict(), fname)
                            print (f"Model saved: ", fname)
                
                steps=steps+1


In [None]:
from sklearn.metrics import r2_score

def sample_loop (model,
                train_loader,
                cond_scales=[7.5], #list of cond scales - each sampled...
                num_samples=2, #how many samples produced every time tested.....
              
                 flag=0,clamp=False,
                 corplot=False,
                 save_img=False,
                 show_neighbors=False,
                xyz_and_graph=False,GED=False,
                 filter_thres = 0.9,temperature=1.,
                 enforce_symmetry = False, #if True: make distance matrix symmetric
                 clamp_round_results=True,dist_matrix_threshold=0.25,
                 
               ):
    steps=0
    e=flag
    for item  in train_loader:

            X_train_batch= item[0]
            y_train_batch=item[1].to(device)
            
            X_train_batch=torch.permute(X_train_batch, (0,2,1)  )

            GT=y_train_batch.cpu().detach().unsqueeze(1) 
            
            num_samples = min (num_samples,y_train_batch.shape[0] )
            print (f"Producing {num_samples} samples...")

            for iisample in range (len (cond_scales)):

                result = GWebT.generate(        sequences=y_train_batch,#conditioning
                        cond_scale = cond_scales[iisample],filter_thres = filter_thres,temperature=temperature,
                                        use_argmax=False,
                     ) # conditioning scale for classifier free guidance
                 
                if clamp_round_results:
                    result[:, 3:3+max_length, :]=torch.clamp(result[:, 3:3+max_length, :], 0, 1) 
                
                if enforce_symmetry:
                    result[:, 3:3+max_length, :]=0.5*(result[:, 3:3+max_length, :]+\
                                                      torch.transpose (result[:, 3:3+max_length, :], 1, 2 ) )
                if clamp_round_results:
                    result[:, 3:3+max_length, :]=torch.clamp(result[:, 3:3+max_length, :], 0, 1) 
                
                result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]< dist_matrix_threshold]=0
                result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]>= dist_matrix_threshold]=1

                result=result.cpu() 

                X_train_batch= pad_sequence (X_train_batch,  max_length).cpu()
                if xyz_and_graph:
                    y_data_coll_pred=[]
                    y_data_coll_GT=[]                
                for samples in range  (num_samples):
                    
                   
                    GTroundL = torch.nonzero(X_train_batch[samples, 4:4+max_length, :]).flatten().max()+1
                    resroundL = torch.nonzero(result[samples, 3:3+max_length, :]).flatten().max()+1
                    
                    fig, ax = plt.subplots(1, 2, figsize=(10, 6) , subplot_kw=dict(projection='3d'))
                    
                    xs=X_train_batch[samples, 1, :GTroundL]
                    ys=X_train_batch[samples, 2, :GTroundL]
                    zs=X_train_batch[samples, 3, :GTroundL]
                    m=6
                    ax[0].scatter(xs, ys, zs, c='red', s=m, marker="o")

                    ax[0].set_xlabel('X')
                    ax[0].set_ylabel('Y')
                    ax[0].set_zlabel('Z')
                    ax[0].set_title('GT')
                    ax[0].set_xlim(-1,1)
                    ax[0].set_ylim(-1,1)
                    ax[0].set_zlim(-1,1)

                    xs=result[samples, 0, :resroundL]
                    ys=result[samples, 1, :resroundL]
                    zs=result[samples, 2, :resroundL]
                    m=6
                    ax[1].scatter(xs, ys, zs, c='red', s=m, marker="o")

                    ax[1].set_xlabel('X')
                    ax[1].set_ylabel('Y')
                    ax[1].set_zlabel('Z')
                    ax[1].set_xlim(-1,1)
                    ax[1].set_ylim(-1,1)
                    ax[1].set_zlim(-1,1)                    
                    
                    ax[1].set_title('Prediction')
                    if save_img:
                            outname = prefix+ f"img_{flag}.png"
                            plt.savefig(outname, dpi=300)
                    
                    plt.show()
                    
                    fig, ax = plt.subplots(1, 2, figsize=(10, 6)  )
                    
                     
                    xs=X_train_batch[samples, 1, :GTroundL]
                    ys=X_train_batch[samples, 2, :GTroundL]
                    zs=X_train_batch[samples, 3, :GTroundL]
                    m=2
                    ax[0].plot(xs, ys , 'bo', markersize=m, label='Y over X')
                    ax[0].plot(xs, zs,'ro', markersize=m, label='Z over X')

                    ax[0].set_xlabel('X')
                    ax[0].set_ylabel('Y/Z')
                 
                    ax[0].set_title('GT')
                    ax[0].axis('square')
                    ax[0].set_xlim(-1,1)
                    ax[0].set_ylim(-1,1)
                    ax[0].legend()
                    
                  
                    xs=result[samples, 0, :resroundL]
                    ys=result[samples, 1, :resroundL]
                    zs=result[samples, 2, :resroundL]
                     
                    ax[1].plot(xs, ys ,'bo', markersize=m, label='Y over X')
                    ax[1].plot(xs, zs,'ro', markersize=m, label='Z over X')

                    ax[1].set_xlabel('X')
                    ax[1].set_ylabel('Y/Z')
                    #ax[1].set_zlabel('Z')
                    ax[1].axis('square')
                    ax[1].set_xlim(-1,1)
                    ax[1].set_ylim(-1,1)
                    #ax[1].set_zlim(-1,1)                    
                    ax[1].legend()
                    
                    
                    ax[1].set_title('Prediction')
                    if save_img:
                            outname = prefix+ f"img_proj_{flag}.png"
                            plt.savefig(outname, dpi=300)
                    
                    plt.show()
                    
                    
                    if show_neighbors:
                        fig, ax = plt.subplots(1, 2, figsize=(10, 6)  )
                    
                        ax[1].imshow (result [samples, 3:3+max(GTroundL,resroundL), :max(GTroundL,resroundL)])
                        ax[0].imshow (X_train_batch [samples, 4:4+max(GTroundL,resroundL), :max(GTroundL,resroundL)])
                        ax[1].set_title('Prediction')
                        ax[0].set_title('GT')
                     
                        plt.show()
                        
                        fig, ax = plt.subplots(1, 2, figsize=(10, 6)  )
                    
                        ax[1].imshow (result [samples, 3:3+max_length, :])
                        ax[0].imshow (X_train_batch [samples, 4:4+max_length, :])
                        ax[1].set_title('Prediction')
                        ax[0].set_title('GT')                        
                        #ax[0].grid(False)
                       # ax[1].grid(False)
                        plt.show()
                        
                    if xyz_and_graph:
                        G_res, data_res, y_data_pred =construct_xyz_and_graph (result[samples,:3+resroundL,:resroundL].squeeze(),
                                                 fname_root=f'{prefix}graph_xyz_{samples}_{flag}_{steps}',
                                                                label='Prediction', #limits=[X_min*1.2, X_max*1.2],
                                                                dist_matrix=True)
                        G_GT, data_GT, y_data_GT=construct_xyz_and_graph (X_train_batch[samples, 1:4+GTroundL, :GTroundL].squeeze(),
                                                 fname_root=f'{prefix}graph_GT_xyz_{samples}_{flag}_{steps}',
                                                              label='GT',dist_matrix=True)
                        
                        

                        plt.plot ( y_data_GT, y_data_pred, '.',label='Graph properties (GT vs predicted)',markersize=12 )
                        plt.legend()
                        plt.xlabel ('GT')
                        plt.ylabel ('Predicted')
                        min_v,max_v=min (min(y_data_GT), min (y_data_pred)),max (max(y_data_GT), max (y_data_pred))
                        plt.plot([min_v, max_v], [min_v, max_v], ls="--", c=".3")
                        plt.axis ('square')
                        plt.show()
                        
                        #print (y_max, y_min,GT.shape[2])
                        for i in range (len(y_min)):
                            y_data_pred[i]=scale_data(y_data_pred[i], y_max[i],y_min[i]) 
                            y_data_GT[i]=scale_data(y_data_GT[i], y_max[i],y_min[i]) 
                            
                        plt.plot ( y_data_GT, y_data_pred, '.',label='Graph properties (GT vs predicted)',markersize=12 )
                        plt.legend()
                        plt.xlabel ('GT')
                        plt.ylabel ('Predicted')
                        plt.plot([-1, 1], [-1, 1], ls="--", c=".3")
                        plt.axis ('square')
                        plt.show()
                        
                        y_data_coll_pred.append (y_data_pred)
                        y_data_coll_GT.append (y_data_GT)#.cpu().numpy())
                        
                        if GED:
                            GED=nx.graph_edit_distance (G_res, G_GT)
                            print ("Graph edit distance=", GED)
                        print (f"")
                        
                    if corplot:
                        plt.plot (X_train_batch[samples, 1:4, :].flatten().detach().cpu(), 
                                  result[samples, 0:3, :].flatten() , '.',label='all',markersize=6 )
                        plt.plot (X_train_batch[samples, 1, :].flatten().detach().cpu(), 
                                  result[samples, 0, :].flatten() , '.',label='x',markersize=2 )
                        plt.plot (X_train_batch[samples, 2, :].flatten().detach().cpu(), 
                                  result[samples, 1, :].flatten() , '.',label='y',markersize=2 )
                        plt.plot (X_train_batch[samples, 3, :].flatten().detach().cpu(), 
                                  result[samples, 2, :].flatten() , '.',label='z',markersize=2 )
                        
                        plt.legend()
                        
                        plt.plot([-1, 1], [-1, 1], ls="--", c=".3")
                        plt.axis ('square')
                        plt.show()
                        
                y_data_coll_pred=np.array(y_data_coll_pred).flatten()
                y_data_coll_GT=np.array(y_data_coll_GT).flatten()        
                
              
                R2=r2_score(y_data_coll_GT, y_data_coll_pred)
                print ("OVERALL R2: ", R2)
                plt.plot ( y_data_coll_GT, y_data_coll_pred, '.', label='Graph properties (GT vs predicted)',markersize=3 )
                plt.legend()
                plt.xlabel ('GT')
                plt.ylabel ('Predicted')
                plt.plot([-1, 1], [-1, 1], ls="--", c=".3")
                plt.axis ('square')
                plt.title ("Correlation prediction vs. GT")
                plt.show()                        
                        
            steps=steps+1

In [None]:
def generate_sample_cond (model,
              
                cond_scales=[7.5], #list of cond scales - each sampled...
               # num_samples=2, #how many samples produced every time tested.....
                
                 flag=0,clamp=False,
                 corplot=False,
                 save_img=False,
                 show_neighbors=False,
                xyz_and_graph=False,GED=False,
                cond=[1., .5, 1.],
                  enforce_symmetry = True, #if True: make distance matrix symmetric
                          dist_matrix_threshold=0.25,
               ):
    steps=0
    e=flag

    y_train_batch=torch.Tensor (cond).to(device)
    y_train_batch=y_train_batch.unsqueeze(0) 
  
    for iisample in range (len (cond_scales)):
        
        samples=0
        
        result = GWebT.generate(        sequences=y_train_batch,#conditioning
                cond_scale = cond_scales[iisample],
                                use_argmax=False,
             ) 
        result[:, 3:3+max_length, :]=torch.clamp(result[:, 3:3+max_length, :], 0, 1) 

        if enforce_symmetry:
            result[:, 3:3+max_length, :]=0.5*(result[:, 3:3+max_length, :]+\
                                              torch.transpose (result[:, 3:3+max_length, :], 1, 2 ) )
            result[:, 3:3+max_length, :]=torch.clamp(result[:, 3:3+max_length, :], 0, 1) 

        result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]< dist_matrix_threshold]=0
        result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]>= dist_matrix_threshold]=1

        result=result.cpu() 

        resroundL = torch.nonzero(result[samples, 3:3+max_length, :]).flatten().max()

        fig, ax = plt.subplots(1, 1, figsize=(5, 6) , subplot_kw=dict(projection='3d'))

        ax=[ax]

        xs=result[samples, 0, :resroundL]
        ys=result[samples, 1, :resroundL]
        zs=result[samples, 2, :resroundL]
        m=6
        ax[0].scatter(xs, ys, zs, c='red', s=m, marker="o")

        ax[0].set_xlabel('X')
        ax[0].set_ylabel('Y')
        ax[0].set_zlabel('Z')
        ax[0].set_xlim(-1,1)
        ax[0].set_ylim(-1,1)
        ax[0].set_zlim(-1,1)                    

        ax[0].set_title('Prediction')
        if save_img:
                outname = prefix+ f"img_{flag}.png"
                plt.savefig(outname, dpi=300)

        plt.show()

        fig, ax = plt.subplots(1, 1, figsize=(5, 6)  )

        ax=[ax]

        xs=result[samples, 0, :resroundL]
        ys=result[samples, 1, :resroundL]
        zs=result[samples, 2, :resroundL]

        ax[0].plot(xs, ys ,'bo', markersize=m, label='Y over X')
        ax[0].plot(xs, zs,'ro', markersize=m, label='Z over X')

        ax[0].set_xlabel('X')
        ax[0].set_ylabel('Y/Z')
        #ax[1].set_zlabel('Z')
        ax[0].axis('square')
        ax[0].set_xlim(-1,1)
        ax[0].set_ylim(-1,1)
        #ax[1].set_zlim(-1,1)                    
        ax[0].legend()


        ax[0].set_title('Prediction')
        if save_img:
                outname = prefix+ f"img_proj_{flag}.png"
                plt.savefig(outname, dpi=300)

        plt.show()
                    
        if show_neighbors:
            fig, ax = plt.subplots(1, 1, figsize=(5, 6)  )
            ax=[ax]
            ax[0].imshow (result [samples, 3:3+max(resroundL,resroundL), :max(resroundL,resroundL)])
            #ax[0].imshow (X_train_batch [samples, 4:4+max(GTroundL,resroundL), :max(GTroundL,resroundL)])
            ax[0].set_title('Prediction')
            #ax[0].grid(False)
            #ax[0].set_title('GT')
            plt.show()

            fig, ax = plt.subplots(1, 1, figsize=(5, 6)  )
            ax=[ax]
            ax[0].imshow (result [samples, 3:3+max_length, :])
            #ax[0].imshow (X_train_batch [samples, 4:4+max_length, :])
            ax[0].set_title('Prediction')
            #ax[0].grid(False)
            #ax[0].set_title('GT')                        
            plt.show()

        if xyz_and_graph:
            for i in range (len(y_min)):
                y_train_batch[0,i]=unscale_data(y_train_batch[0,i], y_max[i],y_min[i]) 
                
            G_res, data_res, y_data_pred =construct_xyz_and_graph (result[0,:3+resroundL,:resroundL].squeeze(),
                                     fname_root=f'{prefix}graph_xyz_{flag}',
                                                    label='Prediction', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True,
                                                 GT_y=y_train_batch[0,:].cpu().numpy())
              
    steps=steps+1
    if xyz_and_graph:
        return result[samples, :resroundL+3, :resroundL].squeeze().permute(1,0), G_res, data_res,y_data
    else:
        return result[samples, :resroundL+3, :resroundL].squeeze().permute (1,0)


In [None]:
import networkx as nx
 
from torch_geometric.utils import to_networkx

def add_edge_to_graph(G, e1, e2, w):
    G.add_edge(e1, e2, weight=w,
              clamp_neighbors=True)
def get_properties(item):
    #print (item.edge_index, item.num_edges)
    #print (item)
    length=0
    dx,dy,dz=0,0,0
    
    for jj in range (item.num_edges):
        dx_=item.pos[item.edge_index[0,jj],0]-item.pos[item.edge_index[1,jj],0]
        dy_=item.pos[item.edge_index[0,jj],1]-item.pos[item.edge_index[1,jj],1]
        dz_=item.pos[item.edge_index[0,jj],2]-item.pos[item.edge_index[1,jj],2]
        
        dx=dx+dx_
        dy=dy+dy_
        dz=dz+dz_
        
        length=length + (dx_**2+dy_**2+dz_**2)**0.5
        
    avg_length = length /item.num_edges     
    dx, dy, dz = dx/item.num_edges, dy/item.num_edges, dz/item.num_edges
        
    num_nodes=item.num_nodes
    num_edges=item.num_edges
    node_degree=item.num_edges / item.num_nodes
    
     
    return avg_length.numpy(),dx.numpy(), dy.numpy(), dz.numpy(), num_nodes , num_edges, node_degree

def construct_xyz_and_graph (result, fname_root='output',
                             clamp_neighbors=True,label='Generated',
                             limits=None,#axis limits for plot, 
                             GT_y=None,
                             dist_matrix=False,
                             
                            ):
    print ("##############################################################################")
    print ("Shape of data provided ", result.shape) 
    print (f"Root file: {fname_root}")
    print ("##############################################################################")
        
    result[ :3, :]=unscale_data(result[:3,: ], X_max.numpy(),X_min.numpy())

    xs=result[ 0, :]
    ys=result[ 1, :]
    zs=result[ 2, :]
    
    with open(fname_root+'.xyz', 'w') as f:
        f.write('#ID x y z \n')
        for i in range (result.shape[1]):
            f.write(f'{i+1} {xs[i]} {ys[i]} {zs[i]} \n')
    
    node_list=[]
    neighbor_list=[]
    point_list=[]
    
    
    #now prepare graph
    for i in range (result.shape[1]):
        node_list.append ( [xs[i], ys[i], zs[i] ])
        point_list.append ( (xs[i], ys[i], zs[i]  ))
        
        if not dist_matrix:
            #neighbors are stored in result [4,5,..., 9]
            for j in range (max_neighbors):
                #neigh_j=i-int (result[ 3+j, i] )   +1
                neigh_j=result[ 3+j, i] 
                #print (neigh_j)
                if neigh_j !=0: #zeros are padded values...
                    #f neigh_j !=0:
                    neighn=  neigh_j-1 #neigh_j is 1+node number (since it encodes 0s....as padding)

                    if clamp_neighbors:
                        neighn=max( 0, min (neighn, result.shape[1]-1) )


                    neighbor_list.append ([i,neighn] )  #val=node1- node2 +1  --> node2= node1- val  +1
                    
        if dist_matrix:
            for j in range (result.shape[1]):
                #neigh_j=i-int (result[ 3+j, i] )   +1
                neigh_j=result[ 3+j, i] 
                #print (neigh_j)
                if neigh_j >0.5: #zeros are padded values... a value of >0.5 means a neighbor is found
                    #f neigh_j !=0:
                    neighn=  j

                    neighbor_list.append ([i,neighn] )  #val=node1- node2 +1  --> node2= node1- val  +1
            

    with open(fname_root+'.dump', 'w') as f:
        f.write(f'\n{result.shape[1]} atoms\n{len (neighbor_list)} bonds  \n\n1 atom types\n1 bond types\n\n')
        f.write(f'0 500 xlo xhi \n0 500 ylo yhi \n0 500 zlo zhi \n\n')
        f.write(f'Masses \n\n1 100.00  \n\n')
        f.write(f'Atoms  \n\n')


        for i in range (result.shape[1]):
            f.write(f'{i+1} 1 1 {xs[i]} {ys[i]} {zs[i]} \n')
        f.write(f'\nBonds  \n\n')

        for i in range (len (neighbor_list)):
            f.write(f'{i+1} 1 {neighbor_list[i][0]} { max( 1, min (neighbor_list[i][1], result.shape[1]) )}   \n')
        f.write(f'\n\n')
    

    #rint (node_list)
    #print (neighbor_list)
    
    neighbor_list=torch.Tensor (neighbor_list).long()
    print ("neighborlist shape: ", neighbor_list.shape)
    #print (neighbor_list)
    
    fig = plt.figure(figsize=(8,8))
    m=24
    ax = fig.add_subplot(projection='3d')
    ax.scatter(xs, ys, zs,c='red', s=m, marker="o")
    ax.set_proj_type('ortho')
    ax.set_title (label)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    print ("STATS of neighbor list: max ", neighbor_list.max(), "min ",  neighbor_list.min(), "shape ", neighbor_list.shape)
    
    
    for i in range (len (neighbor_list)):
            
            N1=neighbor_list[i][0] # because N1 and N2 refers to array indices, not note numbers
            
            N2=max( 0, min (neighbor_list[i][1], result.shape[1]-1) ) 
            #print (N1, N2)
            ax.plot3D ([xs[N1], xs[N2]], [ys[N1], ys[N2]] , [zs[N1], zs[N2]], 'k-', linewidth=1, )
           
            
    if limits != None:
        ax.set_xlim(limits[0],limits[1])
        ax.set_ylim(limits[0],limits[1])
        ax.set_zlim(limits[0],limits[1])    
        
    plt.savefig (fname_root+'.png', dpi=400)
    plt.savefig (fname_root+'.svg', dpi=400)
    plt.show()
    
    print ("Neighborlist shape COO^T format: ", neighbor_list.shape)
    x = torch.tensor(node_list, dtype=torch.float)  #Node feature matrix with shape [num_nodes, num_node_features]
    edge_index =neighbor_list.permute(1,0)  #Graph connectivity in COO format with shape [2, num_edges] and type torch.long
    y = None #torch.tensor(graph_label, dtype=torch.float)  #Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
    data = Data(x=x,pos=x, edge_index=edge_index, y=y)
    
    torch.save (data, fname_root+'.pt')
    
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    G = to_networkx(data, to_undirected=False)
    
   # print(f"radius: {nx.radius(G)} center: {nx.center(G)} density: {nx.density(G)}")
    
    #nx.draw(G, with_labels=False,  )
    print ("##############################################################################")
    
    
    #Now calculate graph properties
    
    avg_length,dx, dy, dx, num_nodes , num_edges, node_degree=get_properties(data)
    
    y_data=np.array([avg_length,dx, dy, dx, num_nodes , num_edges, node_degree])
    
    print (f"Graph properties measured: {labels_y_txt}\n",y_data)
    if exists (GT_y):
        print (f"Graph properties GT: {labels_y_txt}\n", GT_y)
        
        plt.plot ( y_data, GT_y, '.',label='Graph properties (GT vs predicted)',markersize=12 )
        plt.legend()
        plt.xlabel ('GT')
        plt.ylabel ('Predicted')
        min_v,max_v=min (min(y_data), min (GT_y)),max (max(y_data), max (GT_y))
        plt.plot([min_v, max_v], [min_v, max_v], ls="--", c=".3")
        plt.axis ('square')
        plt.show()
    return G,data, y_data
    #visualize_graph(G, color=data.y)       

### Set up model and training

In [None]:
import time

In [None]:
loss_list=[]


In [None]:
prefix='./Transformer_Model/'
if not os.path.exists(prefix):
        os.mkdir (prefix)

In [None]:
max_neighbors, y_data.shape[1]

In [None]:
embed_dim_neighbor=32

GWebT = GraphWebTransformer(
        dim=512,
        depth=12,
        dim_head = 64,
        heads = 16,
        dropout = 0.,
        ff_mult = 4,
        max_length=max_length,
        neigh_emb_trainable=False,
        max_norm=1.,#embedding ayer mnormed
        pos_emb_fourier=True,
        pos_emb_fourier_add=False,
        text_embed_dim = 64,
        embed_dim_position=64,
        embed_dim_neighbor=embed_dim_neighbor,
        predict_neighbors=True,#False,#whether or not to predict neighbors..
        pos_fourier_graph_dim=67,#fourier pos encoding of entire graph
        use_categorical_for_neighbors = False,
        predict_distance_matrix=True,
        cond_drop_prob = 0.25,
        max_text_len = y_data.shape[1],
).cuda()
params (GWebT)
 
optimizer = optim.Adam(GWebT.parameters() , lr=0.0001 )

In [None]:
train_loop (GWebT,
            train_loader,test_loader,
            optimizer=optimizer,
            print_every=10,
            epochs= 20,
            start_ep=0,
            start_step=0,
            train_unet_number=1,
            print_loss =  50* (len (train_loader)-1),
            plot_unscaled=False, 
            save_model=True,
            cond_scales=[1],
            num_samples=4,
            clamp=True,corplot=False,
            save_loss_images=False,show_neighbors=True,xyz_and_graph=True,
        )

In [None]:
fname =f'{prefix}statedict_save-model-epoch_772.pt'
GWebT.load_state_dict(torch.load(fname))

In [None]:
sample_loop (   GWebT,
                test_loader,
                cond_scales=[1,  ], #list of cond scales - each sampled...
                num_samples=16, #how many samples produced every time tested.....
                clamp=True, corplot=False,show_neighbors=True,
                xyz_and_graph=True, clamp_round_results=True, enforce_symmetry=False,
            )

### Generate hierarchical structures

In [None]:
def get_length (X1):
    return torch.nonzero(X1[:,3:]).flatten().max()+1
    
def get_length_xyzCOO (result):
    return torch.nonzero(result[3 ,: ])[-1]+1 #first neighbor matters...
    
def get_xyz_and_dist_matrix_fromxyzCOO (X_data_cl, clamp_neighbors=False ,
                            visualize=False, max_length=None):

    if not exists (max_length):
        max_length=get_length_xyzCOO (X_data_cl)
    
    X_data_cl=X_data_cl.permute (1,0)#bring to [max_length, xyz+max_neighbors]
    print (X_data_cl.shape)
        
    max_neighbors=X_data_cl.shape[1]-3 
    print ("Max neighbors: ", max_neighbors, "Length: ", max_length)
    dist_matrix=torch.zeros (max_length, max_length)#.to(device)
    
    for i in range (max_length):

        for j in range (max_neighbors):
            #neigh_j=i-int (result[ 3+j, i] )   +1
            neigh_j= X_data_cl[ i,3+j] 

            if neigh_j !=0: #zeros are padded values...

                neighn=  neigh_j.long()

                if clamp_neighbors:
                    neighn=max( 1, min (neighn, max_length) )
                    
                

                #print (dist_matrix.shape, j, neighn)
                #dist_matrix[neighn-1, j] = 1
                
                
                dist_matrix[i, neighn-1] = 1
                dist_matrix[neighn-1, i] = 1

    print (dist_matrix.shape,X_data_cl[:max_length, :3].shape )
    output=  torch.cat( (X_data_cl[:max_length, :3],   dist_matrix), 1)
    
    print ("Shape of new matrix (length, x,y,z+length+): ", output.shape)
    
        
    if visualize:
        
        plt.imshow (output[:, 3:3+max_length])
        plt.show()
    return output
    
def shift (X2, dx, dy, dz):
    X2[:, 0] =X2[:, 0] + dx
    X2[:, 1] =X2[:, 1] + dy
    X2[:, 2] =X2[:, 2] + dz
    return X2
    
def stack_and_extend (X1, X2, stack_node, dx,dy,dz, 
                     fname_root='file_name',
                      visualize=False,
                      make_graph = False,
                      avg_pos_in_overlap=False,
                      
                     ): #format: (dist matrix x distmatrix+xyz)
    #stacks two  graphs and overlaps them
    #stacknode determines where in X1 is second one added
    
    S1=get_length(X1) #X1.shape[0]
    S2=get_length(X2) #X2.shape[0]
    
    S_new=stack_node+S2

    dist_matrix=torch.zeros (S_new, S_new+3)#.to(device)
    

    X2[:, 0] =X2[:, 0] + dx
    X2[:, 1] =X2[:, 1] + dy
    X2[:, 2] =X2[:, 2] + dz
    
    dist_matrix[:S1, 3:3+S1]= X1 [:S1, 3:3+S1]
    dist_matrix[stack_node:stack_node+S2, 3+stack_node:3+stack_node+S2]= X2 [:S2, 3:3+S2]
    
    #now average positions
    
    dist_matrix[:S1, :3]= X1[:, :3]
    if avg_pos_in_overlap:
        dist_matrix[stack_node:stack_node+S2, :3]= dist_matrix[stack_node:stack_node+S2, :3]+X2[:, :3]
    else:
        dist_matrix[stack_node:stack_node+S2, :3]=  X2[:, :3]
   
    #dist_matrix[stack_node:stack_node+S2, :3]= dist_matrix[stack_node:stack_node+S2, :3]/2.
    if avg_pos_in_overlap: 
        dist_matrix[stack_node:S1, :3]= dist_matrix[stack_node:S1, :3]/2.
    

    #plt.plot (dist_matrix[:, 0])
    #plt.plot (dist_matrix[:, 1])
    #plt.plot (dist_matrix[:, 2])
    
    #plt.show()
    if visualize:
        
        plt.imshow (dist_matrix[:S_new, 3:3+S_new])
        plt.show()
        
        plt.imshow(X1[:,:3], aspect=.1)
        plt.show()
        plt.imshow(X2[:,:3], aspect=.1)
        plt.show()
        plt.imshow(dist_matrix[:,:3], aspect=.1)
        plt.show()
        
    if make_graph:
        G_res, data_respr, y_data_pred =construct_xyz_and_graph (dist_matrix.permute(1,0),
                                     fname_root=fname_root,
                                                    label='Stacked and extended', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True,
                                                  )
    
    return dist_matrix 
    

    
def make_spiral (X1, radius, 
                 slope_z, delta_angle, steps,  
                stagger_fraction = 0.8,
                fname_root='fname',
                 visualize=True,visualize_all=False,
                 shuffle=False,#shuflle graph
                 
                 avg_pos_in_overlap=False,
                 
                 generate_new_every_iteration=False,
                 cond_vector_list=None,
                 length_cond_vector=7,
                 is_COO=False, #whether model predicts COO (sparse) or not
                )   :

    S1=get_length(X1)
    delta_stagg=S1*stagger_fraction
    
    spiral = torch.clone (X1)
    
    i=-1
    dx= radius * math.cos ((i+1)*delta_angle)
    dy= radius * math.sin ((i+1)*delta_angle)
    dz= slope_z *(i+1)

    spiral = shift (spiral, dx, dy, dz)
    for i in tqdm (range (steps)):
        
        #keep it in there in case i want to generate new graphs every time....
       # S1=get_length(spiral)
    
        if generate_new_every_iteration:
            if exists (cond_vector_list):
                #[-0.7255, -0.2038,  0.5942,  0.3315,  0.286,  0.1852,  0.522]
                cond_v=cond_vector_list[i]
            else:
                cond_v=1.5*torch.rand (length_cond_vector)-0.75
            result, G_res, data_res,y_data= generate_sample_cond (GWebT,
                cond=cond_v,#[-.5, 0.7, 0.6, 0.2, 0.3, 0.4, -0.9],
                cond_scales=[1.,  ], #list of cond scales - each sampled...
                flag=9992,
              clamp=True, corplot=True,show_neighbors=True,
            xyz_and_graph=True)
            
            
            if is_COO:
                X1=get_xyz_and_dist_matrix_fromxyzCOO (result, clamp_neighbors=True ,
                                            visualize=False)#convert xyz-COO coding to xyz-distance matrix
            else:
                X1=result
            S1=get_length(X1)
            delta_stagg=S1*stagger_fraction
    


        if shuffle:
            
            if visualize_all:
                plt.imshow (X1[:S1, 3:3+S1])
                plt.show()               
            ar=torch.range (0,S1-1).long()
            ar2=torch.range (0,S1-1+3).long()
            c=torch.randperm(S1)
            c2=torch.cat( (torch.range (0,2), c+3)).long()
            #print (X1.shape,"S1", S1, "v ", ar, ar2, c, c2)
            X1=torch.cat( ( X1[:,:3],  X1.clone () [ar][c,3:]), 1)
            #print (X1.shape)
            X1=X1[:,ar2][:,c2]
            
            if visualize_all:
                plt.imshow (X1[:S1, 3:3+S1])
                plt.show()     
    
       
        stack_node=int (get_length(spiral)-delta_stagg )
        
        #print ("stack node: ", stack_node, i)
    
        dx= radius * math.cos ((i+1)*delta_angle)
        dy= radius * math.sin ((i+1)*delta_angle)
        dz= slope_z *(i+1)
        
        #print (dx, dy, dz)
        spiral=stack_and_extend (spiral.clone(), X1.clone(), stack_node, dx,dy,dz, 
                      visualize=False,
                      make_graph = False,avg_pos_in_overlap=avg_pos_in_overlap,
                     )
        #print (spiral.shape)
        spiral[:, 3: ]=torch.clamp(spiral[:, 3: ], 0, 1) 
        
        #print ("#### length spiral ", get_length(spiral), spiral.shape, i)
        
    S1=get_length(spiral)
    
    
    #print ("Size of spiral; ", S1)
    plt.imshow (spiral[:S1, 3:3+S1],interpolation='none')
    plt.show()     
    
    plt.plot (spiral[:, 0])
    plt.plot (spiral[:, 1])
    plt.plot (spiral[:, 2])
    plt.show()
    
    plt.imshow (spiral[:S1, :3],interpolation='none')
    plt.show()     
   
        
    
    G_res, data_res, y_data_pred =construct_xyz_and_graph (spiral.permute(1,0),
                                     fname_root=fname_root,
                                                    label='Spiral', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True,
                                                  )
    
    if visualize:
        
        plt.imshow (spiral[:, 3:],interpolation='none')
        plt.show()
    #x(t) = rcos(t), y(t) = rsin(t), z(t) = at,
    
    return spiral
    

In [None]:
output, G_res, data_res,y_data= generate_sample_cond (GWebT,
                cond=[-0.7255, -0.2038,  0.5942,  0.3315,  0.1786,  0.1852,  0.0122],#[-.5, 0.7, 0.6, 0.2, 0.3, 0.4, -0.9],
                cond_scales=[1.,  ], #list of cond scales - each sampled...
                flag=9992,
              clamp=True, corplot=True,show_neighbors=True,
            xyz_and_graph=True)

In [None]:
print(output.shape)

In [None]:
plt.imshow (output[:,3:])
plt.show()

get_length (output)

In [None]:
                      
output=stack_and_extend (output.clone(), output.clone(), 10, 100,60,50,  #
                     fname_root='file_name',
                      visualize=True,
                      make_graph = True,
                                            
                     )#{max_neighbors, xyz+max_neighbors}

In [None]:
output.clone().shape

In [None]:
spiral=make_spiral (output, radius=10, 
                 slope_z=5, delta_angle= 10/360*2*math.pi, steps=50,  
                stagger_fraction = 0.5,
                fname_root='spiral_transf_distmap_215_1',shuffle=True,
                    avg_pos_in_overlap=False,
                )
