In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [2]:
data = pd.read_feather('https://github.com/nicholas-miklaucic/ood-materials/raw/main/mpc_full_feats_scaled_split.feather')
data

Unnamed: 0,comp,0-norm,2-norm,3-norm,5-norm,7-norm,10-norm,minimum Number,maximum Number,range Number,...,infoY_delta_e,statY_bandgap,infoY_bandgap,Rsplt1,Rsplt2,Rsplt3,Rsplt4,Rsplt5,piezo,dataset_split
0,In1,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,2.864847,0.010989,-1.760193,...,False,False,False,False,False,False,False,False,False,1
1,Mg1,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,-0.041846,-1.664706,-1.760193,...,False,False,False,False,True,False,False,True,False,0
2,Be1,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,-0.670320,-2.027018,-1.760193,...,False,False,False,False,False,False,False,False,False,2
3,Hf1,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,4.671710,1.052637,-1.760193,...,False,False,False,False,False,False,False,False,False,1
4,P1,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,0.193832,-1.528839,-1.760193,...,False,False,False,False,False,False,False,False,False,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84185,Sb2W1O6,-0.302404,0.587564,0.680144,0.747711,0.764994,0.771629,-0.356083,1.143215,1.446811,...,False,False,False,False,False,False,False,False,True,0
84186,Sr1Hf1O3,-0.302404,-0.032556,0.033257,0.130844,0.170592,0.193219,-0.356083,1.052637,1.349629,...,False,False,False,False,False,False,False,False,True,1
84187,Rb1Ta1O3,-0.302404,-0.032556,0.033257,0.130844,0.170592,0.193219,-0.356083,1.097926,1.398220,...,False,False,False,False,False,False,False,False,True,1
84188,Ba1Ni1O3,-0.302404,-0.032556,0.033257,0.130844,0.170592,0.193219,-0.356083,0.328012,0.572174,...,False,False,False,False,False,False,False,False,True,0


In [6]:
comps[0].get_el_amt_dict()

{'In': 1.0}

In [4]:
from pymatgen.core import Composition

comps = [Composition(c) for c in data['comp']]

comps[0]

[Composition('In1'),
 Composition('Mg1'),
 Composition('Be1'),
 Composition('Hf1'),
 Composition('P1'),
 Composition('Xe1'),
 Composition('Hg1'),
 Composition('Br1'),
 Composition('Sr1'),
 Composition('Xe1'),
 Composition('Ti1'),
 Composition('Cr1'),
 Composition('Cs1'),
 Composition('Ac1'),
 Composition('Sc1'),
 Composition('Li1'),
 Composition('S1'),
 Composition('I1'),
 Composition('Xe1'),
 Composition('Cu1'),
 Composition('Se1'),
 Composition('K1'),
 Composition('Al1'),
 Composition('Eu1'),
 Composition('Sb1'),
 Composition('Sc1'),
 Composition('Rb1'),
 Composition('Pr1'),
 Composition('Li1'),
 Composition('U1'),
 Composition('Hg1'),
 Composition('Se1'),
 Composition('Si1'),
 Composition('Ru1'),
 Composition('Na1'),
 Composition('Si1'),
 Composition('Kr1'),
 Composition('W1'),
 Composition('Na1'),
 Composition('Nb1'),
 Composition('Ho1'),
 Composition('Hg1'),
 Composition('Tc1'),
 Composition('Kr1'),
 Composition('Xe1'),
 Composition('Br1'),
 Composition('I1'),
 Composition('Hg1'),

In [3]:
from flax import linen as nn
from jaxtyping import Float, Array
from eins import EinsOp

class DeepSetEncoder(nn.Module):
    """Deep Sets with several types of pooling. Permutation-invariant encoder."""
    phi: nn.Module

    @nn.compact
    def __call__(self, x: Float[Array, 'batch token chan'], training: bool) -> Float[Array, 'batch out_dim']:
        phi_x = self.phi(x, training=training)
        phi_x = EinsOp('batch token out_dim -> batch out_dim token')(phi_x)
        op = 'batch out_dim token -> batch out_dim'
        phi_x_mean = EinsOp(op, reduce='mean')(phi_x)
        phi_x_std = EinsOp(op, reduce='std')(phi_x)
        phi_x = jnp.concatenate([phi_x_mean, phi_x_std], axis=-1)
        normed = nn.LayerNorm(dtype=x.dtype)(phi_x)
        return normed
    

class CompRegressor(nn.Module):
    