In [18]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

theme, cs = rp.mpl_setup(False)

In [19]:
elem_embs = pd.read_json('https://raw.githubusercontent.com/CompRhys/aviary/refs/heads/main/aviary/embeddings/element/megnet16.json')
elem_embs

Unnamed: 0,Null,H,He,Li,Be,B,C,N,O,F,...,At,Rn,Fr,Ra,Ac,Th,Pa,U,Np,Pu
0,-0.044911,0.352363,-0.06722,-0.161449,-0.111666,0.260108,0.398148,0.611494,-0.113972,-0.308105,...,0.013335,-0.011944,-0.04345,0.016968,-0.918221,-0.353667,-0.498012,-0.061629,-0.47124,-0.278194
1,0.004152,0.635952,0.141113,0.179496,0.760182,0.707898,0.744485,0.00181,-0.188673,-0.575614,...,0.024167,0.017634,0.04272,-0.025037,0.073024,-0.067413,0.487307,0.243641,0.448124,0.044107
2,0.012933,0.217338,0.164495,-0.114184,0.057829,0.064846,0.662636,0.620457,0.108998,-0.171835,...,-0.009026,-0.005495,-0.044104,0.031968,-0.595052,-0.824566,-0.745676,-0.346598,-0.170911,-0.15482
3,-0.010163,-0.191956,0.136701,0.13651,0.250147,-0.300478,-0.520578,-0.446009,0.214548,0.462324,...,0.046445,-0.009977,0.007274,-0.017894,0.450472,-0.025612,-0.009069,0.056789,0.299127,0.376608
4,0.007606,0.253751,0.016505,0.106477,-0.396934,-0.510219,-0.673885,0.094431,0.371144,0.702146,...,0.003559,-0.030206,-0.02309,-0.02751,-0.345009,-0.346526,-0.26353,-0.442932,-0.490186,-0.445192
5,0.029269,-0.423261,0.073929,0.047109,0.12885,0.082544,-0.476654,-0.817922,-0.103028,0.220106,...,0.030743,-0.04844,0.029772,0.013903,0.607715,0.809945,0.749626,0.640021,0.501166,0.493097
6,0.03398,0.221297,0.151093,0.065104,0.037942,0.589792,0.69872,-0.130236,-0.404307,-0.503306,...,-0.014643,-0.007266,0.023424,-0.024473,-0.072247,0.161332,0.464819,0.237543,0.197754,-0.023476
7,0.018202,-0.452411,-0.13817,-0.069099,0.135463,0.523288,0.607385,-0.645908,-0.5028,-0.39498,...,-0.016754,-0.025606,-0.033319,0.042808,0.17949,0.469464,0.553796,0.706783,0.673036,0.487316
8,0.042308,-1.007713,-0.180052,-0.210581,-0.202221,-0.672266,-0.56216,0.021001,0.313225,0.28345,...,0.046804,0.004872,0.038471,-0.026307,0.290643,0.045801,-0.281812,0.170462,0.243258,0.303537
9,-0.01932,-0.289936,-0.130642,-0.030035,-0.046165,0.490484,0.673558,-0.271362,-0.465547,-0.58429,...,0.003551,0.044316,-0.016308,-0.023788,0.057812,0.178863,0.412992,0.436729,0.359024,0.260558


In [20]:
import torch
from torch import nn
import torch.nn.functional as F
from aviary.roost.model import DescriptorNetwork
from pymatgen.core import Composition
from torch import Tensor, LongTensor

def comp2graph(composition):
    comp_dict = Composition(composition).get_el_amt_dict()
    elements = list(comp_dict)

    weights = list(comp_dict.values())
    weights = np.atleast_2d(weights).T / np.sum(weights)

    try:
        elem_fea = np.vstack([elem_embs[elements]]).T
    except AssertionError as exc:
        raise AssertionError(
            f"{composition} contains element types not in embedding"
        ) from exc
    except ValueError as exc:
        raise ValueError(
            f"{composition} composition cannot be parsed into elements"
        ) from exc

    n_elems = len(elements)
    self_idx = []
    nbr_idx = []
    for elem_idx in range(n_elems):
        self_idx += [elem_idx] * n_elems
        nbr_idx += list(range(n_elems))

    # convert all data to tensors
    elem_weights = Tensor(weights)
    elem_fea = Tensor(elem_fea)
    self_idx = LongTensor(self_idx)
    nbr_idx = LongTensor(nbr_idx)
    return (elem_weights, elem_fea, self_idx, nbr_idx)

# https://github.com/CompRhys/aviary/blob/181e2b2b2d679a12f6dbb430853d92508e8d71f2/aviary/roost/data.py#L140C1-L212C6
def collate_batch(samples):
    # define the lists
    batch_elem_weights = []
    batch_elem_fea = []
    batch_self_idx = []
    batch_nbr_idx = []
    crystal_elem_idx = []

    cry_base_idx = 0
    for idx, inputs in enumerate(samples):
        elem_weights, elem_fea, self_idx, nbr_idx = inputs

        n_sites = elem_fea.shape[0]  # number of atoms for this crystal

        # batch the features together
        batch_elem_weights.append(elem_weights)
        batch_elem_fea.append(elem_fea)

        # mappings from bonds to atoms
        batch_self_idx.append(self_idx + cry_base_idx)
        batch_nbr_idx.append(nbr_idx + cry_base_idx)

        # mapping from atoms to crystals
        crystal_elem_idx.append(torch.tensor([idx] * n_sites))

        # increment the id counter
        cry_base_idx += n_sites

    return (        
        torch.cat(batch_elem_weights, dim=0),
        torch.cat(batch_elem_fea, dim=0),
        torch.cat(batch_self_idx, dim=0),
        torch.cat(batch_nbr_idx, dim=0),
        torch.cat(crystal_elem_idx),        
    )

class CompositionEmbedding(torch.nn.Module):
    def __init__(self, elem_input_dim: int = 16, elem_hidden_dim: int = 64, comp_embed_dim: int = 64):
        super().__init__()
        self.gnn = DescriptorNetwork(elem_emb_len=elem_input_dim, elem_fea_len=elem_hidden_dim, n_graph=1)
        self.head = nn.Linear(elem_hidden_dim, comp_embed_dim)
        self.rescale = nn.Parameter(torch.ones(1, dtype=torch.float32))

    def embed(self, X):
        return self.head(self.gnn(*X))

    def forward(self, X1, X2):
        z1 = self.embed(X1)
        z2 = self.embed(X2)

        dists = torch.sqrt(torch.sum(torch.square(z1 - z2), axis=1))
        return self.to_probability(dists)
    
    def to_probability(self, dists):
        return 1 - torch.tanh(dists * self.rescale)

In [21]:
model = torch.load('checkpoints/test.pt')
model

  model = torch.load('checkpoints/test.pt')


CompositionEmbedding(
  (gnn): DescriptorNetwork(n_graph=1, cry_heads=3, elem_emb_len=16, elem_fea_len=63)
  (head): Linear(in_features=64, out_features=64, bias=True)
)

In [22]:
benchmark = pd.read_csv('https://raw.githubusercontent.com/usccolumbia/cspbenchmark/main/data/CSPbenchmark_test_data.csv')
benchmark_ids = benchmark['material_id']

In [23]:
def compute_scores(comp_1: str, other_comps: list[str], batch_size: int = 32):
    X1 = collate_batch([comp2graph(comp_1)])
    X2 = []
    for i in range(0, len(other_comps), batch_size):
        X2.append(collate_batch([comp2graph(c) for c in other_comps.iloc[i:i+batch_size]]))

    model.eval()
    z1 = model.embed(X1)
    z2 = torch.cat([model.embed(x) for x in X2])

    probs = model.to_probability(torch.cdist(z1, z2).reshape(-1))
    return probs.numpy(force=True)

In [24]:
scores = compute_scores('Nb3Si', benchmark['full_formula'])
scores.round(2)

array([0.18, 0.26, 0.05, 0.06, 0.  , 0.11, 0.18, 0.06, 0.02, 0.03, 0.01,
       0.  , 0.1 , 0.04, 0.03, 0.07, 0.06, 0.03, 0.01, 0.22, 0.04, 0.12,
       0.  , 0.06, 0.04, 0.02, 0.13, 0.05, 0.03, 0.08, 0.08, 0.12, 0.07,
       0.05, 0.1 , 1.  , 0.02, 0.06, 0.04, 0.03, 0.07, 0.  , 0.04, 0.02,
       0.  , 0.  , 0.03, 0.  , 0.13, 0.02, 0.  , 0.  , 0.  , 0.01, 0.  ,
       0.05, 0.  , 0.08, 0.18, 0.  , 0.04, 0.05, 0.04, 0.02, 0.  , 0.  ,
       0.02, 0.02, 0.03, 0.02, 0.03, 0.  , 0.02, 0.02, 0.  , 0.01, 0.04,
       0.01, 0.09, 0.04, 0.11, 0.01, 0.  , 0.08, 0.01, 0.01, 0.13, 0.13,
       0.04, 0.06, 0.19, 0.1 , 0.22, 0.04, 0.03, 0.04, 0.04, 0.04, 0.03,
       0.03, 0.03, 0.03, 0.01, 0.03, 0.02, 0.03, 0.02, 0.03, 0.04, 0.  ,
       0.04, 0.01, 0.02, 0.04, 0.01, 0.  , 0.02, 0.  , 0.02, 0.  , 0.02,
       0.02, 0.03, 0.  , 0.01, 0.15, 0.06, 0.01, 0.  , 0.15, 0.02, 0.  ,
       0.01, 0.  , 0.04, 0.  , 0.  , 0.  , 0.  , 0.06, 0.19, 0.25, 0.03,
       0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.17, 0.