# GraphGeneration: Modeling and design of hierarchical bio-inspired de novo spider web structures using deep learning and additive manufacturing 

Spider webs are incredible biological structures, comprising thin but strong silk filament and arranged into highly complex hierarchical architectures with striking mechanical properties (e.g., lightweight but high strength).  While simple 2D orb webs can easily be mimicked, the modeling and synthesis of artificial, bio-inspired 3D-based web structures is challenging, partly due to the rich set of design features. Here we use deep learning as a way to model and synthesize such 3D web structures, where generative models are conditioned based on key geometric parameters (incl.: average edge length, number of nodes, average node degree, and others). To identify construction principles, we use inductive representation sampling of large spider web graphs and develop and train three distinct conditional generative models to accomplish this task: 1) An analog diffusion model with sparse neighbor representation, 2) a discrete diffusion model with full neighbor representation, and 3) an autoregressive transformer architecture with full neighbor representation. We find that all three models can produce complex, de novo bio-inspired spider web mimics and successfully construct samples that meet the design conditioning that reflect key geometric features (including, the number of nodes,   spatial orientation, and edge lengths). We further present an algorithm that assembles inductive samples produced by the generative deep learning models into larger-scale structures based on a series of geometric design targets, including helical forms and parametric curves. 

[1] W. Lu, N.A. Lee, M.J. Buehler, "Modeling and design of hierarchical bio-inspired de novo spider web structures using deep learning and additive manufacturing," PNAS, 120 (31) e2305273120, 2023, https://www.pnas.org/doi/10.1073/pnas.2305273120 

## Model 2: Analog diffusion model with full neighbor representation

In [None]:
import os,sys
import math

#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
def exists(val):
    return val is not None


In [None]:
import torch

from sklearn.model_selection import train_test_split

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")

In [None]:
#from tqdm import tqdm, trange
from tqdm.notebook import trange, tqdm

In [None]:
device

In [None]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
available_gpus

In [None]:
num_of_gpus = torch.cuda.device_count()
print(num_of_gpus)

### Load dataset

In [None]:
from torch.utils.data import DataLoader,Dataset
import pandas as pd
import seaborn as sns
import time

In [None]:
 
import torchvision
 
import matplotlib.pyplot as plt
import numpy as np
 
from torch import nn
from torch import optim, Tensor
import torch.nn.functional as F
from torchvision import datasets, transforms, models

import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from functools import partial, wraps

In [None]:
print("Torch version:", torch.__version__) 

In [None]:
from sklearn.preprocessing import QuantileTransformer
from sklearn.preprocessing import RobustScaler
import ast
import pandas as pd
import numpy as np

### Load data

In [None]:
from torch_geometric.utils import degree
from torch_geometric  import transforms

In [None]:
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, LinearTransformation
import math
import numbers
import random
from typing import Tuple, Union

 
class RandomRotateLoc(BaseTransform):
    r"""Rotates node positions around a specific axis by a randomly sampled
    factor within a given interval (functional name: :obj:`random_rotate`).

    Args:
        degrees (tuple or float): Rotation interval from which the rotation
            angle is sampled. If :obj:`degrees` is a number instead of a
            tuple, the interval is given by :math:`[-\mathrm{degrees},
            \mathrm{degrees}]`.
        axis (int, optional): The rotation axis. (default: :obj:`0`)
    """
    def __init__(self, degrees: Union[Tuple[float, float], float],
                 axis: int = 0):
        if isinstance(degrees, numbers.Number):
            degrees = (-abs(degrees), abs(degrees))
        assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
        self.degrees = degrees
        self.axis = axis

    def __call__(self, data: Data) -> Data:
        degree = math.pi * random.uniform(*self.degrees) / 180.0
        sin, cos = math.sin(degree), math.cos(degree)
        
        if data.pos.size(-1) == 2:
            matrix = [[cos, sin], [-sin, cos]]
        else:
            if self.axis == 0:
                matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
            elif self.axis == 1:
                matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
            else:
                matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
                
      
        return LinearTransformationLoc(torch.tensor(matrix))(data)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.degrees}, '
                f'axis={self.axis})')


In [None]:
from torch_geometric.data import Data, HeteroData
class LinearTransformationLoc(BaseTransform):
    r"""Transforms node positions :obj:`data.pos` with a square transformation
    matrix computed offline (functional name: :obj:`linear_transformation`)

    Args:
        matrix (Tensor): Tensor with shape :obj:`[D, D]` where :obj:`D`
            corresponds to the dimensionality of node positions.
    """
    def __init__(self, matrix: Tensor):
        if not isinstance(matrix, Tensor):
            matrix = torch.tensor(matrix)
        assert matrix.dim() == 2, (
            'Transformation matrix should be two-dimensional.')
        assert matrix.size(0) == matrix.size(1), (
            f'Transformation matrix should be square (got {matrix.size()})')

        # Store the matrix as its transpose.
        # We do this to enable post-multiplication in `__call__`.
        self.matrix = matrix.t()
      #  print (self.matrix.shape)

    def __call__(
        self,
        data: Union[Data, HeteroData],
    ) -> Union[Data, HeteroData]:
        for store in data.node_stores:
            if not hasattr(store, 'pos'):
                continue

            pos = store.pos.view(-1, 1) if store.pos.dim() == 1 else store.pos
            assert pos.size(-1) == self.matrix.size(-2), (
                'Node position matrix and transformation matrix have '
                'incompatible shape')
            # We post-multiply the points by the transformation matrix instead
            # of pre-multiplying, because `pos` attribute has shape `[N, D]`,
            # and we want to preserve this shape.
           # print (self.matrix)
            store.pos = pos @ self.matrix.to(pos.device, pos.dtype)
      
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(\n{self.matrix.cpu().numpy()}\n)'

In [None]:

input, y_data,node_number_list,labels_y_txt, max_length , posdim_emb, max_neighbors =\
                                        torch.load('dataset_webs_medium.pt')



In [None]:
#Data format for input
#0: ordinal numbering of nodes
#1,2,3: x,y,z coordinates
#4,5,6,7,8,9 - list of neighbors
#10,11,12,13,14,1 - list of distances to neighbors

In [None]:
print (input.shape)
print (y_data.shape)

In [None]:

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, LinearTransformation
import math
import numbers
import random
from typing import Tuple, Union

class RandomRotateDiffusion(BaseTransform):
    r"""Rotates node positions around a specific axis by a randomly sampled
    factor within a given interval (functional name: :obj:`random_rotate`).

    Args:
        degrees (tuple or float): Rotation interval from which the rotation
            angle is sampled. If :obj:`degrees` is a number instead of a
            tuple, the interval is given by :math:`[-\mathrm{degrees},
            \mathrm{degrees}]`.
        axis (int, optional): The rotation axis. (default: :obj:`0`)
    """
    def __init__(self, degrees: Union[Tuple[float, float], float],
                 axis: int = 0):
        if isinstance(degrees, numbers.Number):
            degrees = (-abs(degrees), abs(degrees))
        assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
        self.degrees = degrees
        self.axis = axis

    def __call__(self, pos):
        degree = math.pi * random.uniform(*self.degrees) / 180.0
        sin, cos = math.sin(degree), math.cos(degree)
        
        

        if data.pos.size(-1) == 2:
            matrix = [[cos, sin], [-sin, cos]]
        else:
            if self.axis == 0:
                matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
            elif self.axis == 1:
                matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
            else:
                matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
        matrix=torch.Tensor (matrix)        
       # print (matrix)
        pos=  pos @  matrix#.to(pos.device, pos.dtype)
        return pos
    
    #    return LinearTransformationLoc(torch.tensor(matrix))(data)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.degrees}, '
                f'axis={self.axis})')

### Normalize data and prepare dataloader

In [None]:
class RegressionDataset(Dataset):
    
    def __init__(self, X_data, y_data, node_number_list, degrees=0, jitter=0, clamp_neighbors=True,
                enforce_symm=True,):
        self.X_data = X_data
        self.y_data = y_data
        self.node_number_list=node_number_list
        
        self.degrees=degrees
        self.jitter=jitter
        self.enforce_symm=enforce_symm
        
        self.randomrotatex= RandomRotateDiffusion (degrees=self.degrees, axis=0)
        self.randomrotatey= RandomRotateDiffusion (degrees=self.degrees, axis=1)
        self.randomrotatez= RandomRotateDiffusion (degrees=self.degrees, axis=2)
        self.clamp_neighbors=clamp_neighbors
        
    def __getitem__(self, index):
        
        self.X_data_cl=self.X_data.clone()
        resroundL = self.node_number_list [index] #torch.nonzero(self.X_data_cl[index,:,1])[-1]
        
        if self.degrees>0:
           
            resroundL = torch.nonzero(self.X_data_cl[index,:,1])[-1]
            pos=self.X_data_cl[index,:resroundL,1:4]
           
            pos =self.randomrotatex(pos)
            pos =self.randomrotatey(pos)
            pos =self.randomrotatez(pos)
            self.X_data_cl[index,:resroundL,1:4]=pos
            
        if self.jitter >0:
            dx=torch.randn(resroundL)*self.jitter 
            dy=torch.randn(resroundL)*self.jitter 
            dz=torch.randn(resroundL)*self.jitter 
            
            dx=torch.clamp(dx, min=-self.jitter, max=self.jitter ) 
            dy=torch.clamp(dy, min=-self.jitter, max=self.jitter ) 
            dz=torch.clamp(dz, min=-self.jitter, max=self.jitter ) 
 
            self.X_data_cl[index,:resroundL,1]=self.X_data_cl[index,:resroundL,1]+dx
            self.X_data_cl[index,:resroundL,2]=self.X_data_cl[index,:resroundL,2]+dy
            self.X_data_cl[index,:resroundL,3]=self.X_data_cl[index,:resroundL,3]+dz
            
        self.X_data_cl[index,resroundL:,:]=0
            
       
        dist_matrix=-torch.ones (max_length, max_length)#.to(device)
        
        for i in range (max_length):
             
            #neighbors are stored in result [4,5,..., 9]
            for j in range (max_neighbors):
                #neigh_j=i-int (result[ 3+j, i] )   +1
                neigh_j=self.X_data_cl[index, i, 4+j] 
                
                if neigh_j !=0: #zeros are padded values...
                     
                    neighn=  neigh_j.long()

                    if self.clamp_neighbors:
                        neighn=max( 1, min (neighn, max_length) )
                    
                    #print (dist_matrix.shape, j, neighn)
                    #dist_matrix[neighn-1, i] = 1
                    dist_matrix[i, neighn-1] = 1
                    if self.enforce_symm:
                        #dist_matrix[i, neighn-1] = 1
                        dist_matrix[neighn-1, i] = 1
                    
                  
        output=  torch.cat( (self.X_data_cl[index,:, :4],   dist_matrix), 1)
        output = output.permute (1,0)
        
        
        return output, self.y_data[index]
        
    def __len__ (self):
        return len(self.X_data)

def scale_data(image2, maxv, minv): 
    return (image2 -minv)/(maxv-minv) * 2. - 1.0
 
def unscale_data(image2, maxv, minv):
   
    image2=(image2 +1. )/ 2. * (maxv-minv)+minv 
    return image2


def normalize_data (input, y_data, X_min=None, X_max=None, y_min=None, y_max=None,
                   
                    X_max_neigh=None, X_min_neigh=None,
                    Xscale=0):
    if X_min==None:
        X_min=input[:,:,1:4].min() 
    else:
        print ("use provided X_min", X_min)
    if X_max==None:
        X_max=input[:,:,1:4].max() 
    else:
        print ("use provided X_max", X_max)

    input[:,:,1:4]=(input[:,:,1:4]-X_min)/(X_max-X_min)*(2-2*Xscale)-(1-Xscale) #Normalize range -1 to 1

    print ("Check X after norm  ", input[:,:,1:4].min(), input[:,:,1:4].max())

    print ("Check X_neigh before norm  ", input[:,:,4:9].min(), input[:,:,4:9].max())  
    

    if y_min==None:
        y_min=[]
        for i in range (y_data.shape[1]):
            y_min.append(y_data[:,i].min())
        
    else:
        print ("use provided y_min", y_min)
    if y_max==None:
        y_max=[]
        for i in range (y_data.shape[1]):
            y_max.append(y_data[:,i].max())
        
    else:
        print ("use provided y_max", y_max)
    for i in range (y_data.shape[1]):
        y_data[:,i]=(y_data[:,i]-y_min[i] )/(y_max[i] -y_min[i])*2-1 #Normalize range -1 to 1
    

    print ("Check y_data after norm  ", y_data.min(), y_data.max())
    return input, y_data, X_min, X_max, y_min, y_max#,X_min_neigh, X_max_neigh
 

In [None]:
X_scaled, y_data_scaled, X_min, X_max, y_min, y_max = normalize_data (input, y_data,  Xscale=0. )

In [None]:
X_scaled.shape

In [None]:
def pad_sequence_end (output_xyz, max_length_l):         #pad
    output=torch.zeros((output_xyz.shape[0] , max_length_l,  output_xyz.shape[2])).to(device)
    output[:,:output_xyz.shape[-2],:]=output_xyz  
    return output.to(device)


In [None]:
max_neighbors

In [None]:
def get_data_loaders (X_scaled,  y_data_scaled, node_number_list, split=0.1, batch_size_=16):

    X_train, X_test, y_train, y_test, node_number_list_train, node_number_list_test = train_test_split(X_scaled, 
                                                                                                       y_data_scaled ,
                                                                                                       node_number_list,
                                                                                                       test_size=split,random_state=235)


    print (f"Shapes= {X_scaled.shape}, {y_data_scaled.shape}")
    
     
    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)
    train_dataset = RegressionDataset(X_train, y_train, node_number_list_train, degrees=0, jitter=0.0,
                                     enforce_symm=False) #/ynormfac)

    test_dataset = RegressionDataset(X_test,y_test,node_number_list_test, degrees=0, jitter=0.0,
                                    enforce_symm=False)


    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size_, shuffle=True)
    train_loader_noshuffle = DataLoader(dataset=train_dataset, batch_size=batch_size_, shuffle=False)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size_)

    return train_loader,train_loader_noshuffle, test_loader

In [None]:
train_loader,train_loader_noshuffle, test_loader= get_data_loaders (X_scaled,  y_data_scaled,node_number_list, split=0.1, 
                                                                    batch_size_=128)

In [None]:
X_min, X_max, y_min, y_max

### Training loop

In [None]:
def train_loop (model,
                train_loader,test_loader,
                optimizer=None,
                print_every=10,
                epochs= 300,
                start_ep=0,
                start_step=0,
                train_unet_number=1,
                print_loss=1000,
                plot_unscaled=False,
                save_model=False,
                cond_scales=[7.5], #list of cond scales - each sampled...
                num_samples=2, #how many samples produced every time tested.....
                timesteps=10,
                save_loss_images=False,clamp=False,
                corplot=False,show_neighbors=False,
                xyz_and_graph=False,
                 dist_matrix_threshold=0.99,
                discretize=True,
                get_length_from_result=True,
                only_continuous=True,
                
               ):
    
    #print_loss=1
    #if not exists (optimizer):
    #        print ("ERROR: need to provide optimizer.")
    steps=start_step
    start = time.time()
    

    loss_total=0
    for e in range(1, epochs+1):
            start = time.time()

            torch.cuda.empty_cache()
          
            # TRAINING
            train_epoch_loss = 0
            model.train()
            

            for item  in train_loader:


                X_train_batch= item[0].to(device)
                y_train_batch=item[1].to(device)
                optimizer.zero_grad()
                loss=model (  y_train_batch , X_train_batch) #( batch_sentences, output )
                loss.backward( )
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

                optimizer.step()

                loss_total=loss_total+loss.item()


                if steps % print_every == 0:
                    print(".", end="")

                if steps>0:
                    if steps % print_loss == 0:
                        norm_loss=loss_total/print_loss
                        print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}")

                        loss_list.append (norm_loss)
                        loss_total=0

                        plt.plot (loss_list, label='Loss')
                        plt.legend()

                        if save_loss_images:
                            outname = prefix+ f"loss_{e}_{steps}.jpg"
                            plt.savefig(outname, dpi=200)
                        plt.show()
                        
                        
                        sample_loop (model,device ,
                            test_loader,
                            cond_scales=cond_scales, #list of cond scales - each sampled...
                            num_samples=num_samples, #how many samples produced every time tested.....
                            timesteps=timesteps,clamp=clamp,corplot=corplot,
                            save_img=save_loss_images,show_neighbors=show_neighbors,
                            flag=steps,xyz_and_graph=xyz_and_graph,
                            clamp_round_results=True, enforce_symmetry=False,
                            dist_matrix_threshold= dist_matrix_threshold , discretize=discretize,
                            get_length_from_result=get_length_from_result,
                            only_continuous=only_continuous,
                                    )
                        
                        print (f"\n\n-------------------\nTime passed for {print_loss} epochs at {steps} = {(time.time()-start)/60} mins\n-------------------")
                        start = time.time()
                        if save_model:
                         
                            fname=f"{prefix}statedict_save-model-epoch_{e}.pt"
                            torch.save(model.state_dict(), fname)
                            print (f"Model saved: ", fname)
                steps=steps+1




In [None]:
y_min, y_max=np.array (y_min), np.array (y_max)

In [None]:
y_min, y_max

In [None]:
sns.set_style("whitegrid")

In [None]:
def delete_isolated (a, clean_symmetrical=True):
    a=a.clone ()
    delete_list=[]
    #b=torch.nonzero(a)[:,-1] #find node number list
    b=torch.nonzero(a)[:,1] #find node number list
    
    c=torch.unique (b)
    print (c)
    #for i in range (c.shape[0]-1):
    for i in range (c.shape[0]-1):
        node_id=c[i]
        node_id_pone=c[i+1]
        
        if node_id_pone!=node_id+1:
            a [:,node_id_pone]=0
            print ("#########################", node_id, node_id_pone)
            if clean_symmetrical:
                a [node_id_pone,:]=0
            
    return a
def select_first_continuous (a, clean_symmetrical=True):
    a=a.clone ()
    delete_list=[]
    
    b=torch.nonzero(a)[:,1] #find node number list
    
    c=torch.unique (b) #get unique entries
    print (c)
    #for i in range (c.shape[0]-1):
    for i in range (c.shape[0]-1):
        node_id=c[i]
        node_id_pone=c[i+1]
        
        if node_id_pone!=node_id+1:
            a [:,node_id_pone:]=0
            print ("#########################", node_id, node_id_pone)
            if clean_symmetrical:
                a [node_id_pone:,:]=0
            
    return a
                        

In [None]:
from sklearn.metrics import r2_score

def sample_loop (model,device,
                train_loader,
                cond_scales=[7.5], #list of cond scales - each sampled...
                num_samples=2, #how many samples produced every time tested.....
                timesteps=100,
                 flag=0,clamp=False,
                 corplot=False,
                 save_img=False,
                 show_neighbors=False,
                xyz_and_graph=False,GED=False,
                 max_length_considered=None,#consider predictions only up to a certain token number 
                     enforce_symmetry = False, #if True: make distance matrix symmetric
                 clamp_round_results=True,dist_matrix_threshold=0.25,
                 discretize = True, #whether or not to discretize to 0 or 1
                 get_length_from_result=False,
                 delete_isolated_nodes=True,
                 only_continuous=True,
                 show_colorbar=False,
               ):
    steps=0
    e=flag
    
    if not exists (max_length_considered):
        max_length_considered=max_length
    for item  in train_loader:
            
            X_train_batch= item[0]
            y_train_batch=item[1].to(device)
            
            #X_train_batch=torch.permute(X_train_batch, (0,2,1)  )
            print ("###", X_train_batch.shape)
            GT=y_train_batch.cpu().detach().unsqueeze(1) 

                
            ####
            num_samples = min (num_samples,y_train_batch.shape[0] )
            print (f"Producing {num_samples} samples...")
            if xyz_and_graph:
                y_data_coll_pred=[]
                y_data_coll_GT=[]  
            for iisample in range (len (cond_scales)):
                result=model.sample ( y_train_batch,device,
                                         cond_scale=cond_scales[iisample],
                                         timesteps=timesteps,clamp=clamp
                                          )
                
                result=result.cpu()#.numpy()
                
                 
                
                if enforce_symmetry:
                    result[:, 3:3+max_length, :]=0.5*(result[:, 3:3+max_length, :]+\
                                                      torch.transpose (result[:, 3:3+max_length, :], 1, 2 ) )
                if clamp_round_results:
                    result[:, 3:3+max_length, :]=torch.clamp(result[:, 3:3+max_length, :], 0, 1) 
                    X_train_batch[:, 4:4+max_length, :]=torch.clamp (X_train_batch[:, 4:4+max_length, :], 0, 1)
                
                if discretize:
                    result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]< dist_matrix_threshold]=0
                    result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]>= dist_matrix_threshold]=1
                    

                
                result=result.cpu() 
                
                print (f"sample result for cond_scale={cond_scales[iisample]}....", result.shape, "GT shape ", GT.shape)

                #print (result.shape, GT.shape)
               # X_train_batch= pad_sequence (X_train_batch,  max_length).cpu()
                if xyz_and_graph:
                    y_data_coll_pred=[]
                    y_data_coll_GT=[]                
                for samples in range  (num_samples):
                    #print ("############## shape: ", result[samples, 3:3+max_length, :].shape)
                    if delete_isolated_nodes:
                        result[samples, 3:3+max_length, :]=delete_isolated (result[samples, 3:3+max_length, :])

                    if only_continuous:
                         result[samples, 3:3+max_length, :]= select_first_continuous(result[samples, 3:3+max_length, :])
                
                    if get_length_from_result:
                        
                       
                        GTroundL = torch.nonzero(X_train_batch[samples, 4:4+max_length, :])[:,1].flatten().max()+1
                        resroundL = torch.nonzero(result[samples, 3:3+max_length, :])[:,1].flatten().max()+1
                     
                    else:
                        GTroundL=max_length
                        resroundL=max_length
                    
                    #resroundL= (resround==-1).nonzero(as_tuple=True)  [0]
                      
                    fig, ax = plt.subplots(1, 2, figsize=(10, 6) , subplot_kw=dict(projection='3d'))
                    
                     
                    xs=X_train_batch[samples, 1, :GTroundL]
                    ys=X_train_batch[samples, 2, :GTroundL]
                    zs=X_train_batch[samples, 3, :GTroundL]
                    m=6
                    ax[0].scatter(xs, ys, zs, c='red', s=m, marker="o")

                    ax[0].set_xlabel('X')
                    ax[0].set_ylabel('Y')
                    ax[0].set_zlabel('Z')
                    ax[0].set_title('GT')
                    ax[0].set_xlim(-1,1)
                    ax[0].set_ylim(-1,1)
                    ax[0].set_zlim(-1,1)
                    
                    
                    
                    xs=result[samples, 0, :resroundL]
                    ys=result[samples, 1, :resroundL]
                    zs=result[samples, 2, :resroundL]
                    m=6
                    ax[1].scatter(xs, ys, zs, c='red', s=m, marker="o")

                    ax[1].set_xlabel('X')
                    ax[1].set_ylabel('Y')
                    ax[1].set_zlabel('Z')
                    ax[1].set_xlim(-1,1)
                    ax[1].set_ylim(-1,1)
                    ax[1].set_zlim(-1,1)                    
                    
                    
                    ax[1].set_title('Prediction')
                    if save_img:
                            outname = prefix+ f"img_{flag}.png"
                            plt.savefig(outname, dpi=300)
                    
                    plt.show()
                    
                    fig, ax = plt.subplots(1, 2, figsize=(10, 6)  )
                    
                     
                    xs=X_train_batch[samples, 1, :GTroundL]
                    ys=X_train_batch[samples, 2, :GTroundL]
                    zs=X_train_batch[samples, 3, :GTroundL]
                    m=2
                    ax[0].plot(xs, ys , 'bo', markersize=m, label='Y over X')
                    ax[0].plot(xs, zs,'ro', markersize=m, label='Z over X')

                    ax[0].set_xlabel('X')
                    ax[0].set_ylabel('Y/Z')
                   # ax[0].set_zlabel('Z')
                    ax[0].set_title('GT')
                    ax[0].axis('square')
                    ax[0].set_xlim(-1,1)
                    ax[0].set_ylim(-1,1)
                    ax[0].legend()
                    
                   # ax[0].set_zlim(-1,1)
                    
                    xs=result[samples, 0, :resroundL]
                    ys=result[samples, 1, :resroundL]
                    zs=result[samples, 2, :resroundL]
                     
                    ax[1].plot(xs, ys ,'bo', markersize=m, label='Y over X')
                    ax[1].plot(xs, zs,'ro', markersize=m, label='Z over X')

                    ax[1].set_xlabel('X')
                    ax[1].set_ylabel('Y/Z')
                    #ax[1].set_zlabel('Z')
                    ax[1].axis('square')
                    ax[1].set_xlim(-1,1)
                    ax[1].set_ylim(-1,1)
                    #ax[1].set_zlim(-1,1)                    
                    ax[1].legend()
                    
                    
                    ax[1].set_title('Prediction')
                    if save_img:
                            outname = prefix+ f"img_proj_{flag}.png"
                            plt.savefig(outname, dpi=300)
                    
                    plt.show()
                    
                    
                    if show_neighbors:
                        fig, ax = plt.subplots(1, 2, figsize=(10, 6)  )
                    
                        ax[1].imshow (result [samples, 3:3+max(GTroundL,resroundL), :max(GTroundL,resroundL)])
                        ax[0].imshow (X_train_batch [samples, 4:4+max(GTroundL,resroundL), :max(GTroundL,resroundL)])
                        ax[1].set_title('Prediction')
                        ax[0].set_title('GT')
                      
                        plt.show()
                        
                        fig, ax = plt.subplots(1, 2, figsize=(10, 6)  )
                    
                        shw1=ax[1].imshow (result [samples, 3:3+max_length, :])
                        shw0=ax[0].imshow (X_train_batch [samples, 4:4+max_length, :]) #X_train_batch[0,4:,:]
                        ax[1].set_title('Prediction')
                        ax[0].set_title('GT')    
                         
                        if show_colorbar:
                            bar0 = plt.colorbar(shw0)
                            bar0.set_label('GT')
                            bar1 = plt.colorbar(shw1)
                            bar1.set_label('Prediction')                        
                        #ax[0].grid(False)
                       # ax[1].grid(False)
                        plt.show()
                        
                    if xyz_and_graph:
                        G_res, data_res, y_data_pred =construct_xyz_and_graph (result[samples,:3+resroundL,:resroundL].squeeze(),
                                                 fname_root=f'{prefix}graph_xyz_{samples}_{flag}_{steps}',
                                                                label='Prediction', #limits=[X_min*1.2, X_max*1.2],
                                                                dist_matrix=True)
                        G_GT, data_GT, y_data_GT=construct_xyz_and_graph (X_train_batch[samples, 1:4+GTroundL, :GTroundL].squeeze(),
                                                 fname_root=f'{prefix}graph_GT_xyz_{samples}_{flag}_{steps}',
                                                              label='GT',dist_matrix=True)
                        
                        

                        plt.plot ( y_data_GT, y_data_pred, '.',label='Graph properties (GT vs predicted)',markersize=12 )
                        plt.legend()
                        plt.xlabel ('GT')
                        plt.ylabel ('Predicted')
                        min_v,max_v=min (min(y_data_GT), min (y_data_pred)),max (max(y_data_GT), max (y_data_pred))
                        plt.plot([min_v, max_v], [min_v, max_v], ls="--", c=".3")
                        plt.axis ('square')
                        plt.show()
                        
                        #print (y_max, y_min,GT.shape[2])
                        for i in range (len(y_min)):
                            y_data_pred[i]=scale_data(y_data_pred[i], y_max[i],y_min[i]) 
                            y_data_GT[i]=scale_data(y_data_GT[i], y_max[i],y_min[i]) 
                            
                        plt.plot ( y_data_GT, y_data_pred, '.',label='Graph properties (GT vs predicted)',markersize=12 )
                        plt.legend()
                        plt.xlabel ('GT')
                        plt.ylabel ('Predicted')
                        plt.plot([-1, 1], [-1, 1], ls="--", c=".3")
                        plt.axis ('square')
                        plt.show()
                        
                        y_data_coll_pred.append (y_data_pred)
                        y_data_coll_GT.append (y_data_GT)#.cpu().numpy())
                        
                        if GED:
                            GED=nx.graph_edit_distance (G_res, G_GT)
                            print ("Graph edit distance=", GED)
                        print (f"")
                        
                        
                        
                    print (X_train_batch.shape,result.shape )
                    if corplot:
                        plt.plot (X_train_batch[samples, 1:4, :].flatten().detach().cpu(), 
                                  result[samples, 0:3, :].flatten() , '.',label='all',markersize=6 )
                        plt.plot (X_train_batch[samples, 1, :].flatten().detach().cpu(), 
                                  result[samples, 0, :].flatten() , '.',label='x',markersize=2 )
                        plt.plot (X_train_batch[samples, 2, :].flatten().detach().cpu(), 
                                  result[samples, 1, :].flatten() , '.',label='y',markersize=2 )
                        plt.plot (X_train_batch[samples, 3, :].flatten().detach().cpu(), 
                                  result[samples, 2, :].flatten() , '.',label='z',markersize=2 )
                        
                        plt.legend()
                        
                        plt.plot([-1, 1], [-1, 1], ls="--", c=".3")
                        plt.axis ('square')
                        plt.show()

                        
                y_data_coll_pred=np.array(y_data_coll_pred).flatten()
                y_data_coll_GT=np.array(y_data_coll_GT).flatten()        
                
                print ("collected shape ", y_data_coll_pred.shape, y_data_coll_GT.shape)
                
                print ("collected  ", y_data_coll_pred, y_data_coll_GT)
                
                R2=r2_score(y_data_coll_GT, y_data_coll_pred)
                print ("OVERALL R2: ", R2)
                plt.plot ( y_data_coll_GT, y_data_coll_pred, '.', label='Graph properties (GT vs predicted)',markersize=3 )
                plt.legend()
                plt.xlabel ('GT')
                plt.ylabel ('Predicted')
                plt.plot([-1, 1], [-1, 1], ls="--", c=".3")
                plt.axis ('square')
                plt.title ("Correlation prediction vs. GT")
                plt.show()                        
                        
                        
            steps=steps+1

        

In [None]:
def generate_sample_cond (model,
              
                cond_scales=[7.5], #list of cond scales - each sampled...
               # num_samples=2, #how many samples produced every time tested.....
                timesteps=100,
                 flag=0,clamp=False,
                 corplot=False,
                 save_img=False,
                 show_neighbors=False,
                xyz_and_graph=False,GED=False,
                cond=[1., .5, 1.],
                     enforce_symmetry = False, #if True: make distance matrix symmetric
                 clamp_round_results=True,dist_matrix_threshold=0.25,
                 discretize = True, #whether or not to discretize to 0 or 1
                 get_length_from_result=False,
                 delete_isolated_nodes=True,
                 only_continuous=True,
                          show_colorbar=False,
                          save_neigh_img=False,
                # show_colorbar=False,                 
               ):
    steps=0
    e=flag
    #for item  in train_loader:


    #X_train_batch= item[0]
    y_train_batch=torch.Tensor (cond).to(device)
    y_train_batch=y_train_batch.unsqueeze(0) 
    #X_train_batch=torch.permute(X_train_batch, (0,2,1)  )



    for iisample in range (len (cond_scales)):
        
        samples=0
        print ("y_train_batch ", y_train_batch.shape)
       

        
        result=model.sample ( y_train_batch,device,
                                 cond_scale=cond_scales[iisample],
                                 timesteps=timesteps,clamp=clamp,
                                  )
        #print ("y_train_batch ", y_train_batch.shape)
        #result= pad_sequence (result,  max_length)
        #print (result.shape)
        result=result.cpu()#.numpy()



        if enforce_symmetry:
            result[:, 3:3+max_length, :]=0.5*(result[:, 3:3+max_length, :]+\
                                              torch.transpose (result[:, 3:3+max_length, :], 1, 2 ) )
        if clamp_round_results:
            result[:, 3:3+max_length, :]=torch.clamp(result[:, 3:3+max_length, :], 0, 1) 
           # X_train_batch[:, 4:4+max_length, :]=torch.clamp (X_train_batch[:, 4:4+max_length, :], 0, 1)

        if discretize:
            result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]< dist_matrix_threshold]=0
            result[:, 3:3+max_length, :] [result[:, 3:3+max_length, :]>= dist_matrix_threshold]=1



        result=result.cpu() 

        print (f"sample result for cond_scale={cond_scales[iisample]}....", result.shape )

        #print (result.shape, GT.shape)
       # X_train_batch= pad_sequence (X_train_batch,  max_length).cpu()
        if xyz_and_graph:
            y_data_coll_pred=[]
            #y_data_coll_GT=[]                

        samples=0
        #print ("############## shape: ", result[samples, 3:3+max_length, :].shape)
        if delete_isolated_nodes:
            result[samples, 3:3+max_length, :]=delete_isolated (result[samples, 3:3+max_length, :])

        if only_continuous:
             result[samples, 3:3+max_length, :]= select_first_continuous(result[samples, 3:3+max_length, :])

        if get_length_from_result:

            #print ("##", torch.nonzero(result[samples, 3:3+max_length, :]).shape,  
            #    result[samples, 3:3+max_length, :])

          #  GTroundL = torch.nonzero(X_train_batch[samples, 4:4+max_length, :])[:,1].flatten().max()+1
            resroundL = torch.nonzero(result[samples, 3:3+max_length, :])[:,1].flatten().max()+1
            #only get length from LENGTH in x direction

        else:
            #GTroundL=max_length
            resroundL=max_length 
        
        
        print (f"sample result for cond_scale={cond_scales[iisample]}....", result.shape )
        print ("Conditiioning: ", y_train_batch)
 
        fig, ax = plt.subplots(1, 1, figsize=(5, 6) , subplot_kw=dict(projection='3d'))

        ax=[ax]

        xs=result[samples, 0, :resroundL]
        ys=result[samples, 1, :resroundL]
        zs=result[samples, 2, :resroundL]
        m=6
        ax[0].scatter(xs, ys, zs, c='red', s=m, marker="o")

        ax[0].set_xlabel('X')
        ax[0].set_ylabel('Y')
        ax[0].set_zlabel('Z')
        ax[0].set_xlim(-1,1)
        ax[0].set_ylim(-1,1)
        ax[0].set_zlim(-1,1)                    


        ax[0].set_title('Prediction')
        if save_img:
                outname = prefix+ f"img_{flag}.png"
                plt.savefig(outname, dpi=300)

        plt.show()

        fig, ax = plt.subplots(1, 1, figsize=(5, 6)  )

        ax=[ax]

        xs=result[samples, 0, :resroundL]
        ys=result[samples, 1, :resroundL]
        zs=result[samples, 2, :resroundL]

        ax[0].plot(xs, ys ,'bo', markersize=m, label='Y over X')
        ax[0].plot(xs, zs,'ro', markersize=m, label='Z over X')

        ax[0].set_xlabel('X')
        ax[0].set_ylabel('Y/Z')
        #ax[1].set_zlabel('Z')
        ax[0].axis('square')
        ax[0].set_xlim(-1,1)
        ax[0].set_ylim(-1,1)
        #ax[1].set_zlim(-1,1)                    
        ax[0].legend()


        ax[0].set_title('Prediction')
        if save_img:
                outname = prefix+ f"img_proj_{flag}.png"
                plt.savefig(outname, dpi=300)

        plt.show()

        
        
        
        if show_neighbors:
            fig, ax = plt.subplots(1, 1, figsize=(5, 6)  )
            ax=[ax]
            ax[0].imshow (result [samples, 3:3+max(resroundL,resroundL), :max(resroundL,resroundL)])
            #ax[0].imshow (X_train_batch [samples, 4:4+max(GTroundL,resroundL), :max(GTroundL,resroundL)])
            #ax[0].set_title('GT')
           # ax[0].grid(False)
           # ax[1].grid(False)
            if save_neigh_img:
                plt.axis('off')
                outname = prefix+ f"neigh_proj_1_{flag}.png"
                plt.savefig(outname, dpi=200)
            ax[0].set_title('Prediction')
            plt.show()

            fig, ax = plt.subplots(1, 1, figsize=(5, 6)  )
            ax=[ax]
            shw1=ax[0].imshow (result [samples, 3:3+max_length, :])
            #shw0=ax[0].imshow (X_train_batch [samples, 4:4+max_length, :]) #X_train_batch[0,4:,:]
            #ax[0].set_title('Prediction')
            #ax[0].set_title('GT')    

            if show_colorbar:
                bar0 = plt.colorbar(shw0)
                bar0.set_label('Prediction')
                #bar1 = plt.colorbar(shw1)
                #bar1.set_label('Prediction')                        
            #ax[0].grid(False)
           # ax[1].grid(False)
            if save_neigh_img:
                plt.axis('off')
                outname = prefix+ f"neigh_proj_2_{flag}.png"
                plt.savefig(outname, dpi=200)
            ax[0].set_title('Prediction')
            plt.show()
            

        if xyz_and_graph:
            G_res, data_res, y_data_pred =construct_xyz_and_graph (result[samples,:3+resroundL,:resroundL].squeeze(),
                                     fname_root=f'{prefix}graph_xyz_{samples}_{flag}_{steps}',
                                                    label='Prediction', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True)
           

            if GED:
                GED=nx.graph_edit_distance (G_res, G_GT)
                print ("Graph edit distance=", GED)
            print (f"")


 
 
    steps=steps+1
    if xyz_and_graph:
        return result[samples,:3+resroundL,:resroundL].squeeze().permute (1,0), G_res, data_res,y_data
    else:
        return result[samples,:3+resroundL,:resroundL].squeeze() 


In [None]:
import networkx as nx
 
from torch_geometric.utils import to_networkx

In [None]:
max_length

In [None]:
import networkx as nx
 
from torch_geometric.utils import to_networkx

def add_edge_to_graph(G, e1, e2, w):
    G.add_edge(e1, e2, weight=w,
              clamp_neighbors=True)
def get_properties(item):
    #print (item.edge_index, item.num_edges)
    #print (item)
    length=0
    dx,dy,dz=0,0,0
    
    for jj in range (item.num_edges):
        dx_=item.pos[item.edge_index[0,jj],0]-item.pos[item.edge_index[1,jj],0]
        dy_=item.pos[item.edge_index[0,jj],1]-item.pos[item.edge_index[1,jj],1]
        dz_=item.pos[item.edge_index[0,jj],2]-item.pos[item.edge_index[1,jj],2]
        
        dx=dx+dx_
        dy=dy+dy_
        dz=dz+dz_
        
        length=length + (dx_**2+dy_**2+dz_**2)**0.5
        
    avg_length = length /item.num_edges     
    dx, dy, dz = dx/item.num_edges, dy/item.num_edges, dz/item.num_edges
        
    num_nodes=item.num_nodes
    num_edges=item.num_edges
    node_degree=item.num_edges / item.num_nodes
    
     
    return avg_length.numpy(),dx.numpy(), dy.numpy(), dz.numpy(), num_nodes , num_edges, node_degree

def construct_xyz_and_graph (result, fname_root='output',
                             clamp_neighbors=True,label='Generated',
                             limits=None,#axis limits for plot, 
                             GT_y=None,
                             dist_matrix=False,
                             
                            ):
    print ("##############################################################################")
    print ("Shape of data provided ", result.shape) 
    print (f"Root file: {fname_root}")
    
    result[ :3, :]=unscale_data(result[:3,: ], X_max.numpy(),X_min.numpy())

    
    #result=result.numpy()
    xs=result[ 0, :]
    ys=result[ 1, :]
    zs=result[ 2, :]
    
    with open(fname_root+'.xyz', 'w') as f:
        f.write('#ID x y z \n')
        for i in range (result.shape[1]):
            f.write(f'{i+1} {xs[i]} {ys[i]} {zs[i]} \n')
    
    
    #G=nx.Graph()
    node_list=[]
    neighbor_list=[]
    point_list=[]
    
    
    #now prepare graph
    for i in range (result.shape[1]):
        node_list.append ( [xs[i], ys[i], zs[i] ])
        point_list.append ( (xs[i], ys[i], zs[i]  ))
        
        if not dist_matrix:
            #neighbors are stored in result [4,5,..., 9]
            for j in range (max_neighbors):
                #neigh_j=i-int (result[ 3+j, i] )   +1
                neigh_j=result[ 3+j, i] 
                #print (neigh_j)
                if neigh_j !=0: #zeros are padded values...
                    #f neigh_j !=0:
                    neighn=  neigh_j-1 #neigh_j is 1+node number (since it encodes 0s....as padding)

                    if clamp_neighbors:
                        neighn=max( 0, min (neighn, result.shape[1]-1) )


                    neighbor_list.append ([i,neighn] )  #val=node1- node2 +1  --> node2= node1- val  +1
                    
        if dist_matrix:
            for j in range (result.shape[1]):
                #neigh_j=i-int (result[ 3+j, i] )   +1
                neigh_j=result[ 3+j, i] 
                #print (neigh_j)
                if neigh_j >0.5: #zeros are padded values... a value of >0.5 means a neighbor is found
                    #f neigh_j !=0:
                    neighn=  j

                    neighbor_list.append ([i,neighn] )  #val=node1- node2 +1  --> node2= node1- val  +1
                
                #add_edge_to_graph(G, i, neighn, 1)
    #pos = {point: point for point in point_list}
    #print (pos)

    with open(fname_root+'.dump', 'w') as f:
        f.write(f'\n{result.shape[1]} atoms\n{len (neighbor_list)} bonds  \n\n1 atom types\n1 bond types\n\n')
        f.write(f'{min(xs[:])*1.1} {max(xs[:])*1.1} xlo xhi \n{min(ys[:])*1.1} {max(ys[:])*1.1} ylo yhi \n{min(zs[:])*1.1} {max(zs[:])*1.1} zlo zhi \n\n')
        f.write(f'Masses \n\n1 100.00  \n\n')
        f.write(f'Atoms  \n\n')


        for i in range (result.shape[1]):
            f.write(f'{i+1} 1 1 {xs[i]} {ys[i]} {zs[i]} \n')
        f.write(f'\nBonds  \n\n')

        for i in range (len (neighbor_list)):
            f.write(f'{i+1} 1 {neighbor_list[i][0]+1} { max( 1, min (neighbor_list[i][1], result.shape[1]) )}   \n')
        f.write(f'\n\n')
    

    #rint (node_list)
    #print (neighbor_list)
    
    neighbor_list=torch.Tensor (neighbor_list).long()
    print ("neighborlist shape: ", neighbor_list.shape)
    #print (neighbor_list)
    
    fig = plt.figure(figsize=(8,8))
    m=24
    ax = fig.add_subplot(projection='3d')
    ax.scatter(xs, ys, zs,c='red', s=m, marker="o")
    ax.set_proj_type('ortho')
    ax.set_title (label)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    print ("STATS of neighbor list: max ", neighbor_list.max(), "min ",  neighbor_list.min(), "shape ", neighbor_list.shape)
    
    
    for i in range (len (neighbor_list)):
            
            N1=neighbor_list[i][0] # because N1 and N2 refers to array indices, not note numbers
            
            N2=max( 0, min (neighbor_list[i][1], result.shape[1]-1) ) 
            #print (N1, N2)
            ax.plot3D ([xs[N1], xs[N2]], [ys[N1], ys[N2]] , [zs[N1], zs[N2]], 'k-', linewidth=1, )
           
            
    if limits != None:
        ax.set_xlim(limits[0],limits[1])
        ax.set_ylim(limits[0],limits[1])
        ax.set_zlim(limits[0],limits[1])    
        
    plt.savefig (fname_root+'.png', dpi=400)
    plt.savefig (fname_root+'.svg', dpi=400)
    plt.show()
    
    print ("Neighborlist shape COO^T format: ", neighbor_list.shape)
    x = torch.tensor(node_list, dtype=torch.float)  #Node feature matrix with shape [num_nodes, num_node_features]
    edge_index =neighbor_list.permute(1,0)  #Graph connectivity in COO format with shape [2, num_edges] and type torch.long
    y = None #torch.tensor(graph_label, dtype=torch.float)  #Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
    data = Data(x=x,pos=x, edge_index=edge_index, y=y)
    
    torch.save (data, fname_root+'.pt')
    
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    G = to_networkx(data, to_undirected=False)
    
    print ("##############################################################################")
    
    
    #Now calculate graph properties
    
    avg_length,dx, dy, dx, num_nodes , num_edges, node_degree=get_properties(data)
    
    y_data=np.array([avg_length,dx, dy, dx, num_nodes , num_edges, node_degree])
    
    print (f"Graph properties measured: {labels_y_txt}\n",y_data)
    if exists (GT_y):
        print (f"Graph properties GT: {labels_y_txt}\n", GT_y)
        
        plt.plot ( y_data, GT_y, '.',label='Graph properties (GT vs predicted)',markersize=12 )
        plt.legend()
        plt.xlabel ('GT')
        plt.ylabel ('Predicted')
        min_v,max_v=min (min(y_data), min (GT_y)),max (max(y_data), max (GT_y))
        plt.plot([min_v, max_v], [min_v, max_v], ls="--", c=".3")
        plt.axis ('square')
        plt.show()
    return G,data, y_data
    #visualize_graph(G, color=data.y)  

### Trainer

In [None]:
loss_list=[]

In [None]:
prefix='./AnalogDiffusionFull/'
if not os.path.exists(prefix):
        os.mkdir (prefix)

In [None]:
 max_length

In [None]:
max_neighbors

In [None]:
max_length_diff=64
max_length=max_length_diff
max_length

In [None]:
sns.set_style("whitegrid")

In [None]:
from GraphDiffusion import AnalogDiffusionFull, count_parameters,pad_sequence 

predict_neighbors=True
pred_dim=3+ max_length

context_embedding_max_length=y_data.shape[1]
model =AnalogDiffusionFull( 
            max_length=max_length,
            pred_dim=pred_dim,
            channels=256,
            unet_type='cfg', #'base', #'cfg',
            context_embedding_max_length=context_embedding_max_length,
            pos_emb_fourier=True,
            pos_emb_fourier_add=False,
            text_embed_dim = 256,
            embed_dim_position=256,
            predict_neighbors=predict_neighbors,
                    )  .to(device)  

count_parameters (model)
 
optimizer = optim.Adam(model.parameters() , lr=0.0002 )

In [None]:
train_loop (model,
                train_loader,test_loader,
                optimizer=optimizer,
                print_every=100,
                epochs= 3000000,
                start_ep=0,
                start_step=0,
                train_unet_number=1,
                print_loss =  100* (len (train_loader)-1),
                plot_unscaled=False,#if unscaled data is plotted
                save_model=True,
                cond_scales=[1],#[1, 2.5, 3.5, 5., 7.5, 10., 15., 20.],
                num_samples=4,
                timesteps=150,clamp=True,corplot=False,
                save_loss_images=False,show_neighbors=True,
                xyz_and_graph=True,
                dist_matrix_threshold=0.99,
                discretize=True,
                get_length_from_result=True,
               )

In [None]:
fname=f'{prefix}/statedict_save-model-epoch_2001.pt' #lowest loss mode
model.load_state_dict(torch.load(fname))

In [None]:
sample_loop (model,device,
                test_loader,
                cond_scales=[1,  ], #list of cond scales - each sampled...
                num_samples=16, #how many samples produced every time tested.....
                timesteps=150,clamp=True, corplot=False,show_neighbors=True,
           xyz_and_graph=True, clamp_round_results=True, enforce_symmetry=False,discretize=True,
            get_length_from_result=True,dist_matrix_threshold=0.25, )

In [None]:
generate_sample_cond (model,
                cond=[-.5, 0.7, 0.65],
                cond_scales=[1.1,  ], #list of cond scales - each sampled...
                flag=9992,
                timesteps=100,clamp=True, corplot=True,show_neighbors=True,
            xyz_and_graph=True)

### Generate hierarchical structures

In [None]:
def get_length (X1):
    return torch.nonzero(X1[:,3:]).flatten().max()+1
    

def get_length_xyzCOO (result):
    return torch.nonzero(result[3 ,: ])[-1]+1 #first neighbor matters...
    
def get_xyz_and_dist_matrix_fromxyzCOO (X_data_cl, clamp_neighbors=False ,
                            visualize=False, max_length=None):

    if not exists (max_length):
        max_length=get_length_xyzCOO (X_data_cl)
    
    X_data_cl=X_data_cl.permute (1,0)#bring to [max_length, xyz+max_neighbors]
    print (X_data_cl.shape)
        
    max_neighbors=X_data_cl.shape[1]-3 
    print ("Max neighbors: ", max_neighbors, "Length: ", max_length)
    dist_matrix=torch.zeros (max_length, max_length)#.to(device)
    
    for i in range (max_length):

        #neighbors are stored in result [4,5,..., 9]
        for j in range (max_neighbors):
           
            neigh_j= X_data_cl[ i,3+j] 

            if neigh_j !=0: #zeros are padded values...

                neighn=  neigh_j.long()

                if clamp_neighbors:
                    neighn=max( 1, min (neighn, max_length) )
                    

                dist_matrix[i, neighn-1] = 1
                dist_matrix[neighn-1, i] = 1

    print (dist_matrix.shape,X_data_cl[:max_length, :3].shape )
    output=  torch.cat( (X_data_cl[:max_length, :3],   dist_matrix), 1)
    
    print ("Shape of new matrix (length, x,y,z+length+): ", output.shape)
    
        
    if visualize:
        
        plt.imshow (output[:, 3:3+max_length])
        plt.show()
    return output
    
def shift (X2, dx, dy, dz):
    X2[:, 0] =X2[:, 0] + dx
    X2[:, 1] =X2[:, 1] + dy
    X2[:, 2] =X2[:, 2] + dz
    return X2
    
def stack_and_extend (X1, X2, stack_node, dx,dy,dz, 
                     fname_root='file_name',
                      visualize=False,
                      make_graph = False,
                      avg_pos_in_overlap=False,
                      
                     ): #format: (dist matrix x distmatrix+xyz)
    #stacks two  graphs and overlaps them
    #stacknode determines where in X1 is second one added
    
    S1=get_length(X1) #X1.shape[0]
    S2=get_length(X2) #X2.shape[0]

    
    S_new=stack_node+S2
 
    
    dist_matrix=torch.zeros (S_new, S_new+3)#.to(device)
    
  
      
    X2[:, 0] =X2[:, 0] + dx
    X2[:, 1] =X2[:, 1] + dy
    X2[:, 2] =X2[:, 2] + dz


    
    
    dist_matrix[:S1, 3:3+S1]= X1 [:S1, 3:3+S1]
    dist_matrix[stack_node:stack_node+S2, 3+stack_node:3+stack_node+S2]= X2 [:S2, 3:3+S2]
    
    #now average positions
    
    dist_matrix[:S1, :3]= X1[:, :3]
    if avg_pos_in_overlap:
        dist_matrix[stack_node:stack_node+S2, :3]= dist_matrix[stack_node:stack_node+S2, :3]+X2[:, :3]
       # dist_matrix[stack_node:stack_node+S2, :3]= dist_matrix[stack_node:stack_node+S2, :3]+X2[:, :3]
        
        
    else:
        dist_matrix[stack_node:stack_node+S2, :3]=  X2[:, :3]
    
    if avg_pos_in_overlap: 
        dist_matrix[stack_node:S1, :3]= dist_matrix[stack_node:S1, :3]/2.
    

 
    if visualize:
        
        plt.imshow (dist_matrix[:S_new, 3:3+S_new])
        plt.show()
        
        plt.imshow(X1[:,:3], aspect=.1)
        plt.show()
        plt.imshow(X2[:,:3], aspect=.1)
        plt.show()
        plt.imshow(dist_matrix[:,:3], aspect=.1)
        plt.show()
        
    if make_graph:
        G_res, data_respr, y_data_pred =construct_xyz_and_graph (dist_matrix.permute(1,0),
                                     fname_root=fname_root,
                                                    label='Stacked and extended', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True,
                                                  )
    
    return dist_matrix 
    
    

    
def make_spiral (X1, radius, 
                 slope_z, delta_angle, steps,  
                stagger_fraction = 0.8,
                fname_root='fname',
                 visualize=True,visualize_all=False,
                 shuffle=False,#shuflle graph
                 
                 avg_pos_in_overlap=False,
                 
                 generate_new_every_iteration=False,
                 cond_vector_list=None,
                 length_cond_vector=7,
                 is_COO=False, #whether model predicts COO (sparse) or not
                )   :

    S1=get_length(X1)
    delta_stagg=S1*stagger_fraction
    
    spiral = torch.clone (X1)
    
    i=-1
    dx= radius * math.cos ((i+1)*delta_angle)
    dy= radius * math.sin ((i+1)*delta_angle)
    dz= slope_z *(i+1)

    spiral = shift (spiral, dx, dy, dz)
    for i in tqdm (range (steps)):
        
        
        if generate_new_every_iteration:
            if exists (cond_vector_list):
                #[-0.7255, -0.2038,  0.5942,  0.3315,  0.286,  0.1852,  0.522]
                cond_v=cond_vector_list[i]
            else:
                cond_v=1.5*torch.rand (length_cond_vector)-0.75
            result, G_res, data_res,y_data= generate_sample_cond (GWebT,
                cond=cond_v,#[-.5, 0.7, 0.6, 0.2, 0.3, 0.4, -0.9],
                cond_scales=[1.,  ], #list of cond scales - each sampled...
                flag=9992,
              clamp=True, corplot=True,show_neighbors=True,
            xyz_and_graph=True)
            
            
            if is_COO:
                X1=get_xyz_and_dist_matrix_fromxyzCOO (result, clamp_neighbors=True ,
                                            visualize=False)#convert xyz-COO coding to xyz-distance matrix
            else:
                X1=result
            S1=get_length(X1)
            delta_stagg=S1*stagger_fraction
    


        if shuffle:
            
            if visualize_all:
                plt.imshow (X1[:S1, 3:3+S1])
                plt.show()               
            ar=torch.range (0,S1-1).long()
            ar2=torch.range (0,S1-1+3).long()
            c=torch.randperm(S1)
            c2=torch.cat( (torch.range (0,2), c+3)).long()
           
            X1=torch.cat( ( X1[:,:3],  X1.clone () [ar][c,3:]), 1)
           
            X1=X1[:,ar2][:,c2]
            
        

            if visualize_all:
                plt.imshow (X1[:S1, 3:3+S1])
                plt.show()     
    
            
            #a=a[:,torch.randperm(a.size()[1])]
            
        stack_node=int (get_length(spiral)-delta_stagg )
        
        #print ("stack node: ", stack_node, i)
    
        dx= radius * math.cos ((i+1)*delta_angle)
        dy= radius * math.sin ((i+1)*delta_angle)
        dz= slope_z *(i+1)
        
        #print (dx, dy, dz)
        spiral=stack_and_extend (spiral.clone(), X1.clone(), stack_node, dx,dy,dz, 
                      visualize=False,
                      make_graph = False,avg_pos_in_overlap=avg_pos_in_overlap,
                     )
        #print (spiral.shape)
        spiral[:, 3: ]=torch.clamp(spiral[:, 3: ], 0, 1) 
        
        #print ("#### length spiral ", get_length(spiral), spiral.shape, i)
        
    S1=get_length(spiral)
    
    
    #print ("Size of spiral; ", S1)
    plt.imshow (spiral[:S1, 3:3+S1],interpolation='none')
    plt.show()     
    
    plt.plot (spiral[:, 0])
    plt.plot (spiral[:, 1])
    plt.plot (spiral[:, 2])
    plt.show()
    
    plt.imshow (spiral[:S1, :3],interpolation='none')
    plt.show()     
   
        
    
    G_res, data_res, y_data_pred =construct_xyz_and_graph (spiral.permute(1,0),
                                     fname_root=fname_root,
                                                    label='Spiral', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True,
                                                  )
    
    if visualize:
        
        plt.imshow (spiral[:, 3:],interpolation='none')
        plt.show()
    #x(t) = rcos(t), y(t) = rsin(t), z(t) = at,
    
    return spiral
    

In [None]:
output, G_res, data_res,y_data= generate_sample_cond (model,
                cond=[-0.7255, -0.2038,  0.5942,  0.3315,  0.1786,  0.1852,  0.0122],#[-.5, 0.7, 0.6, 0.2, 0.3, 0.4, -0.9],
                cond_scales=[1.,  ], #list of cond scales - each sampled...
                flag=9992,
              clamp=True, corplot=True,show_neighbors=True,
           xyz_and_graph=True, clamp_round_results=True, enforce_symmetry=True,discretize=True,
            get_length_from_result=True,dist_matrix_threshold=0.25,
             )

In [None]:
print(output.shape)
plt.imshow (output[:,3:])
plt.show()

get_length (output)

In [None]:
                      
output2=stack_and_extend (output.clone(), output.clone(), 10, 100,60,50,  #
                     fname_root='file_name200',
                      visualize=True,
                      make_graph = True,
                                            
                     )#{max_neighbors, xyz+max_neighbors}

In [None]:
spiral=make_spiral (output, radius=20, 
                 slope_z=5, delta_angle= 10/360*2*math.pi, steps=80,  
                stagger_fraction = 0.5,
                fname_root='spiral_v900',shuffle=False,
                    avg_pos_in_overlap=False,
                )

### Build attractor structure 

In [None]:

def get_xyz_attractor (t, px=1.5, py=1.5, pz=1.5, dx=2, dy= 2., scale=1):
    
    x=(dx+math.cos(px*t))*math.cos(t)
    y=(dy+math.cos(py*t))*math.sin(t)
    z=math.sin(pz*t)
    return x*scale, y*scale, z*scale


In [None]:
pts_1 = {'x':[],'y':[],'z':[]};
for t in np.arange(0,50,0.005):
 
 
    x=(math.sin(3*t))+math.cos(t)
    y=(math.sin(t))*math.sin(5*t)
    z=math.sin(2*t)
    pts_1['x'].append(x)
    pts_1['y'].append(y)
    pts_1['z'].append(z)


In [None]:
ax = plt.axes(projection='3d')
# Data for a three-dimensional line
ax.plot3D(pts_1['x'], pts_1['y'], pts_1['z'], 'gray')
# Data for three-dimensional scattered points
ax.scatter3D(pts_1['x'], pts_1['y'], pts_1['z'], c=pts_1['z'], cmap='Greens');
plt.show()

In [None]:
pts_2 = {'x':[],'y':[],'z':[]};
t_r=np.arange(0,10,0.1)
print (t_r.shape)
for t in t_r:

    x,y,z= get_xyz_attractor (t, px=1.5, py=1.5, pz=1.5, dx=2, dy= 2., scale=20)
    
   
    pts_2['x'].append(x)
    pts_2['y'].append(y)
    pts_2['z'].append(z)
    
ax = plt.axes(projection='3d')

ax.scatter3D(pts_2['x'], pts_2['y'], pts_2['z'], c=pts_2['z'], cmap='Greens');

In [None]:
def make_attractor (X1,   t,   scale=5,
                stagger_fraction = 0.8,
                fname_root='fname',
                 visualize=True,visualize_all=False,
                 shuffle=False,#shuflle graph
                 
                 avg_pos_in_overlap=False,
                 
                 generate_new_every_iteration=False,
                 cond_vector_list=None,
                 length_cond_vector=7,
                 is_COO=False, #whether model predicts COO (sparse) or not
                )   :

    S1=get_length(X1)
    delta_stagg=S1*stagger_fraction
    
    spiral = torch.clone (X1)
    

    dx, dy, dz =  get_xyz_attractor (t[0], px=1.5, py=1.5, pz=1.5, dx=2, dy= 2., scale=scale)
    
    spiral = shift (spiral, dx, dy, dz)
    for i in tqdm (range (len (t)-1)):
        
        
        if generate_new_every_iteration:
            if exists (cond_vector_list):
                #[-0.7255, -0.2038,  0.5942,  0.3315,  0.286,  0.1852,  0.522]
                cond_v=cond_vector_list[i]
            else:
                cond_v=1.5*torch.rand (length_cond_vector)-0.75
            result, G_res, data_res,y_data= generate_sample_cond (model,
                cond=cond_v,#[-.5, 0.7, 0.6, 0.2, 0.3, 0.4, -0.9],
                cond_scales=[1.,  ], #list of cond scales - each sampled...
                flag=9992,
              clamp=True, corplot=True,show_neighbors=True,
            xyz_and_graph=True,
                                                                 
                clamp_round_results=True, enforce_symmetry=True,discretize=True,
                get_length_from_result=True,dist_matrix_threshold=0.25,
                                                           
                                                              )
            
            
            if is_COO:
                X1=get_xyz_and_dist_matrix_fromxyzCOO (result, clamp_neighbors=True ,
                                            visualize=False)#convert xyz-COO coding to xyz-distance matrix
            else:
                X1=result
            S1=get_length(X1)
            delta_stagg=S1*stagger_fraction
    


        if shuffle:
            
            if visualize_all:
                plt.imshow (X1[:S1, 3:3+S1])
                plt.show()               
            ar=torch.range (0,S1-1).long()
            ar2=torch.range (0,S1-1+3).long()
            c=torch.randperm(S1)
            c2=torch.cat( (torch.range (0,2), c+3)).long()
            #print (X1.shape,"S1", S1, "v ", ar, ar2, c, c2)
            X1=torch.cat( ( X1[:,:3],  X1.clone () [ar][c,3:]), 1)
            #print (X1.shape)
            X1=X1[:,ar2][:,c2]
            
 

            #t1[c][:,c2]

            if visualize_all:
                plt.imshow (X1[:S1, 3:3+S1])
                plt.show()     
    
            
            #a=a[:,torch.randperm(a.size()[1])]
            
        stack_node=int (get_length(spiral)-delta_stagg )
        
        
        dx, dy, dz = get_xyz_attractor(t[i+1],px=1.5, py=1.5, pz=1.5, dx=2, dy= 2., scale=scale)
        
        #print (dx, dy, dz)
        spiral=stack_and_extend (spiral.clone(), X1.clone(), stack_node, dx,dy,dz, 
                      visualize=False,
                      make_graph = False,avg_pos_in_overlap=avg_pos_in_overlap,
                     )
        #print (spiral.shape)
        spiral[:, 3: ]=torch.clamp(spiral[:, 3: ], 0, 1) 
       
    S1=get_length(spiral)
    
    
    #print ("Size of spiral; ", S1)
    plt.imshow (spiral[:S1, 3:3+S1],interpolation='none')
    plt.show()     
    
    plt.plot (spiral[:, 0])
    plt.plot (spiral[:, 1])
    plt.plot (spiral[:, 2])
    plt.show()
    
    plt.imshow (spiral[:S1, :3],interpolation='none')
    plt.show()     
   
        
    
    G_res, data_res, y_data_pred =construct_xyz_and_graph (spiral.permute(1,0),
                                     fname_root=fname_root,
                                                    label='Spiral', #limits=[X_min*1.2, X_max*1.2],
                                                    dist_matrix=True,
                                                  )
    
    if visualize:
        
        plt.imshow (spiral[:, 3:],interpolation='none')
        plt.show()
    #x(t) = rcos(t), y(t) = rsin(t), z(t) = at,
    
    return spiral

In [None]:
t_r=np.arange(0,10,0.05)
print (t_r.shape)
pts_2 = {'x':[],'y':[],'z':[]}
for t in t_r:
 
    #x=(6+math.cos(3.5*t))*math.cos(t)
    #y=(6+math.cos(3.5*t))*math.sin(t)
    #z=math.sin(1.5*t)
    
    x,y,z= get_xyz_attractor (t, px=1.5, py=1.5, pz=1.5, dx=2, dy= 2., scale=5)
    
   
    pts_2['x'].append(x)
    pts_2['y'].append(y)
    pts_2['z'].append(z)
    
ax = plt.axes(projection='3d')
 
ax.scatter3D(pts_2['x'], pts_2['y'], pts_2['z'], c=pts_2['z'], cmap='Greens');

In [None]:
make_attractor (output,  t_r,  scale=10,
                stagger_fraction = 0.5,
                fname_root='attractor_v901',
                 visualize=True,visualize_all=False,
                 shuffle=True,#shuflle graph
                 
                 avg_pos_in_overlap=False,
                 
                 generate_new_every_iteration=False,
                 cond_vector_list=None,
                 length_cond_vector=7, 
                 is_COO=False, #whether model predicts COO (sparse) or not
                ) 