# Pretrained ALIGNN Models
Ashley S. Dale

Notebook loads a pretrained ALIGNN model, and calculates the loss landscape.

In [2]:
%matplotlib widget

In [5]:
import copy
import numpy as np
import matplotlib.pyplot as plt

from alignn.pretrained import *
from jarvis.db.figshare import data

import ipywidgets as widgets
from torchinfo import summary

import loss_landscapes
import loss_landscapes.metrics

from abc import ABC, abstractmethod
from loss_landscapes.model_interface.model_wrapper import ModelWrapper
import torch

# Load Model

In [6]:
list_of_pretrained_models = list(get_all_models().keys())

-> Select the `jv_formation_energy_peratom_alignn` model for the demo

In [None]:
style = {'description_width': 'initial'}

config_selector = widgets.Dropdown(
    options=list_of_pretrained_models,
    value=None,
    description='Select Model',
    style=style,
    disabled=False,
)

display(config_selector)

In [None]:
# This is the model we will load
model_name = config_selector.value
print("Selected: ", model_name)

In [None]:
model = get_figshare_model(model_name)

In [None]:
summary(model)

We can use the model weights distribution to help scale the distance traveled on the loss landscape:

In [None]:
fc_layer_wts = model.fc.weight.detach().numpy()
counts, bins = np.histogram(fc_layer_wts)

log_counts, log_bins = np.histogram(np.log(np.abs(fc_layer_wts)))

fig, ax = plt.subplots(1,2, figsize=(8,4))
ax[0].hist(bins[:-1], bins, weights=counts)
ax[0].set_title('Weights from Final FC Layer')
ax[0].set_xlabel('Wt Value')
ax[0].set_ylabel('Count')

ax[1].hist(log_bins[:-1], log_bins, weights=log_counts)
ax[1].set_title('Weights from Final FC Layer')
ax[1].set_xlabel('Log(Abs(Wt Value))')
ax[1].set_ylabel('Count')

fig.tight_layout()
plt.show()

In [None]:
print('avg wt val: ', np.mean(fc_layer_wts))
print('std wt val: ', np.std(fc_layer_wts))
wt_std = np.std(fc_layer_wts)

# Load Data

In [None]:
## For larger sample
target="formation_energy_peratom"
n_samples = 250

d = data("dft_3d")
d = d[:n_samples]

In [14]:
def get_data_loader(atoms_array, target):
    from torch.utils.data import DataLoader

    neighbor_strategy="k-nearest"
    atom_features="cgcnn"
    use_canonize=True
    line_graph=True
    batch_size = 1
    workers = 0
    pin_memory=False

    mem = []
    for i, ii in enumerate(atoms_array):
        info = {}
        info["atoms"] = ii['atoms']
        info["prop"] = ii[target]
        info["jid"] = str(i)
        mem.append(info)

    test_data = get_torch_dataset(
        dataset=mem,
        target="prop",
        neighbor_strategy=neighbor_strategy,
        atom_features=atom_features,
        use_canonize=use_canonize,
        line_graph=line_graph,
    )
    collate_fn = test_data.collate_line_graph

    test_loader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=workers,
        pin_memory=pin_memory,
    )

    return test_loader

In [None]:
test_dataloader = get_data_loader(d, target)

In [None]:
single_batch = next(iter(test_dataloader))
len(single_batch)

In [None]:
print('Graph 1: ', single_batch[0])
print('Graph 2: ', single_batch[1])
print(target+': ', single_batch[2])

# Prepare Loss Landscape

In [18]:
model_final = copy.deepcopy(model)

In [19]:
STEPS=20
DISTANCE=10*wt_std

### Define Loss Function

In [20]:
criterion = torch.nn.L1Loss()

### Define Model Wrapper

In [21]:

class Metric(ABC):
    """ A quantity that can be computed given a model or an agent. """

    def __init__(self):
        super().__init__()

    @abstractmethod
    def __call__(self, model_wrapper: ModelWrapper):
        pass

class Loss(Metric):
    """ Computes a specified loss function over specified input-output pairs. """
    def __init__(self, loss_fn, model, inputs: torch.Tensor, target: torch.Tensor):
        super().__init__()
        self.loss_fn = loss_fn
        self.inputs = inputs
        self.model = model
        self.target = target

    def __call__(self, model_wrapper: ModelWrapper) -> float:
        outputs = model_wrapper.forward(self.inputs)
        err = self.loss_fn(self.target, outputs)
        return err

## Calculate Loss Function

In [22]:
metric = Loss(criterion, model_final.eval(), (single_batch[0], single_batch[1]), single_batch[2])

In [None]:
loss_data_fin = loss_landscapes.random_plane(model_final, metric, distance=DISTANCE, steps=STEPS, normalization=None, deepcopy_model=True)

# Plot the surface

In [24]:

perturbation_range = np.round(np.linspace(-0.5*DISTANCE, 0.5*DISTANCE, 8), 3)

In [None]:
save_fig_name = os.path.join('loss_contours.png')
fig, ax = plt.subplots(1, 1)
plt.contourf(np.log(loss_data_fin), levels=50)
ax.set_title('Loss Contours \n'+ r'$L(\theta + \alpha i + \beta j$)')
ax.axis('square')
ax.scatter((STEPS-1)/2., (STEPS-1)/2., 20, 'r', '*')
ax.set_xticks(np.linspace(0, STEPS, 8, endpoint=True))
ax.set_xticklabels(perturbation_range)
ax.set_yticks(np.linspace(0, STEPS, 8, endpoint=True))
ax.set_yticklabels(perturbation_range)
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$\beta$')
plt.colorbar()
fig.savefig(save_fig_name, transparent=True, dpi=300)
plt.show()

In [None]:
import numpy as np
save_fig_name = os.path.join('loss_surface.png')
fig = plt.figure()
ax = plt.axes(projection='3d')
X = np.array([[j for j in range(STEPS)] for i in range(STEPS)])
Y = np.array([[i for _ in range(STEPS)] for i in range(STEPS)])
ax.plot_surface(X, Y, np.log(loss_data_fin), rstride=1, cstride=1, cmap='viridis', edgecolor='none')
ax.set_xticks(np.linspace(0, STEPS, 8, endpoint=True))
ax.set_xticklabels(perturbation_range)
ax.set_yticks(np.linspace(0, STEPS, 8, endpoint=True))
ax.set_yticklabels(perturbation_range)
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$\beta$')
ax.set_zlabel('Loss')
fig.savefig(save_fig_name, transparent=True, dpi=300)
plt.show()