In [None]:
##NOTE - INPUT IS REQUIRED IN THIS SECTION
#=============================================================================
#Input data csv must be saved to your Google Drive, and the file path
#and name below updated as needed, in order for this program to run.

##INPUT DATA
#============
#Root directory (leave blank for local PC run):
file_path = ''#'/content/drive/My Drive/Colab Notebooks/GNN_Geomodeling/'
#Input drillhole training data (each is optional, but must at least provide an empty file with headers):
input_rock_unit_filename = 'Input_Data/Folded_Rock_Unit.csv'#'Input_Data/Simp_Drillhole_data.csv'#
input_geologic_contact_filename = 'Input_Data/Folded_Geologic_Contacts.csv'
input_orientations_filename = 'Input_Data/Folded_Orientations.csv'
#'Input_Data/MSOP_drillhole_data.csv'
#'Input_Data/GeoLogic_int_drillhole_data.csv'

#Input data attributes (Column names)
#X,Y,Z must have the same name in all 3 input files
input_name_x = "X"
input_name_y = "Y"
input_name_z = "Z"
#rock unit:
input_name_rock_unit = "RockUnit"
#geologic contacts:
input_name_geologic_contact = "FieldValue"
#orientations:
input_name_x_vec = "XVec"
input_name_y_vec = "YVec"
input_name_z_vec = "ZVec"

##MESH SETTINGS
#===============
#Mesh extents and size in X,Y,Z (if less than the above dataset extents, data will be filtered)
min_extents = [1000,2300,1690]#[6200,7000,1400]##[2400,5100,2100]
max_extents = [1100,2480,1770]#[6800,7400,1700]##[3200,5600,2700]
max_volume = 50#1000# #max volume of each tetrahedra in the tetrahedral mesh

#Set x-slice for mesh/graph visualization on 3D plots
min_viewing_x = 1040#6725#2825#
max_viewing_x = 1050#6745#2845#

##NEURAL NETWORK MODEL SETTINGS
#==============================
mUseGPUIfAvailable = True
# Neural network parameters
mNumEpoch = 800
mNumHiddenLayerNeuron = 128
# Relative weighting of the three input types in the loss function, does not need to sum to 1:
mRockUnitWeightFactor = 0.2
mGeologicContactWeightFactor = 0.2
mOrientationsWeightFactor = 0.6
# Proportions of data for training vs validation dataset, should sum to 1:
mTrainingPercent = 0.85
mValidationPercent = 0.15

##OUTPUT SETTINGS
#=================
#Optionally save output - Set flags for whether to save graph and/or model
save_GNN_graph = False
save_GNN_graph_filename = 'Saved_Output/GeoLogic_pytorch_geometric_data.pth'
#'Saved_Output/MSOP_pytorch_geometric_data.pth'
save_model = False
save_model_filename = 'Saved_Output/GeoLogic_pytorch_geometric_model.pth'
#'Saved_Output/MSOP_pytorch_geometric_model.pth'


# **Configure & install dependencies**

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
# We assume that PyTorch is already installed
import torch
#torchversion = torch.__version__

# Install PyTorch Scatter, PyTorch Sparse, and PyTorch Geometric
#!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torchversion}.html
#!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torchversion}.html
#!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

#!pip install -q meshpy

In [None]:
print(torch.cuda.is_available())
device = torch.device("cuda:0" if (torch.cuda.is_available() and mUseGPUIfAvailable) else "cpu")
print(device)

In [None]:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from meshpy.tet import MeshInfo, build
from meshpy.geometry import GeometryBuilder, Marker, make_box
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import math
import sys
from sklearn.metrics import confusion_matrix
import seaborn as sns
from collections import Counter
import torch.nn.functional as F
from torch_geometric.data import Data
from torch.nn import Linear, Dropout
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
from scipy.spatial import cKDTree


# **Mesh generation**
___________________

**Meshing Helper Functions:**

In [None]:
#Function to generate edge list from tetrahedra
def generate_edges_from_ele(element_array):
    edges = set()
    for element in element_array:
        element = sorted(element)
        for i in range(len(element)):
            for j in range(i+1,len(element)):
                edge = tuple([element[i],element[j]])
                edges.add(edge)
    return np.array(list(edges))

In [None]:
def edge_length_stats(edges,nodes):
  edge_lengths = []
  for edge in edges:
    # Calculate the Euclidean distance between the nodes
    length = np.linalg.norm(nodes[edge[0]] - nodes[edge[1]])
    # Append the edge length to the list
    edge_lengths.append(length)

  # Calculate min, max, and mean edge lengths
  min_length = np.min(edge_lengths)
  max_length = np.max(edge_lengths)
  mean_length = np.mean(edge_lengths)

  #print
  print("min edge length: %.3f" % min_length)
  print("max edge length: %.3f" % max_length)
  print("mean edge length: %.3f" % mean_length)
  return edge_lengths

In [None]:
#plot histogram
def histo(data):
  # Plotting a histogram
  plt.hist(data, bins=30, edgecolor='black')  # Adjust the number of bins as needed
  plt.title('Histogram')
  plt.xlabel('Edge Length')
  plt.ylabel('Frequency')
  plt.grid(True)
  plt.show()



In [None]:
def visualize(mesh_points,mesh_edges):
  # Access the mesh points and elements
  mesh_points = np.array(mesh.points)

  # Visualize the tetrahedral mesh using matplotlib
  fig = plt.figure()
  ax = fig.add_subplot(111, projection='3d')

  # Plot the mesh edges
  edge_points = [(mesh_points[edge[0]], mesh_points[edge[1]]) for edge in mesh_edges]
  edge_collection = Line3DCollection(edge_points, color='g')  # You can adjust the color here
  ax.add_collection3d(edge_collection)

  # Plot the mesh points
  ax.scatter(mesh_points[:, 0], mesh_points[:, 1], mesh_points[:, 2], color='r', marker='o')

  # Set the extents of the plot (change the values accordingly)
  ax.set_xlim(min_extents[0], max_extents[0])
  ax.set_ylim(min_extents[1], max_extents[1])
  ax.set_zlim(min_extents[2], max_extents[2])

  ax.set_xlabel('X')
  ax.set_ylabel('Y')
  ax.set_zlabel('Z')
  plt.show()

**Meshing main:**

In [None]:
points, facets, _, facet_markers = make_box(min_extents, max_extents)

mesh_info = MeshInfo()
mesh_info.set_points(points)
mesh_info.set_facets(facets)

mesh = build(mesh_info, max_volume=max_volume,volume_constraints=True, attributes=False, insert_points=1)

In [None]:
#retrieve true edge list and run stats
mesh_faces = np.array(mesh.faces)
mesh_elements = np.array(mesh.elements)
mesh_edges = generate_edges_from_ele(mesh_elements)
mesh_points = np.array(mesh.points)
edge_lengths = edge_length_stats(mesh_edges,mesh_points)
histo(edge_lengths)

In [None]:
#print initial mesh
print("%d points" % len(mesh_points))
print("%d edges" % len(mesh_edges))
print("%d tetrahedra" % len(mesh_elements))
print("%d exterior faces" % len(mesh_faces))

# **Import Geo Data**
_______________
3 types of input data: rock unit, geologic contacts, and orientations

In [None]:
#import input data
df_rock_unit = pd.read_csv(file_path + input_rock_unit_filename)
df_geologic_contacts = pd.read_csv(file_path + input_geologic_contact_filename)
df_orientations = pd.read_csv(file_path + input_orientations_filename)

In [None]:
#Preview input data
print("ROCK UNIT DATA PREVIEW:")
print(df_rock_unit.head(10))
print("GEOLOGIC CONTACTS DATA PREVIEW:")
print(df_geologic_contacts.head(10))
print("ORIENTATIONS DATA PREVIEW:")
print(df_orientations.head(10))

In [None]:
#Accepts Pandas dataframe and filters based on max ans min x,y,z
def FilterInputData(df):
  return df[
    (df[input_name_x].between(min_extents[0], max_extents[0])) &
    (df[input_name_y].between(min_extents[1], max_extents[1])) &
    (df[input_name_z].between(min_extents[2], max_extents[2]))
    ]

In [None]:
# reduce imported data based on the specified extents
df_rock_unit = FilterInputData(df_rock_unit)
print("%d input rock unit data points within specified limits" % len(df_rock_unit))
df_geologic_contacts = FilterInputData(df_geologic_contacts)
print("%d input geologic contact measurements within specified limits" % len(df_geologic_contacts))
df_orientations = FilterInputData(df_orientations)
print("%d input orientation measurements within specified limits" % len(df_orientations))

# **Visualize Drillhole Data with Mesh**

In [None]:
#plot interactively with an optional mesh slice
def interactive_visualize(rock_unit_data, geologic_contact_data=None, orientation_data=None, mesh_points=None, mesh_edges=None, predicted_labels=None):

  #Plot rock unit data
  fig = px.scatter_3d(rock_unit_data, x=input_name_x, y=input_name_y, z=input_name_z, color=input_name_rock_unit)

  #Plot geologic contact data
  if geologic_contact_data is not None:
    geologic_contact_trace = px.scatter_3d(geologic_contact_data, x=input_name_x, y=input_name_y, z=input_name_z, color=input_name_geologic_contact).data[0]
    geologic_contact_trace.marker.symbol = 'x'  # Change marker symbol to 'x'
    geologic_contact_trace.marker.size = 4
    fig.add_trace(geologic_contact_trace)

  #Plot orientation measurement data
  if orientation_data is not None:
    orientation_trace = px.scatter_3d(orientation_data, x=input_name_x, y=input_name_y, z=input_name_z).data[0]
    orientation_trace.marker.symbol = 'diamond-open'
    orientation_trace.marker.size = 4
    fig.add_trace(orientation_trace)

  #Plot mesh points in viewing slice
  if mesh_points is not None:
    if predicted_labels==None:
      predicted_labels = np.zeros(len(mesh_points))
    # Creating a DataFrame from the mesh points
    df_mesh = pd.DataFrame({
        'x': mesh_points[:, 0],
        'y': mesh_points[:, 1],
        'z': mesh_points[:, 2],
        input_name_rock_unit: predicted_labels.cpu()

    })

    # Filter points to a slice in y-z plane
    mesh_slice = df_mesh[(df_mesh['x'] >= min_viewing_x) & (df_mesh['x'] <= max_viewing_x)]

    # Adding new points to the existing figure
    fig.add_trace(
        px.scatter_3d(mesh_slice, x='x', y='y', z='z', color=input_name_rock_unit).data[0]
    )
    fig.update_traces(marker=dict(size=5))  # Change the marker size here
    fig.update_layout(title='Interactive 3D Plot')

  #plot mesh edges in viewing slice
  if mesh_points is not None and mesh_edges is not None:
    for edge in mesh_edges:
          for point in edge:
              add_point = True
              point_coords = mesh_points[point]
              if (point_coords[0] < min_viewing_x or point_coords[0] > max_viewing_x):
                  add_point = False
                  break
          if add_point:
              point1, point2 = edge
              point1 = mesh_points[point1]
              point2 = mesh_points[point2]
              x_vals = [point1[0],point2[0]]
              y_vals = [point1[1],point2[1]]
              z_vals = [point1[2],point2[2]]

              fig.add_trace(go.Scatter3d(
                  x=x_vals, y=y_vals, z=z_vals,
                  mode='lines',
                  line=dict(color='grey', width=2),
                  name=''
              ))

  fig.show()

In [None]:
#visualize input data
interactive_visualize(rock_unit_data=df_rock_unit, geologic_contact_data=df_geologic_contacts, orientation_data=df_orientations)

In [None]:
#visualize slice of the initial mesh before adjusting for input data
#interactive_visualize(filtered_data, mesh_points, mesh_edges)


# Setup Torch Geometric Dataset Data

In [None]:
def euclidean_dist_sq(x1,y1,z1,x2,y2,z2):
  return ((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1))

In [None]:
#Method adjusts the mesh such that the closest mesh node to every input data location is adjusted to coincide with the input data location
#Method also assigns training data and validation data using masks
def AdjustAndLabelMesh(tree, mesh_points, input_data_locations, input_labels, labels, test_mask, train_mask):

    if mValidationPercent > 0:
        skip_train = round((mTrainingPercent + mValidationPercent) / mValidationPercent)  #assign every nth node to validation dataset instead of train dataset
    else:
        skip_train = sys.maxsize

    labelled_nodes = np.zeros(len(mesh_points))

    labelled_count = 0
    skipped_count = 0
    mean_adjust_dist = 0

    for data_point,data_label in zip(input_data_locations, input_labels):
        min_dist, min_index = tree.query(data_point, k=1)  # Find the nearest neighbor index and distance

        if labelled_nodes[min_index] == 0:
            # Update the original mesh_points array at min_index
            mesh_points[min_index] = [data_point[0], data_point[1], data_point[2]]
            labels[min_index] = data_label #assign class as label
            labelled_nodes[min_index] = 1
            labelled_count += 1
            #update train and test masks
            if labelled_count % skip_train == 0:
                test_mask[min_index] = True
            else:
                train_mask[min_index] = True
            mean_adjust_dist += min_dist
        else:
            skipped_count += 1

    mean_adjust_dist /= labelled_count

    print("Total number of nodes labelled: {0}".format(labelled_count))

    print("Total number of input points skipped: {0}".format(skipped_count))
    #temporary until meshing is updated to honor all input data points during creation

    print("Mean mesh adjustment: {:.2f}".format(mean_adjust_dist))

In [None]:
#Adjust mesh and assign graph labels and masks

# Build a KD-tree from mesh_points_np
tree = cKDTree(mesh_points)

#ROCK UNIT
# Input data - convert to np arrays
locations_rock_unit_input = df_rock_unit[[input_name_x,input_name_y,input_name_z]].values
labels_rock_unit_input = df_rock_unit[input_name_rock_unit].values
#Graph labels and masks
labels_rock_unit = np.ones(len(mesh_points))*-1
train_mask_rock_unit = np.zeros(len(mesh_points), dtype=bool)
test_mask_rock_unit = np.zeros(len(mesh_points), dtype=bool)

print("")
print("Mesh adjustment for rock unit measurements:")

AdjustAndLabelMesh(tree,mesh_points,locations_rock_unit_input,labels_rock_unit_input,
        labels_rock_unit,test_mask_rock_unit,train_mask_rock_unit)

#SCALAR FIELD
# Input data - convert to np arrays
locations_scalar_field_input = df_geologic_contacts[[input_name_x,input_name_y,input_name_z]].values
labels_scalar_field_input = df_geologic_contacts[input_name_geologic_contact].values

#Scale scalar field labels [-1,1]
#labels_scalar_field_input_scaled = labels_scalar_field_input #initialize in case null
if len(labels_scalar_field_input) > 0:
        max_scalar_field = np.max(labels_scalar_field_input)
        min_scalar_field = np.min(labels_scalar_field_input)
        center_scalar_field = (max_scalar_field + min_scalar_field)/2
        labels_scalar_field_input_scaled = (labels_scalar_field_input - center_scalar_field) / (max_scalar_field - min_scalar_field) * 2

#Graph labels and masks
labels_scalar_field = np.ones(len(mesh_points))*-99
train_mask_scalar_field = np.zeros(len(mesh_points), dtype=bool)
test_mask_scalar_field = np.zeros(len(mesh_points), dtype=bool)

if len(labels_scalar_field_input) > 0:
        print("")
        print("Mesh adjustment for scalar field measurements:")

        AdjustAndLabelMesh(tree,mesh_points,locations_scalar_field_input,labels_scalar_field_input_scaled,
                labels_scalar_field,test_mask_scalar_field,train_mask_scalar_field)

#ORIENTATIONS
# Input data - convert to np arrays
locations_orientations_input = df_orientations[[input_name_x,input_name_y,input_name_z]].values
labels_orientations_input = df_orientations[[input_name_x_vec,input_name_y_vec,input_name_z_vec]].values
#Graph labels and masks
labels_orientations = np.ones((len(mesh_points),3))*-1
train_mask_orientations = np.zeros(len(mesh_points), dtype=bool)
test_mask_orientations = np.zeros(len(mesh_points), dtype=bool)

if len(labels_orientations_input) > 0:
        print("")
        print("Mesh adjustment for orientation measurements:")

        AdjustAndLabelMesh(tree,mesh_points,locations_orientations_input,labels_orientations_input,
                labels_orientations,test_mask_orientations,train_mask_orientations)

In [None]:
#visualize the drillhole data and mesh after mesh adjustment
#interactive_visualize(rock_unit_data=df_rock_unit, geologic_contact_data=df_geologic_contacts, orientation_data=df_orientations,mesh_points=mesh_points,mesh_edges=mesh_edges)

In [None]:
#populate model vectors
node_features_ts = torch.tensor(mesh_points, dtype = torch.float)
edge_ts = torch.tensor(mesh_edges, dtype=torch.long).t().contiguous()
#Rock Unit
labels_rock_unit_ts = torch.tensor(labels_rock_unit, dtype = torch.long)
train_mask_rock_unit_ts = torch.tensor(train_mask_rock_unit, dtype = torch.bool)
test_mask_rock_unit_ts = torch.tensor(test_mask_rock_unit, dtype = torch.bool)
#Scalar Field
labels_scalar_field_ts = torch.tensor(labels_scalar_field, dtype = torch.float)
train_mask_scalar_field_ts = torch.tensor(train_mask_scalar_field, dtype = torch.bool)
test_mask_scalar_field_ts = torch.tensor(test_mask_scalar_field, dtype = torch.bool)
#Orientations
labels_orientations_ts = torch.tensor(labels_orientations)
train_mask_orientations_ts = torch.tensor(train_mask_orientations, dtype = torch.bool)
test_mask_orientations_ts = torch.tensor(test_mask_orientations, dtype = torch.bool)

In [None]:
#scale features x,y,z
max_extents = np.array(max_extents)
min_extents = np.array(min_extents)
major_extents = np.max(max_extents - min_extents)
center = torch.tensor((max_extents + min_extents)/2)
node_features_ts = ((node_features_ts - center) / major_extents / 2).to(torch.float)

In [None]:
# Print statistics about the graph.
def GraphSummaryStats(data):
    print(data)
    print('==============================================================')

    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {(2*data.num_edges) / data.num_nodes:.2f}')
    print(f'Number of training nodes: {data.train_mask.sum()}')
    print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.4f}')
    print(f'Number of validation nodes: {data.val_mask.sum()}')
    print(f'Contains isolated nodes: {data.has_isolated_nodes()}')
    print(f'Contains self-loops: {data.has_self_loops()}')
    print(f'Is undirected: {data.is_undirected()}')

In [None]:
#create torch graph datasets - scalar field
data_scalar_field = Data(x=node_features_ts, edge_index=edge_ts, y=labels_scalar_field_ts)
data_scalar_field.train_mask = train_mask_scalar_field_ts
data_scalar_field.val_mask = test_mask_scalar_field_ts
GraphSummaryStats(data_scalar_field)

In [None]:
#create torch graph datasets - rock unit
data_rock_unit = Data(x=node_features_ts, edge_index=edge_ts, y=labels_rock_unit_ts) 
data_rock_unit.train_mask = train_mask_rock_unit_ts
data_rock_unit.val_mask = test_mask_rock_unit_ts
GraphSummaryStats(data_rock_unit)

In [None]:
#create torch graph datasets - orientations
data_orientations = Data(x=node_features_ts, edge_index=edge_ts, y=labels_orientations_ts) 
data_orientations.train_mask = train_mask_orientations_ts
data_orientations.val_mask = test_mask_orientations_ts
GraphSummaryStats(data_orientations)

In [None]:
#save to google drive
if save_GNN_graph:
  torch.save(data_scalar_field, file_path + save_GNN_graph_filename)

In [None]:
#to load data:
#data = torch.load('/content/drive/My Drive/Colab Notebooks/MSOP_pytorch_geometric_data.pth')

In [None]:
#mini-batch - NOT USING FOR NOW
train_loader = NeighborLoader(
    data_rock_unit, #COME BACK AND CONFIRM THIS IS CORRECT
    num_neighbors=[5,20],
    batch_size=16,
    input_nodes=data_rock_unit.train_mask,
)

val_loader = NeighborLoader(
    data_rock_unit,
    num_neighbors=[5,20],
    batch_size=16,
    input_nodes=data_rock_unit.val_mask,
)

In [None]:
# Define custom loss function for orientation
# Method computes orientation of the scalar field in the graph neural network using taylor series approximation, 
# then computes loss based on the difference with the user input orientation data
# Returns: orientation loss and average angular difference
class CustomOrientationLoss(torch.nn.Module):
    def __init__(self):
        super(CustomOrientationLoss, self).__init__()

    def forward(self, out_scalar_field, orientations_data, mask):
        nodes = np.where(mask.cpu())[0]
        edge_indexes = orientations_data.edge_index
        loss = 0.0
        avg_ang_diff = 0.0
        for node in nodes:
            # Initialize a mask for adjacent nodes in the 1-hop neighborhood
            adjacent_nodes_mask = np.zeros(len(orientations_data.x), dtype=bool)
            edges_with_given_node = (edge_indexes == node).any(dim=0)
            # Extract indices of nodes that share an edge with the given node
            adjacent_nodes_pairs = edge_indexes[:, edges_with_given_node]

            for node1, node2 in adjacent_nodes_pairs.T:
                adjacent_node = node2.item() if node1 == node else node1.item()
                # Add adjacent nodes to the mask
                adjacent_nodes_mask[adjacent_node] = True

            #Pv: matrix of x,y,z coords of one-hop neighborhood Nv nodes relative to current node
            Pv = (orientations_data.x[adjacent_nodes_mask] - orientations_data.x[node]).T 
            #Sv: matrix of scalar field values output from neural network for one-hop neighbor Nv nodes relative to current node
            Sv = out_scalar_field[adjacent_nodes_mask] - out_scalar_field[node]
            Zv = Pv @ Sv # matrix multiplication for taylor series approximation of scalar field orientation at current node u
            Av = orientations_data.y[node] # measured orientation at current node u
            # Compute L1 norm
            norm_Zv = torch.norm(Zv).item()
            norm_Av = torch.norm(Av).item()
            # Calculate difference between predicted and measured orientation at current node
            cos_ang_diff = torch.dot(Av.float(),Zv).item() / (norm_Av * norm_Zv)
            ang_diff = math.degrees(math.acos(cos_ang_diff)) # Don't need, temporary for debugging
            avg_ang_diff += ang_diff
            # Add to incremental loss
            loss += 1 - abs(cos_ang_diff)
        
        node_count = len(nodes)

        if node_count > 0:
            avg_ang_diff /= node_count
            loss /= node_count

        return loss, avg_ang_diff

In [None]:
#GNN
class GraphSAGE(torch.nn.Module):
  """GraphSAGE"""
  def __init__(self, dim_in, dim_h, dim_out):
    super().__init__()
    # part 1 of the network - scalar field regression
    self.sage1 = SAGEConv(dim_in, dim_h)
    self.sage2 = SAGEConv(dim_h, dim_h)
    self.sage3 = SAGEConv(dim_h, 1) #scalar field output

    # part 2 of the network - rock unit classification
    self.sage4 = SAGEConv(1, dim_out)
    self.optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.005,
                                      weight_decay=5e-4)
    self.dim_h = dim_h
    self.dim_out = dim_out

  def forward(self, x, edge_index):
    h = self.sage1(x, edge_index)
    #h = torch.prelu(h, torch.ones(self.dim_h))
    h = torch.relu(h)
    #h = F.dropout(h, p=0.3, training=self.training)
    h = self.sage2(h, edge_index)
    #h = torch.prelu(h, torch.ones(self.dim_h))
    h = torch.relu(h)
    #h = F.dropout(h, p=0.3, training=self.training)
    h = self.sage3(h, edge_index)

    out_scalar_field = h[:,0]#[:, self.dim_out-1]

    h = self.sage4(h, edge_index)

    out_rock_unit = F.log_softmax(h, dim=1)

    return out_rock_unit, out_scalar_field

  def fit(self, data_rock_unit, data_scalar_field, data_orientations, epochs):
    criterion_rock_unit = torch.nn.CrossEntropyLoss() # rock unit
    criterion_scalar_field = torch.nn.MSELoss() # scalar field
    criterion_orientation = CustomOrientationLoss() # orientation
    optimizer = self.optimizer
    train_losses = []
    val_losses = []
    self.train()

    for epoch in range(epochs):
        total_loss = 0
        total_accuracy = 0
        val_loss = 0
        val_acc = 0

        optimizer.zero_grad()
        out_rock_unit, out_scalar_field = self(data_rock_unit.x, data_rock_unit.edge_index)  # Perform a single forward pass.

        # Compute the loss solely based on the training nodes.
        loss_scalar_field = criterion_scalar_field(out_scalar_field[data_scalar_field.train_mask], data_scalar_field.y[data_scalar_field.train_mask])  
        if math.isnan(loss_scalar_field):
          loss_scalar_field = 0 #handle case of no contact input data
        acc_scalar_field = accuracy_scalar_field(out_scalar_field[data_scalar_field.train_mask],data_scalar_field.y[data_scalar_field.train_mask])

        loss_rock_unit = criterion_rock_unit(out_rock_unit[data_rock_unit.train_mask], data_rock_unit.y[data_rock_unit.train_mask])
        acc_rock_unit = accuracy_rock_unit(out_rock_unit[data_rock_unit.train_mask].argmax(dim=1),data_rock_unit.y[data_rock_unit.train_mask])

        # Pass scalar field to compute orientations, return orientation loss and average angular difference
        loss_orientation, avg_ang_diff = criterion_orientation(out_scalar_field, data_orientations, data_orientations.train_mask)

        total_loss = loss_rock_unit * mRockUnitWeightFactor + loss_scalar_field * mGeologicContactWeightFactor + loss_orientation * mOrientationsWeightFactor
        total_accuracy = acc_rock_unit + acc_scalar_field

        # (loss_classification + loss_regression.backward() for performance in future, consider)
        total_loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients

        #validation
        val_loss_scalar_field = criterion_scalar_field(out_scalar_field[data_scalar_field.val_mask], data_scalar_field.y[data_scalar_field.val_mask])
        if math.isnan(val_loss_scalar_field):
          val_loss_scalar_field = 0 #handle case of no contact input data
        val_acc_scalar_field = accuracy_scalar_field(out_scalar_field[data_scalar_field.val_mask],data_scalar_field.y[data_scalar_field.val_mask])
        val_loss_rock_unit = criterion_rock_unit(out_rock_unit[data_rock_unit.val_mask], data_rock_unit.y[data_rock_unit.val_mask])
        val_acc_rock_unit = accuracy_rock_unit(out_rock_unit[data_rock_unit.val_mask].argmax(dim=1),data_rock_unit.y[data_rock_unit.val_mask])
        val_loss_orientation, val_avg_ang_diff = criterion_orientation(out_scalar_field, data_orientations, data_orientations.val_mask)

        val_loss = val_loss_rock_unit * mRockUnitWeightFactor + val_loss_scalar_field * mGeologicContactWeightFactor + val_loss_orientation * mOrientationsWeightFactor
        val_acc = val_acc_rock_unit + val_acc_scalar_field

        train_losses.append(total_loss.item())
        val_losses.append(val_loss.item())

        # Print header
        if epoch == 0:
          print(f'      || TRAIN                                                                           || VALIDATION')
          print(f'Epoch || Loss Rock Unit | Loss Scalar | Loss Ori | Acc Rock Unit | MSE Scalar | Ang Diff || Loss Rock Unit | Loss Scalar | Loss Ori | Acc Rock Unit | MSE Scalar | Ang Diff')

        # Print metrics every 10 epochs
        if epoch % 10 == 0:
          s = '%5.0f || %14.3f | %11.3f | %8.3f | %12.2f%% | %9.2f%% | %5.0fdeg || '%(epoch,loss_rock_unit,loss_scalar_field,loss_orientation,acc_rock_unit,acc_scalar_field,avg_ang_diff)
          s += '%14.3f | %11.3f | %8.3f | %12.2f%% | %9.2f%% | %5.0fdeg'%(val_loss_rock_unit,val_loss_scalar_field,val_loss_orientation,val_acc_rock_unit,val_acc_scalar_field,val_avg_ang_diff)
          print(s)

    return train_losses,val_losses


def accuracy_rock_unit(pred_y, y):
    """Calculate classification accuracy."""
    if len(y) == 0:
      return 0
    else:
      return ((pred_y == y).sum() / len(y)).item()

def accuracy_scalar_field(pred_y, y):
    """Calculate regression accuracy."""
    if len(y) == 0:
      return 0
    else:
      range = torch.max(y) - torch.min(y)
      MSE = (((pred_y - y)/range) ** 2).mean().item()
      return MSE

def test(model, data):
    """Evaluate the model on test set and print the accuracy score."""
    model.eval()
    _, out = model(data.x, data.edge_index)
    acc = accuracy_rock_unit(out.argmax(dim=1)[data.val_mask], data.y[data.val_mask])
    return acc

def evaluate(model, data):
    """Evaluate the model and return prediction results."""
    model.eval()
    out1, out2 = model(data.x, data.edge_index)
    return out1, out2

In [None]:
s = "Price: $ %8.2f, second value: %16.3f"% (356.08977,524354243.12324)
print(s)

# **Train GNN Model**

In [None]:
# Create GraphSAGE
num_features = data_rock_unit.num_features
num_classes = torch.unique(data_rock_unit.y).size(0) # number of rock unit classes
print(f'Number of features: {num_features}')
print(f'Number of classes: {num_classes}')
graphsage = GraphSAGE(num_features,mNumHiddenLayerNeuron,num_classes)
graphsage = graphsage.to(device)
print(graphsage)

In [None]:
# Train
train_losses,val_losses = graphsage.fit(data_rock_unit.to(device), data_scalar_field.to(device), data_orientations.to(device), mNumEpoch)

# Test
#print(f'\nGraphSAGE test accuracy: {test(graphsage, data_rock_unit)*100:.2f}%\n')

In [None]:
#OPTIONAL: Save model / load existing model to skip above steps

#Save PyTorch model
if save_model:
  torch.save(graphsage.state_dict(), file_path + save_model_filename)

In [None]:
# # Create an instance of your GNN model
# graphsage = GraphSAGE(num_features,128,num_classes)

# # Load the saved state dictionary into your model
# graphsage.load_state_dict(torch.load(file_path + save_model_filename))
# graphsage.eval()  # Put the model in evaluation mode if needed

# **Validate & Visualize Results**

In [None]:
#plot of training and cross validation loss

# Calculate moving average of train_losses
moving_avg_train = np.convolve(train_losses, np.ones(10)/10, mode='valid')
moving_avg_val = np.convolve(val_losses, np.ones(10)/10, mode='valid')

epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(8, 6))
#plt.plot(epochs, train_losses, 'b', label='Training Loss')
plt.plot(epochs[:len(moving_avg_train)], moving_avg_train, 'b--', label='Training Moving Avg')
#plt.plot(epochs, val_losses, 'r', label='Validation Loss')
plt.plot(epochs[:len(moving_avg_val)], moving_avg_val, 'r--', label='Validation Moving Avg')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
#Create confusion matrix

# Lists to store ground truth labels and predictions
list_labels = []
list_predictions = []

out1,out2 = evaluate(graphsage,data_scalar_field)
predicted_rock_unit = out1.argmax(dim=1)
predicted_scalar_field = out2

combined_mask = np.logical_or(data_rock_unit.train_mask.cpu(), data_rock_unit.val_mask.cpu())

filtered_predictions = predicted_rock_unit[combined_mask]
filtered_labels = data_rock_unit.y[combined_mask]

list_predictions.extend(filtered_predictions.cpu().numpy())
list_labels.extend(filtered_labels.cpu().numpy())

# Create a confusion matrix
conf_matrix = confusion_matrix(list_labels, list_predictions)

# Visualize the confusion matrix using seaborn
class_labels = ['1', '2', '3']
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()

# Print a summary of predicted values by class
predicted_counts = Counter(predicted_rock_unit.cpu().numpy())
predicted_counts = {k: v for k, v in sorted(predicted_counts.items(), key=lambda item: item[0])}
print("Predicted values summary by Min Code:")
for class_label, count in predicted_counts.items():
    print(f"Min Code {class_label}: {count} nodes")

In [None]:
#Interactive plot of rock unit predictions in a mesh slice
interactive_visualize(rock_unit_data=df_rock_unit, geologic_contact_data=df_geologic_contacts,
                       orientation_data=df_orientations, mesh_points=mesh_points, predicted_labels=predicted_rock_unit.cpu())
                       #mesh_edges=mesh_edges #mesh edges off for performance

In [None]:
#Interactive plot of scalar field predictions in a mesh slice
interactive_visualize(rock_unit_data=df_rock_unit, geologic_contact_data=df_geologic_contacts,
                      orientation_data=df_orientations, mesh_points=mesh_points, predicted_labels=predicted_scalar_field.detach()) 
                      #mesh_edges=mesh_edges #mesh edges off for performance

In [None]:
#Plot the input training data with the labels as larger circles and the predictions as smaller circles within

filtered_predictions_cpu = filtered_predictions.cpu().detach().numpy()
filtered_labels_cpu = filtered_labels.cpu().detach().numpy()
mislabel_mask = filtered_labels_cpu != filtered_predictions_cpu

filtered_locations = data_rock_unit.x[combined_mask].cpu().detach().numpy()

fig = px.scatter_3d(x=filtered_locations[:,0], y=filtered_locations[:,1], z=filtered_locations[:,2])
fig.update_traces(marker=dict(color=filtered_labels_cpu, size=7, line=dict(width=2, color=filtered_labels_cpu)), selector=dict(mode='markers'))

smaller_size_trace = px.scatter_3d(
    x=filtered_locations[mislabel_mask, 0],
    y=filtered_locations[mislabel_mask, 1],
    z=filtered_locations[mislabel_mask, 2],
    color=filtered_predictions_cpu[mislabel_mask]
)
smaller_size_trace.update_traces(
    marker=dict(size=4),  # Adjust the size as needed
    selector=dict(mode='markers')
)

fig.add_trace(smaller_size_trace.data[0])

fig.update_layout(title='Interactive 3D Plot')

fig.show()