In [1]:
import os
import deepxde as dde
import numpy as np
import matplotlib.pyplot as plt
from HPO_train import *

lmbd = 1.0
mu = 0.5
Q = 4.0

domain = np.array([[0.0, 1.0], [0.0, 1.0]])
geom = dde.geometry.Rectangle([0, 0], [1, 1])

Using backend: pytorch



In [2]:
print(f"Using GPU: {dde.backend.torch.cuda.is_available()}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Using GPU: False


In [3]:
lmbd = 1.0
mu = 0.5
Q = 4.0

domain = np.array([[0.0, 1.0], [0.0, 1.0]])
geom = dde.geometry.Rectangle([0, 0], [1, 1])

# plotting utilities

def pcolor_plot(AX, X, Y, C, title,colormap="jet",**kwargs):
    ## plot the pcolor plot of the given data C on the given axis AX with the given title and optional colorbar limits cmin and cmax
    if len(kwargs) == 0:
        im = AX.pcolor(X, Y, C, cmap=colormap,shading='auto')
    else:
        cmin = kwargs["cmin"]
        cmax = kwargs["cmax"]
        im = AX.pcolor(X, Y, C, cmap=colormap, vmin=cmin, vmax=cmax,shading='auto')
    AX.axis("equal")
    AX.axis("off")
    AX.set_title(title, fontsize=14)
    return im

def plot_field(domain,model,output_func=None,V_exact=None,plot_diff=False,n_points=10000,fields_name=None):

    X = np.linspace(domain[0][0], domain[0][1], int(np.sqrt(n_points)))
    Y = np.linspace(domain[1][0], domain[1][1], int(np.sqrt(n_points)))
    Xgrid, Ygrid = np.meshgrid(X, Y)
    Xinput = np.hstack((Xgrid.reshape(-1, 1), Ygrid.reshape(-1, 1)))

    plotify = lambda x: x.reshape(Xgrid.shape)

    if output_func is None:
        V_nn = model.predict(Xinput)
    else:
        V_nn = model.predict(Xinput, operator=output_func)

    V_nn = [plotify(V) for V in V_nn]

    n_fields = len(V_nn) if type(V_nn) is list else 1
    n_plot = 1

    if fields_name is None:
        fields_name = V_exact.__name__.replace('_exact','') if V_exact is not None else 'V'

    coord = ["_x","_y","_xy"] if n_fields > 1 else [""]
    fields_name = [fields_name + coord[i] for i in range(n_fields)]

    if V_exact is not None:
        V_exact = V_exact(Xinput)
        V_exact = [plotify(V) for V in V_exact]
        n_plot = 3 if plot_diff else 2
        
    fig, ax = plt.subplots(n_fields, n_plot, figsize=(4*n_plot, 3*n_fields), dpi=200)

    for i in range(n_fields):

        if V_exact is not None:

            cmax = max(V_nn[i].max(), V_exact[i].max())
            cmin = min(V_nn[i].min(), V_exact[i].min())

            im1 = pcolor_plot(ax[i][0], Xgrid, Ygrid, V_exact[i], f"{fields_name[i]}*",colormap="jet", cmin=cmin, cmax=cmax)
            im2 = pcolor_plot(ax[i][1], Xgrid, Ygrid, V_nn[i], f"{fields_name[i]}_nn",colormap="jet", cmin=cmin, cmax=cmax)

            fig.colorbar(im1, ax=ax[i][1])
        else:
            im1 = pcolor_plot(ax[i], Xgrid, Ygrid, V_nn[i], f"{fields_name[i]}_nn")
            fig.colorbar(im1, ax=ax[i])

        if plot_diff:
            diff = V_nn[i] - V_exact[i]
            abs_diff = np.abs(diff)
            cmax = abs_diff.max() if diff.max() > 0 else 0
            cmin = -abs_diff.max() if diff.min() < 0 else 0
            im3 = pcolor_plot(ax[i][2], Xgrid, Ygrid,diff, f"{fields_name[i]}_nn - {fields_name[i]}*",colormap="jet", cmin=cmin, cmax=cmax)
            fig.colorbar(im3, ax=ax[i][2])
            ax[i][2].text(1.075,0.5,f"mean($\mid${fields_name[i]}_nn - {fields_name[i]}*$\mid$): {np.mean(abs_diff):.2e}", fontsize=6,ha = "center",rotation = "vertical",rotation_mode = "anchor")

    return fig

In [4]:
import yaml
import glob

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    
def load_config(path):  
    with open(path, 'r') as file:
        config = yaml.load(file, Loader=yaml.FullLoader)
        config.pop('wandb_version')
        config.pop('_wandb')
        for key, val in config.items():
            config[key] = val["value"]
    return dotdict(config)

def path_from_id(run_id):
    config_path = os.path.join('wandb',f"*-{run_id}","files")
    return glob.glob(config_path)[0]

In [5]:
run_id = "mmf62c2a"

config_path = os.path.join(path_from_id(run_id),"config.yaml")
model_path = os.path.join(path_from_id(run_id),"model-13000.pt")
config = load_config(config_path)

print(config)
[net_type,n_layers,size_layers,activation] = config.net_type, config.n_layers, config.size_layers, config.activation
[loss_type,num_sample,train_distribution,bc_type] = config.loss_type, config.num_samples, config.train_distribution, config.bc_type

net_exact = Unet_exact if net_type == 'Unet' else USnet_exact
size_output = 2 if net_type == 'Unet' else 5
num_domain = num_sample**2
num_boundary = num_sample

net = dde.nn.FNN([2] + [size_layers]*n_layers + [size_output], activation, 'Glorot uniform')
net, total_loss, bc, pde_net, energy_net, mat_net = loss_and_bc_setup(geom,net, net_type, bc_type, loss_type)

data = dde.data.PDE(
    geom,
    total_loss,
    bc,
    num_domain=num_domain,
    num_boundary=num_boundary,
    num_test= 100**2,
    train_distribution = train_distribution,
    solution=net_exact,
)
model = dde.Model(data,net)
model.compile('adam', lr=0.001)
model.restore(model_path,device)

# model = dde.Model(data, net)

# Calculate the final solution

U_output = lambda x, ouput: (ouput[:,0], ouput[:,1]) 
S_output = lambda x, ouput: S_nn(E_nn(x,ouput)) if config.net_type == 'Unet' else (ouput[:,2], ouput[:,3], ouput[:,4]) 

U_field = plot_field(domain,model,output_func=U_output,V_exact=lambda x: (U_exact(x)[:,0], U_exact(x)[:,1]),plot_diff=True,fields_name="U")
Eps_field = plot_field(domain,model,output_func=E_nn,V_exact=E_exact,plot_diff=True)
Sig_field =plot_field(domain,model,output_func=S_output,V_exact=S_exact,plot_diff=True)

{'activation': 'tanh', 'bc_type': 'hard', 'iterations': [3000, None], 'learning_rate': [0.001, None], 'loss_type': 'pde', 'n_layers': 7, 'net_type': 'USnet', 'num_samples': 50, 'optimizer': ['adam', 'L-BFGS'], 'size_layers': 100, 'train_distribution': 'Hammersley'}
Compiling model...
'compile' took 0.000251 s



KeyError: 'step'

In [15]:
wandb.

AttributeError: module 'wandb' has no attribute 'init'