In [None]:
import torch
from torch_geometric.data import Data
import os
from torch.utils.data import Dataset, DataLoader
from coarsen import *
from scipy import io
from torch_geometric.transforms import FaceToEdge
from models_gpu import *
from visualize import *
from evaluate import *

In [None]:

def get_graph(mat,index):
    '''
    get_graph: Reads a single data point from already-loaded matlab data
    
    mat - The dictionary of values read from a .mat file
    index - The index of the data point
    
    Returns - The Data() representation of 'mat'
    
    '''
    
    nodes = mat['nodes'][index,0].T
    elems = mat['elem'][index,0].T-1
    stress = mat['stress'][index,0]
    dt = mat['dt'][index,0]
    sdf = mat['sdf'][index][0].T

    f2e = FaceToEdge(remove_faces=True)

    x = torch.tensor(np.concatenate((nodes,dt), axis=1), dtype=torch.float)
    y = torch.tensor(stress, dtype=torch.float)

    data = Data(x=x, face = torch.tensor(elems.T), y=y)
    data = f2e(data)
    data.sdf = torch.tensor(sdf, dtype=torch.float)
    return data

def load_matlab_dataset(filename, scale = 10000):
    '''
    load_matlab_dataset: Loads a scalar field dataset from a .mat file  
    
    Inputs:
    - filename - The .mat dataset consisting of meshes, the scalar field and SDF at each node, and an SDF array
    - scale - The number to divide each scalar field value by, defaults to 10000   
    
    Returns:
    - The dataset as a list of Data() objects
    
    '''
    mat = io.loadmat(filename)
    dataset = []
    for i in range(len(mat['nodes'])):
        data = get_graph(mat, i)
        data.y /= scale
        data.sdf = data.sdf[None, None, :, :] * 10
        dataset.append(data)
        
    return dataset


class StressDataset(Dataset):
    def __init__(self, zipfiles, **kwargs):
        self.zipfiles = zipfiles
        self.options = kwargs
        self.data = []
        for zipfile in zipfiles:
            self.data += load_matlab_dataset(zipfile)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        X = coarsen_data(self.data[idx], **self.options)
        return X
    
def split_dataset(dataset, train_frac=0.8, seed=216):
    train_size = int(0.8*len(dataset))
    test_size = len(dataset) - train_size
    generator = torch.Generator().manual_seed(seed)
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator)
    print(f"Splitting dataset with seed {seed}: {train_size} training, {test_size} testing")
    return train_dataset, test_dataset
    
class StressDataLoader(DataLoader):
    def __init__(self, dataset, **kwargs):
        super().__init__(dataset, collate_fn=merge_coarsened_data, dimension=2, **kwargs)

In [None]:
device=torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = StressDataset(["data/stress_vor_w.mat", "data/stress_lat_w.mat"])
dataset_tr, dataset_te = split_dataset(dataset)
train_loader = StressDataLoader(dataset_tr, batch_size=4)
test_loader = StressDataLoader(dataset_te, batch_size=4)

In [None]:




conv_info = get_default_conv_info("GCNConv")
model = GraphUNet(num_layers=3, num_channels=64, conv_info=conv_info, mlp_dims=[128,128,128]).to(device)
# 64 minutes, 100 epochs, lr=0.001, "stress-model-gcn"
# Total parameters:  76225
# Performance:
#                     Median R2
#         Train set:   0.5444
#          Test set:   0.5409

# conv_info = get_default_conv_info("EdgeConv")
# conv_info["dims"] = [64,64]
# model = GraphUNet(num_layers=3, num_channels=64, conv_info=conv_info, mlp_dims=[128,128,128]).to(device)
# 67 minutes, 75 epochs, lr=0.001, "stress-model"
# Total parameters:  159745
# Performance:
#                     Median R2
#         Train set:   0.9013
#          Test set:   0.8743

# model = JustConvNet(num_convs=4, conv_dims=[64,64], channels=64, mlp_dims = [128,128,128]).to(device)
# 120 minutes, 200 epochs, lr=0.0001, "stress-model-plain"
# Total parameters:  91457
# Performance:
#                     Median R2
#         Train set:   0.439
#          Test set:   0.4154

print(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
lr = 0.0005
epochs = 50
opt = optim.Adam(params=model.parameters(), lr=lr)
lossfun = nn.MSELoss().to(device)
# all_losses = []
# val_losses = []

for epoch in range(epochs+1):
    losses = []
    for i, data in enumerate(train_loader):
        loss = lossfun(model(data, device).reshape(-1,1), data.y.reshape(-1,1).to(device))
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
        print(f"Batch {i+1}/{len(train_loader)}: Loss = {loss.item():.6e}, Avg = {np.mean(np.array(losses)):.6e}          ",end="\r")
        # del data
        # torch.cuda.empty_cache() 

    if 1: # 0 == (epoch%5):
        print(f"Epoch: {epoch}/{epochs}, Loss: {np.mean(np.array(losses)):.6e}                             ")# (max {np.max(np.array(losses)):.6e}, median {np.median(np.array(losses)):.6e})                ")
    all_losses.append(np.array(losses))

    model.eval()
    losses = []
    for i, data in enumerate(test_loader):
        loss = lossfun(model(data, device).reshape(-1,1), data.y.reshape(-1,1).to(device))
        losses.append(loss.item())
        # del data
        # torch.cuda.empty_cache() 
    if 1: # 0 == (epoch%5):
        print(f"   Val. Loss: {np.mean(np.array(losses)):.6e}                              ")
    val_losses.append(np.array(losses))
    model.train()

In [None]:
# np.savez("stress-model-plain-losses.npz", all_losses=np.array(all_losses), val_losses=np.array(val_losses))
# torch.save(model, "stress-model-plain.pth")
# model = torch.load("stress-model-plain.pth").to(device)


#np.savez("stress-model-losses.npz", all_losses=np.array(all_losses), val_losses=np.array(val_losses))
#torch.save(model, "stress-model.pth")
# model = torch.load("stress-model.pth").to(device)


np.savez("stress-model-gcn-losses.npz", all_losses=np.array(all_losses), val_losses=np.array(val_losses))
torch.save(model, "stress-model-gcn.pth")
# model = torch.load("stress-model-gcn.pth").to(device)


In [None]:
vals_plain = vals

In [None]:
vals = eval_model_multiple(model, dict(tr=dataset_tr, te=dataset_te), device=device)

In [None]:
print("Total parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Performance:")
print("                    Median R2")
print("        Train set:  ", np.round(np.median(vals["tr"]),4))
print("         Test set:  ", np.round(np.median(vals["te"]),4))

In [None]:
plot_boxes(vals)


In [None]:
i = 107
data = dataset_te[i]
plot_comparison(model, data, device=device)

i = 105
data = dataset_te[i]
plot_comparison(model, data, device=device)

In [None]:
model1 = torch.load("stress-model-plain.pth").to(device)
model2 = torch.load("stress-model-gcn.pth").to(device)
model3 = torch.load("stress-model.pth").to(device)
models = dict(model1=model1, model2=model2, model3=model3)
model_names=dict(model1="Plain GNN (EdgeConv)", model2="TAG U-Net (GCNConv)", model3="TAG U-Net (EdgeConv)")


In [None]:
r2file="r2_data.npz"
if os.path.isfile(r2file):
    r2s = np.load(r2file)
else:
    r2s = dict()
    for key in models:
        vals = eval_model_multiple(models[key], dict(tr=dataset_tr, te=dataset_te), device=device)
        r2s[key] = vals
    np.savez(r2file,r2s)

In [None]:
model_color=dict(model1="darkred", model2="blue", model3="green")
plt.figure(dpi=200)
for key in models:
    te = r2s[key]["te"]
    log = True
    bins = 20
    plt.hist(te[te>-1], bins=bins, density=True, histtype="step", lw=3., edgecolor=model_color[key], label=model_names[key])
    #plt.hist(te[te>-1], bins=bins, density=True, alpha=0.3, , log=log, histtype="stepfilled", color=model_color[key])
plt.legend()
plt.xlabel("$R^2$")
plt.ylabel("Probability Density")
plt.title("2-D Stress Prediction Task")
plt.show()

In [None]:
def plot_model_comparison(models, data, model_names, filename=None, dpi=300, size=13, device="cpu"):
    N = len(models)
    plt.figure(figsize=(4.2*(N+1), 9.5), dpi=dpi)
    s = size# / (1 + 7*(3000<data.x.shape[0]))
    small_axes = []

    maxval = np.max(data.y.detach().numpy())
    for key in models:
        pred = models[key].to(device)(data, device=device).cpu().detach().numpy()
        maxval = max(maxval, np.max(pred))

    plt.subplot(2,N+1,1)
    plt.title("Ground Truth", fontsize=17)
    plot_data(data, data.y, size=s, color_bounds=[0,maxval])
    
    i = 2
    for key in models:
        model = models[key].to(device)
        pred = model(data, device=device).cpu().flatten()

        plt.subplot(2,N+1,i)
        plt.title(model_names[key], fontsize=17)
        plot_data(data, pred, size=s, color_bounds=[0,maxval])

        ax = plt.subplot(2,N+1,i+4)
        plot_model_r2(model, data, device=device)
        plt.axis("scaled")
        small_axes.append(ax)
        i += 1

    plt.tight_layout()
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.14, hspace=0)
    for ax in small_axes:
        pos1 = ax.get_position() # get the original position 
        pos2 = [pos1.x0 + 0.01, pos1.y0 + 0.1/N,  pos1.width * 0.9, pos1.height * 0.9] 
        ax.set_position(pos2) # set a new position

    if filename is not None:
        plt.savefig(filename, bbox_inches = "tight")
        plt.close()
    else:
        plt.show()

In [None]:
plt.figure(dpi=300)
plt.imshow(np.flipud(dataset_te[imed].data[0].sdf.squeeze().detach().numpy()),cmap="seismic", vmin=-2,vmax=2)
plt.axis("off")
plt.show()

In [None]:
idx = np.argsort(r2s["model3"]["te"])
imed = idx[3*len(idx)//4-10]
plot_model_comparison(models, dataset_te[imed],model_names)

In [None]:
def plot_graph(xy, edges, node_color='black', edge_color=None, node_size = 100, 
               color_bounds = None, label = None, linewidth = 1, colorbar=True):
    ''' 
    plot_graph: Plots a 2D graph's nodes and edges on the current axes
    
    Args:
    xy - Two-column array with coordinates: [x, y]
    edges - Array of node index pairs for each 1-directional edge
    node_color - Color string, array, or rgb triple, defaults to 'black'
    edge_color - Color string or rgb triple for each edge (None: no edges)
    node_size - Size of each node, defaults to 100
    color_bounds - [value of lowest color, value of highest color], when node_color is an array
    label - Name of the plot for use in a legend (optional)
    
    Returns:
    - handle of node scatter plot
    '''

    x = xy[:,0]
    y = xy[:,1]


    if edge_color is not None:
        edges = edges[:,edges[0,:] < edges[1,:]]
        for edge in edges.T:
            plt.plot([x[edge[0]],x[edge[1]]], [y[edge[0]],y[edge[1]]], c=edge_color, zorder=0, linewidth=linewidth)
    
    if label is not None:
        title_height = 0.88
        fontsize = 12
        plt.title(label, fontsize=fontsize, y=title_height)

    if type(node_color) == str:
        handle = plt.scatter(x,y, s=node_size, c=node_color, zorder=1, label=label, cmap='jet')
    else:
        if color_bounds is None:
            tick_min = np.round(np.min(node_color),3)
            tick_max = np.round(np.max(node_color),3)
        elif len(color_bounds) == 1:
            tick_min = color_bounds[0]
            tick_max = np.round(np.max(node_color),3)
        else:
            tick_min = color_bounds[0]
            tick_max = color_bounds[1]
        cb = dict(vmin=tick_min, vmax=tick_max)
        tick_min = cb["vmin"]
        tick_max = cb["vmax"]
        tick_med = np.round((tick_min + tick_max)/2,3)
        handle = plt.scatter(x,y, s=node_size, c=node_color, zorder=1, label=label, cmap='jet', **cb)

        if colorbar:
            cbar_shrink = 0.9
            cbar_pad = -0.1
            bar = plt.colorbar(shrink=cbar_shrink, location='bottom', pad=cbar_pad, ticks=[tick_min, tick_med, tick_max])
            bar.ax.set_xticklabels([tick_min, tick_med, tick_max])

    plt.axis("equal")
    plt.axis("off")
    return handle

def plot_data(data, colors="black", show_edges=False, color_bounds=None, label = None, size=30, width=.4, colorbar=True):
    xy = data.x[:,:2].detach().numpy()
    edges = data.edge_index.detach().numpy()
    node_color = colors.detach().numpy() if type(colors) == torch.Tensor else colors
    edge_color = "black" if show_edges else None
    handle = plot_graph(xy, edges, node_color=node_color, edge_color=edge_color, node_size=size, 
                        color_bounds=color_bounds, label=label, linewidth=width, colorbar=colorbar)
    return handle

In [None]:
data = dataset.data[0]
plt.figure(dpi=300)
plot_data(data, show_edges=True, size=5)
plt.show()

In [None]:
data = dataset[0].data[1]
plt.figure(dpi=300)
plot_data(data, show_edges=True, size=10)
plt.show()