# Grad-CAM for MOFs

- This notebook requires a trained model. We provide the .pth file containing the model we train for the hMOFX-DB benchmark. 
- All global descriptors and parameters required for scaling are hard-coded into the notebook. 

**Input**: CIF File + Adsorption Conditions (P, T, Adsorbate)

**Output**: Atomic importance heatmap displaying adsorption sites and interaction hotspots from learned structural-property relationships.

### Definitions:

In [1]:
# Adsorption conditions
adsorbate = "CO2" # Select between CO2, H2, N2, and CH4
pressure = 1000 # Pa
temperature = 298 # The model was only trained on 77 K for H2 and 298 K for the other adsorbates

# Path for CIF File
grad_cam_cif_path = ""

# Path of saved model
model_path = ""

# Path for saving GradCAM figure
fig_save_path  = "datasets/gradcam.png"

### Load Libraries

In [2]:
import os
import torch
import h5py
import vtk
import pyvista as pv
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

from ase.io import read, write
from ase.build import make_supercell
from ase.data import atomic_numbers
from rdkit import Chem
from rdkit.Chem import AllChem

from models.pointnet2_v2_film_clipped import get_model
from data_util.pre_processing_v2 import encode_gas

### Convert CIF to Point Cloud for GradCAM Evaluation

In [3]:
pauling_en = {
    "H": 0.8492077756084468, "He": 0.9234948886510139, "Li": 0.47493569014597764, "Be": 0.6642111641550714, 
    "B": 0.840168050416806, "C": 1.0, "N": 1.1258799375612023, "O": 1.2909944487358056,
    "F": 1.3693063937629155, "Ne": 1.4194806702221516, "Na": 0.5723140958578757, "Mg": 0.6878662178008077, 
    "Al": 0.8006407690254357, "Si": 0.908623583766197, "P": 0.9989599581169395, "S": 1.1050160600213705, 
    "Cl": 1.1924392738883696, "Ar": 1.2738116634126702, "K": 0.5089466451388751, "Ca": 0.6201736729460421, 
    "Sc": 0.6839411288813299, "Ti": 0.731096616353433, "V": 0.7538648981947593, "Cr": 0.7469990402967646, 
    "Mn": 0.835538990118209, "Fe": 0.8637251994466477, "Co": 0.9004503377814961, "Ni": 0.9217648016985401, 
    "Cu": 0.8731338026686615, "Zn": 0.9223432547217107, "Ga": 0.9646352117828865, "Ge": 1.0380552995055379, 
    "As": 1.0961412988206825, "Se": 1.1758511788040868, "Br": 1.2403473458920846, "Kr": 1.2756249193674616, 
    "Rb": 0.49168917189444195, "Sr": 0.596246052824969, "Y": 0.6517120879556719, "Zr": 0.6870429186215166, 
    "Nb": 0.6629935441317955, "Mo": 0.7023610444702197, "Tc": 0.8056292332943622, "Ru": 0.7745966692414834, 
    "Rh": 0.7922703501282296, "Pd": 1.1477402547212905, "Ag": 0.816741885599305, "Cd": 0.8591403680107819, 
    "In": 0.9014253790393105, "Sn": 0.9650485383226495, "Sb": 1.0190493307301367, "Te": 1.085955175195522, 
    "I": 1.1483385035264293, "Xe": 1.2055362602143995, "Cs": 0.46779577942376166, "Ba": 0.5792730788177665, 
    "La": 0.6201736729460422, "Ce": 0.6517120879556719, "Pr": 0.6113009170521597, "Nd": 0.6148041028495473, 
    "Pm": 0.6165784329539861, "Sm": 0.6183682144631293, "Eu": 0.6256864362310539, "Gd": 0.6400386879521874, 
    "Tb": 0.6256864362310539, "Dy": 0.6275569529190723, "Ho": 0.6294443465127104, "Er": 0.631348872337157, 
    "Tm": 0.6332707911583599, "Yb": 0.6219950386090711, "Lu": 0.6537204504606134, "Hf": 0.691548166360639, 
    "Ta": 0.7222199706358509, "W": 0.7623215819277851, "Re": 0.7963510440251993, "Os": 0.8190487086369509, 
    "Ir": 0.8588975014708029, "Pt": 0.8161135189404553, "Au": 0.8298105855798943, "Hg": 0.8687758883885015, 
    "Tl": 0.8951435925492911, "Pb": 0.9515506912134104, "Bi": 0.9812298519789698, "Po": 1.0517132668916793, 
    "At": 1.0922877925089394, "Rn": 1.1579018646071315, "Fr": 0.4771422345272372, "Ra": 0.5720228171406342, 
    "Ac": 0.6100888760865631, "Th": 0.6445033866354902, "Pa": 0.6400386879521874, "U": 0.6381534447172755, 
    "Np": 0.6362847629757779, "Pu": 0.6183682144631301, "Am": 0.6294443465127104, "Cm": 0.6457962733644264, 
    "Bk": 0.6256864362310547, "Cf": 0.6256864362310547, "Es": 0.631348872337157, "Fm": 0.6275569529190731, 
    "Md": 0.616578432953987, "No": 0.6113009170521605, "Lr": 0.6557474954819612, "Rf": 0.6804471538998222, 
    "Db": 0.7149123293218679, "Sg": 0.7461574507879195, "Bh": 0.7675923631762812, "Hs": 0.8036226816523603, 
    "Mt": 0.8352690695845574, "Ds": 0.8000150238973879, "Rg": 0.8400345102346997, "Cn": 0.907096861568455, 
    "Nh": 0.921095157729515, "Fl": 0.9548719951727265, "Mc": 0.9473309334313417, "Lv": 0.9573314559145735, 
    "Ts": 1.0309883912717273, "Og": 1.101199737705377
}

hardness = {
    "H": 6.422119502568, "He": 22.143693968, "Li": 2.3868328805, "Be": 5.8613495, 
    "B": 4.009148, "C": 4.9990885, "N": 7.967065, "O": 6.0784702500000005, 
    "F": 7.010815150000001, "Na": 2.29557535, "Al": 2.7764692, "Si": 3.3810809500000003, 
    "P": 4.870039500000001, "S": 4.141452985000001, "Cl": 4.677452499999999, "Ar": 13.629805600000001,
    "K": 1.9195967700000003, "Ca": 3.0443026, "Sc": 3.186745, "Ti": 3.3745600000000002, 
    "V": 3.1105935, "Cr": 3.050255, "Mn": 3.717009, "Fe": 3.8757339, 
    "Co": 3.60937677, "Ni": 3.2419385000000003, "Cu": 3.2456899999999997, "Zn": 4.6970995, 
    "Ga": 2.7846509, "Ge": 3.3333615, "As": 4.4925, "Se": 3.865861, 
    "Br": 4.2251109, "Rb": 1.8456039999999998, "Sr": 2.8234336, "Y": 2.9551299999999996, 
    "Zr": 3.1039499999999998, "Nb": 2.920722, "Mo": 3.172215, "Tc": 3.2846905, 
    "Ru": 3.15525, "Rh": 3.1609499999999997, "Pd": 3.8874299999999997, "Ag": 3.1371170000000004, 
    "Cd": 4.496911, "In": 2.7431776, "Sn": 3.1159250000000003, "Sb": 3.7811945000000002, 
    "Te": 3.5193920000000003, "I": 3.6961116, "Xe": 6.09292155, "Cs": 1.711139774, 
    "Ba": 2.533522, "La": 2.55345, "Ce": 2.4442999999999997, "Pr": 2.2555, 
    "Nd": 1.8045000000000002, "Eu": 2.4031925, "Tb": 2.3494, "Dy": 2.793525, 
    "Tm": 2.577655, "Yb": 3.1370799999999996, "Lu": 2.5429355, "Hf": 3.405535, 
    "Ta": 3.613785, "W": 3.523885, "Re": 3.84176, "Os": 3.6691150000000006, 
    "Ir": 3.7016099999999996, "Pt": 3.4154150000000003, "Au": 3.4584615, "Tl": 2.8656435, 
    "Pb": 3.52996822, "Bi": 3.171577, "Po": 3.2569999999999997, "At": 3.2587550000000003, 
    "Fr": 1.7933704500000003, "Ra": 2.5892120000000003, "Ac": 2.5151130000000004
}

polarizability = {
    "H": 0.3988592920353982, "He": 0.12245575221238937, "Li": 14.523230088495575, "Be": 3.3398230088495575, 
    "B": 1.8141592920353982, "C": 1.0, "N": 0.6548672566371682, "O": 0.4690265486725663, 
    "F": 0.3309734513274336, "Ne": 0.2354955752212389, "Na": 14.398230088495573, "Mg": 6.300884955752212, 
    "Al": 5.11504424778761, "Si": 3.300884955752212, "P": 2.2123893805309733, "S": 1.7168141592920352, 
    "Cl": 1.2920353982300883, "Ar": 0.9807964601769911, "K": 25.637168141592916, "Ca": 14.230088495575222, 
    "Sc": 8.584070796460177, "Ti": 8.849557522123893, "V": 7.699115044247787, "Cr": 7.345132743362831, 
    "Mn": 6.017699115044247, "Fe": 5.486725663716814, "Co": 4.867256637168142, "Ni": 4.336283185840708, 
    "Cu": 4.11504424778761, "Zn": 3.42212389380531, "Ga": 4.424778761061947, "Ge": 3.5398230088495573, 
    "As": 2.654867256637168, "Se": 2.557522123893805, "Br": 1.8584070796460175, "Kr": 1.4849557522123893, 
    "Rb": 28.300884955752213, "Sr": 17.451327433628318, "Y": 14.336283185840706, "Zr": 9.91150442477876, 
    "Nb": 8.672566371681416, "Mo": 7.699115044247787, "Tc": 6.991150442477876, "Ru": 6.371681415929203, 
    "Rh": 5.840707964601769, "Pd": 2.3132743362831856, "Ag": 4.867256637168142, "Cd": 4.070796460176991, 
    "In": 5.752212389380531, "Sn": 4.6902654867256635, "Sb": 3.805309734513274, "Te": 3.3628318584070795, 
    "I": 2.911504424778761, "Xe": 2.4176991150442477, "Cs": 35.477876106194685, "Ba": 24.07079646017699, 
    "La": 19.02654867256637, "Ce": 18.141592920353983, "Pr": 19.11504424778761, "Nd": 18.4070796460177, 
    "Pm": 17.699115044247787, "Sm": 16.991150442477874, "Eu": 16.283185840707965, "Gd": 13.982300884955752, 
    "Tb": 15.044247787610619, "Dy": 14.424778761061946, "Ho": 13.805309734513273, "Er": 13.27433628318584, 
    "Tm": 12.743362831858406, "Yb": 12.300884955752212, "Lu": 12.123893805309734, "Hf": 9.11504424778761, 
    "Ta": 6.548672566371681, "W": 6.017699115044247, "Re": 5.486725663716814, "Os": 5.044247787610619, 
    "Ir": 4.778761061946902, "Pt": 4.2477876106194685, "Au": 3.1858407079646014, "Hg": 3.0008849557522117, 
    "Tl": 4.424778761061947, "Pb": 4.15929203539823, "Bi": 4.2477876106194685, "Po": 3.893805309734513, 
    "At": 3.716814159292035, "Rn": 3.0973451327433628, "Fr": 28.123893805309734, "Ra": 21.769911504424776, 
    "Ac": 17.964601769911503, "Th": 19.20353982300885, "Pa": 13.628318584070795, "U": 11.415929203539822, 
    "Np": 13.362831858407079, "Pu": 11.681415929203538, "Am": 11.5929203539823, "Cm": 12.743362831858406, 
    "Bk": 11.061946902654867, "Cf": 10.79646017699115, "Es": 10.442477876106194, "Fm": 10.0, 
    "Md": 9.646017699115044, "No": 9.734513274336283, "Lr": 28.318584070796458, "Rf": 9.91150442477876, 
    "Db": 3.716814159292035, "Sg": 3.5398230088495573, "Bh": 3.3628318584070795, "Hs": 3.1858407079646014, 
    "Mt": 3.0088495575221237, "Ds": 2.831858407079646, "Rg": 2.831858407079646, "Cn": 2.47787610619469, 
    "Nh": 2.566371681415929, "Fl": 2.743362831858407, "Mc": 6.283185840707964, "Ts": 6.725663716814159, 
    "Og": 5.132743362831858
}

def extract_coord_and_element_from_cif(
    structure_id,                              # Structural ID assigned to each MOF in dataset
    cif_path,                                  # Path assigned to temporarily written CIF file 
    supercell_matrix=[[1,0,0],[0,1,0],[0,0,1]] # Ensures generation of a single unit cell
):

    # Load and generate supercell
    loaded_cif = read(cif_path)
    supercell_atoms = make_supercell(loaded_cif, supercell_matrix, order='cell-major')

    # Center point cloud to origin
    positions = supercell_atoms.get_positions()
    center = positions.mean(axis=0)
    shifted_positions = positions - center

    # Return coordinates and elements as dictionary
    coord_and_atom_type = {
        "x": shifted_positions[:, 0],
        "y": shifted_positions[:, 1],
        "z": shifted_positions[:, 2],
        "element": supercell_atoms.get_chemical_symbols()
    }
    
    return coord_and_atom_type
    
def featureize_atomic_point_cloud(
    coord_and_atom_type,                                # Point cloud containing (xyz + element)
    pauling_en_dict, hardness_dict, polarizability_dict # Pre-defined atomic property dicitionaries
):
    
    
    pt = Chem.GetPeriodicTable()
    elements = coord_and_atom_type["element"]
    positions = np.column_stack((
        coord_and_atom_type["x"],
        coord_and_atom_type["y"],
        coord_and_atom_type["z"]
    ))

    # Assign atomic properties to each row:
    atomic_numbers_arr = [atomic_numbers[el] for el in elements]

    pauling_en_arr     = [pauling_en_dict.get(el, 0.0) for el in elements]
    hardness_arr       = [hardness_dict.get(el, 0.0) for el in elements]
    polarizability_arr = [polarizability_dict.get(el, 0.0) for el in elements]

    atomic_mass_arr   = [pt.GetAtomicWeight(Z) for Z in atomic_numbers_arr]
    vdw_radii_arr     = [pt.GetRvdw(Z) for Z in atomic_numbers_arr]
    cov_radii_arr     = [pt.GetRcovalent(Z) for Z in atomic_numbers_arr]
    n_outer_elec_arr  = [pt.GetNOuterElecs(Z) for Z in atomic_numbers_arr]

    return np.column_stack((
        positions[:, 0],
        positions[:, 1],
        positions[:, 2],
        pauling_en_arr,
        hardness_arr,
        polarizability_arr,
        vdw_radii_arr,
        cov_radii_arr,
        atomic_mass_arr,
        atomic_numbers_arr,
        n_outer_elec_arr
    ))

### GradCAM Logic:

In [4]:
def gradcam_pointnet2(model, xyz_rgb, extra_features=None, target_layer="sa1"):
    model.eval()
    feats = {}
    grads = {}

    def save_feats(module, input, output):
        # output = (new_xyz, new_points, fps_idx)
        feats['value'] = output[1].detach()  # new_points

    def save_grads(module, grad_input, grad_output):
        # grad_output[1] corresponds to new_points gradient
        # sometimes grad_output is a tuple with one element
        try:
            grads['value'] = grad_output[1].detach()  # gradient w.r.t new_points
        except IndexError:
            grads['value'] = grad_output[0].detach()

    hook_layer = getattr(model, target_layer)
    fwd_hook = hook_layer.register_forward_hook(save_feats)
    bwd_hook = hook_layer.register_full_backward_hook(save_grads)

    # Forward pass
    pred, _ = model(xyz_rgb, extra_features)

    # Backward pass
    model.zero_grad()
    pred.sum().backward()

    fwd_hook.remove()
    bwd_hook.remove()

    # Check that grads were captured
    if 'value' not in grads:
        raise RuntimeError(f"Gradients not captured for layer {target_layer}")

    # GradCAM computation
    weights = grads['value'].mean(dim=2)  # (B, C)
    cam = (weights.unsqueeze(2) * feats['value']).sum(dim=1)  # (B, N)    
    cam = F.relu(cam)
    
    # Map back to original coordinates
    fps_idx = model.fps_indices[target_layer]
    coords = xyz_rgb[:, :3, :].permute(0, 2, 1)
    coords_sampled = torch.stack([coords[b, fps_idx[b]] for b in range(coords.size(0))], dim=0)
    heatmap_out = torch.cat([coords_sampled, cam.unsqueeze(-1)], dim=-1)

    return heatmap_out, pred

# Utility

In [11]:
def upsample_single(pc, upsample_points):

    _, C, N = pc.shape
    pc = pc.squeeze(0)

    if N == upsample_points:
        return pc.unsqueeze(0)

    repeat_factor = upsample_points // N
    remainder = upsample_points % N

    repeated = []

    if repeat_factor > 0:
        total_N = repeat_factor * N
        rand_idx = torch.randint(low=0, high=N, size=(total_N,), device=pc.device)
        repeated_pc = pc[:, rand_idx]
        repeated.append(repeated_pc)

    if remainder > 0:
        idx = torch.randint(low=0, high=N, size=(remainder,), device=pc.device)
        sampled = pc[:, idx]
        repeated.append(sampled)

    pc_new = torch.cat(repeated, dim=1) if repeated else pc
    return pc_new.unsqueeze(0)

def prepare_single_sample(
    atomic_pc,
    upsample_points, # represents the number of points the point cloud must be upsampled to
    pressure,
    temperature,
    adsorbate
):
    
    # Prepare Point Cloud for Input
    pc_array = atomic_pc
    print(pc_array.shape)
    # Hard Coded Scaling Coefficients:
    pc_array[:, 3] /= 1.3693063937629155, 
    pc_array[:, 4] /= 7.967065, 
    pc_array[:, 5] /= 9.91150442477876, 
    pc_array[:, 6] /= 2.3, 
    pc_array[:, 7] /= 1.75, 
    pc_array[:, 8] /= 91.224, 
    pc_array[:, 9] /= 40.0,
    pc_array[:, 10] /= 11.0
    
    # Prepare for Model Inference:
    pc_tensor = torch.tensor(pc_array.T, dtype=torch.float32).unsqueeze(0)
    pc_tensor = upsample_single(pc_tensor, upsample_points)

    gas        = adsorbate
    T          = temperature
    P          = pressure
    forcefield = "TraPPE"
        
    gas_encoded = torch.tensor(encode_gas(gas, forcefield), dtype=torch.float32)
    T_tensor = torch.tensor([T / 298.0], dtype=torch.float32)
    P_tensor = torch.tensor([np.log10(P) / 7.0], dtype=torch.float32)
    extra_features = torch.cat([gas_encoded, T_tensor, P_tensor], dim=0)
    
    # Return Values:
    return pc_tensor, extra_features

def assign_gradcam_importance(atomic_pc,
                              grad_cam, max_noise=0.1,
                              normalize=True):
    
    # Step 0: Load MOF as Point Cloud
    point_cloud = atomic_pc

    # Step 1: Merge close points in grad_cam
    gc_points = grad_cam[:, :3]  # xyz
    gc_values = grad_cam[:, 3]   # i values

    # Sum Close Points
#     merged = np.zeros(len(gc_points), dtype=bool)
#     merged_gc = []

#     for i in range(len(gc_points)):
#         if merged[i]:
#             continue
#         dists = np.linalg.norm(gc_points - gc_points[i], axis=1)
#         close_idx = np.where(dists < max_noise)[0]
#         total_i = gc_values[close_idx].sum()
#         merged[i] = True
#         merged[close_idx] = True
#         merged_gc.append(np.append(gc_points[i], total_i))

#     merged_gc = np.array(merged_gc)  # (M_merged,4)

    # Average Close Points
    merged = np.zeros(len(gc_points), dtype=bool)
    merged_gc = []

    for i in range(len(gc_points)):
        if merged[i]:
            continue

        dists = np.linalg.norm(gc_points - gc_points[i], axis=1)
        close_idx = np.where(dists < max_noise)[0]

        avg_xyz = gc_points[close_idx].mean(axis=0)
        avg_i = gc_values[close_idx].mean()

        merged[close_idx] = True
        merged_gc.append(np.append(avg_xyz, avg_i))

    merged_gc = np.array(merged_gc)
    
    # Step 2: Add importance column to point_cloud
    N = point_cloud.shape[0]
    importance_col = np.full((N, 1), -1.0)  # default -1

    for idx in range(N):
        pc_xyz = point_cloud[idx, :3]
        if len(merged_gc) > 0:
            dists = np.linalg.norm(merged_gc[:, :3] - pc_xyz, axis=1)
            min_idx = np.argmin(dists)
            if dists[min_idx] < max_noise:
                importance_col[idx, 0] = merged_gc[min_idx, 3]

    # Step 3: Normalize importance values (excluding -1)
    if normalize:
        mask = importance_col[:, 0] != -1
        if np.any(mask):
            values = importance_col[mask, 0]
            min_val, max_val = values.min(), values.max()
            if max_val > min_val:  # avoid divide by zero
                importance_col[mask, 0] = (values - min_val) / (max_val - min_val)
            else:
                importance_col[mask, 0] = 1.0  # all values equal -> set 1

    # Append as new column
    point_cloud_with_importance = np.hstack([point_cloud, importance_col])

    return point_cloud_with_importance

# Visualization

In [14]:
def visualize_heatmap_fast(point_cloud_with_importance, filename, image_size=(1600, 1200), remove_null_features=False, sphere_radius=1.0):
    """
    Visualize a point cloud with per-point importance as colored spheres using PyVista.
    All points are treated the same.
    """
    vtk.vtkObject.GlobalWarningDisplayOff()
    
    coords = point_cloud_with_importance[:, :3]
    radii = point_cloud_with_importance[:, 6]
    importance = point_cloud_with_importance[:, -1].copy()
    importance = np.power(importance, 1.0/1.25)

    # Identify null (importance = -1)
    null_mask = importance == -1

    # Replace all -1 with 0 for colormap
    importance[null_mask] = 0
    importance_clipped = np.clip(importance, 0, 1)

    # Create main cloud
    cloud = pv.PolyData(coords)
    cloud['importance'] = importance_clipped
    cloud['radii'] = radii * 1.5
    glyphs = cloud.glyph(scale='radii', geom=pv.Sphere(radius=sphere_radius))

    plotter = pv.Plotter(off_screen=True, window_size=image_size)

    # Add all points with importance coloring
    plotter.add_mesh(
        glyphs, 
        scalars='importance', 
        cmap='magma', 
        smooth_shading=True, 
        clim=[0,1],
        specular=0.25,            # shiny reflections
        specular_power=25,       # crisp highlights
        ambient=0.3,
        diffuse=0.8
    )

    # Optional: white spheres for original -1 importance
    if remove_null_features and null_mask.any():
        null_cloud = pv.PolyData(coords[null_mask])
        null_glyphs = null_cloud.glyph(scale=False, geom=pv.Sphere(radius=sphere_radius))
        plotter.add_mesh(null_glyphs, color='white', smooth_shading=True)

    # Camera setup (top-down XY view)
    plotter.view_xy()

    # White background + black grid
    plotter.background_color = 'white'
    plotter.enable_lightkit()
    plotter.show(screenshot=filename)
    print(f"Saved visualization as '{filename}'")



# Main

In [None]:
# Load Atomic Point Cloud from CIF for Grad-CAM
basic_pc = extract_coord_and_element_from_cif("GRAD_CAM_MOF", grad_cam_cif_path)
atomic_pc = featureize_atomic_point_cloud(basic_pc, pauling_en, hardness, polarizability)

# Load Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fixed - used for model training
extra_feature_dim            = 24
atomic_feature_channel_num   = 8

model = get_model(
    extra_feat_dim=extra_feature_dim,               
    atomic_feature_channel_num=atomic_feature_channel_num
).to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

pc_tensor, extra_features = prepare_single_sample(atomic_pc, upsample_points=2048,
                                                  pressure=pressure, temperature=temperature, adsorbate=adsorbate)

pc_tensor = pc_tensor.to(device)
extra_features = extra_features.unsqueeze(0).to(device)

# Perform Single Sample Prediction + Obtain GradCAM
pred, _ = model(pc_tensor, extra_features)
target_layer = "sa1"
heatmap, _ = gradcam_pointnet2(model, pc_tensor, extra_features, target_layer=target_layer)
coords_with_importance = heatmap[0].cpu().numpy()

# Print Results:
print("Point cloud with importance shape:", coords_with_importance.shape)
print("Prediction: " + str(pred))

pc_for_vis = assign_gradcam_importance(
    atomic_pc,
    coords_with_importance,
    normalize=True
)

visualize_heatmap_fast(pc_for_vis, filename=fig_save_path)