In [29]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

is_dark = True
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [30]:
%cd ~/cdv

/home/nmiklaucic/cdv


In [167]:
pd.value_counts(clean['dimensionality'])

  pd.value_counts(clean['dimensionality'])


dimensionality
3D-bulk                  32349
intercalated ion         30180
2D-bulk                   4327
0D-bulk                   1755
1D-bulk                   1278
na                         315
intercalated molecule      196
Name: count, dtype: int64

In [31]:
clean = pd.read_pickle('precomputed/jarvis_dft3d_cleaned/dataframe.pkl')
clean

Unnamed: 0,space_group,formula,e_form,bandgap,atoms,e_total,ehull,dimensionality,density,num_atoms,num_spec,magmom
0,129,"(Ti, Cu, Si, As)",-0.42762,0.000,"[[2.67519992 2.67519992 7.37609819] Ti, [0.891...",-3.37474,0.0423,3D-bulk,5.956,8,4,0.0
1,221,"(Dy, B)",-0.41596,0.000,"[[0. 0. 0.] Dy, [0.81214866 2.04453946 2.04453...",-5.79186,0.0456,3D-bulk,5.522,7,2,0.0
2,119,"(Be, Os, Ru)",0.04847,0.000,"[[0. 0. 0.] Be, [1.83359072 0. 1.74248...",-4.65924,0.3183,intercalated ion,10.960,4,3,0.0
3,14,"(K, Bi)",-0.44140,0.472,"[[0.91524308 6.85585362 9.07269063] K, [ 4.624...",-0.42496,0.0000,intercalated ion,5.145,32,2,0.0
4,164,"(V, Se)",-0.71026,0.000,"[[0. 0. 0.] V, [ 1.67774838 -0.96865035 4.652...",-3.87823,0.0156,2D-bulk,5.718,3,2,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...
75980,164,"(Nd, Mg, Sb)",-0.69038,0.000,"[[0. 0. 0.] Nd, [2.3394364 1.35070495 2.69383...",-1.53096,0.0000,3D-bulk,5.116,5,3,0.0
75983,216,"(Yb, Pr, Pd)",-0.56635,0.000,"[[5.25036997 5.25036997 5.25036997] Yb, [1.750...",-1.70796,0.0000,3D-bulk,8.139,3,3,0.0
75984,216,"(Tb, Tl, Zn)",-0.29921,0.000,"[[1.66349368 1.66349368 1.66349368] Tb, [4.990...",0.05135,0.0000,3D-bulk,9.666,3,3,0.0
75989,216,"(Pr, Tl, Zn)",-0.34112,0.000,"[[1.71052992 1.71052992 1.71052992] Pr, [5.131...",-0.10696,0.0000,3D-bulk,8.517,3,3,0.0


In [170]:
clean.iloc[4]

space_group                                                     164
formula                                                     (V, Se)
e_form                                                     -0.71026
bandgap                                                         0.0
atoms             [[0. 0. 0.] V, [ 1.67774838 -0.96865035  4.652...
e_total                                                    -3.87823
ehull                                                        0.0156
dimensionality                                              2D-bulk
density                                                       5.718
num_atoms                                                         3
num_spec                                                          2
magmom                                                          0.0
Name: 4, dtype: object

In [32]:
from pymatgen.core import Structure
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis.local_env import CrystalNN
from tqdm import tqdm
import warnings
nn = CrystalNN()

from multiprocessing import Pool

warnings.simplefilter('ignore', category=UserWarning)

with Pool(processes=32) as P:    
    graphs = P.map(nn.get_bonded_structure, clean['atoms'])

  self.pid = os.fork()


In [164]:
elements = set()
for struct in clean['atoms']:
    elements.update(set(struct.elements))

print(' '.join([e.symbol for e in sorted(elements)]))
# print(elements)
# print({e.symbol: (e.average_ionic_radius, e.atomic_radius) for e in elements})

K Rb Ba Na Sr Li Ca La Tb Yb Ce Pr Nd Sm Dy Y Ho Er Tm Hf Mg Zr Sc U Ta Ti Mn Be Nb Al Tl V Zn Cr Cd In Ga Fe Co Cu Si Ni Ag Sn Hg Ge Bi B Sb Te Mo As P H Ir Os Pd Ru Pt Rh Pb W Au C Se S I Br N Cl O F


In [165]:
import pickle

with open('precomputed/jarvis_dft3d_cleaned/graphs.pkl', 'wb') as out:
    pickle.dump(graphs, out)

# with open('precomputed/jarvis_dft3d_cleaned/graphs.pkl', 'rb') as out:
#     graphs = pickle.load(out)

We want to split the samples into batches. Because Jax needs fixed shapes, we're going to pad
batches to the same number of nodes and edges. This means that we want to sort into batches such
that the max number of edges in a batch is minimized, ideally just below a power of 2. That's what
the below code does.

In [43]:
from tqdm import tqdm
import functools as ft


sizes = np.array([(len(g.graph.nodes), len(g.graph.edges)) for g in graphs])

def get_parts(numbers, batch, chunk_size):    
    # assert len(numbers) % (batch * chunk_size) == 0
    n_batches = len(numbers) // batch
    parts = np.zeros((batch, n_batches), dtype=jnp.int32)
    part_sizes = np.array([0 for _ in range(n_batches)])
    
    chunk_i = 0
    for sample_is in tqdm(np.argsort(-numbers).reshape(batch // chunk_size, chunk_size * n_batches)):
        sample_sizes = numbers[sample_is]
        n_filled = np.zeros((n_batches,), dtype=jnp.int32)
        for sample_i, sample_size in zip(sample_is, sample_sizes):
            next_i = np.argmin(part_sizes + 10000 * (n_filled == chunk_size))
            parts[chunk_i * chunk_size + n_filled[next_i], next_i] += sample_i
            n_filled[next_i] += 1
            part_sizes[next_i] += sample_size
        chunk_i += 1

    return parts, part_sizes

parts, part_sizes = get_parts(sizes[:, 1], 32, 8)
print(part_sizes)

100%|██████████| 4/4 [00:00<00:00,  4.54it/s]

[950 950 949 ... 947 947 947]





In [129]:
# max node size
max(jnp.take(sizes[:, 0], parts).sum(axis=0))

Array(452, dtype=int32)

In [52]:
jnp.save('precomputed/jarvis_dft3d_cleaned/batches.npy', parts)

In [94]:
crystal

Structure Summary
Lattice
    abc : 4.089078911208881 4.089078911208881 4.089078911208881
 angles : 90.0 90.0 90.0
 volume : 68.37171521292233
      A : 4.089078911208881 0.0 0.0
      B : -0.0 4.089078911208881 -0.0
      C : 0.0 -0.0 4.089078911208881
    pbc : True True True
PeriodicSite: Dy (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]
PeriodicSite: B (0.8121, 2.045, 2.045) [0.1986, 0.5, 0.5]
PeriodicSite: B (2.045, 3.277, 2.045) [0.5, 0.8014, 0.5]
PeriodicSite: B (2.045, 0.8121, 2.045) [0.5, 0.1986, 0.5]
PeriodicSite: B (3.277, 2.045, 2.045) [0.8014, 0.5, 0.5]
PeriodicSite: B (2.045, 2.045, 0.8121) [0.5, 0.5, 0.1986]
PeriodicSite: B (2.045, 2.045, 3.277) [0.5, 0.5, 0.8014]

In [159]:
from flax import struct
from collections import defaultdict
from jaxtyping import Float, Array, Int
from cdv.utils import debug_structure

@struct.dataclass
class NodeData:
    species: Int[Array, 'nodes']
    frac: Float[Array, 'nodes 3']
    cart: Float[Array, 'nodes 3']
    graph_i: Int[Array, 'nodes']

@struct.dataclass
class EdgeData:
    to_jimage: Int[Array, 'edges 3']
    graph_i: Int[Array, 'edges']
    sender: Int[Array, 'edges']
    receiver: Int[Array, 'edges']

@struct.dataclass
class CrystalData:
    dataset_i: Int[Array, 'batch']
    abc: Float[Array, 'batch 3']
    angles_rad: Float[Array, 'batch 3']

@struct.dataclass
class Graphs:
    """Batched/padded graphs. Should be able to sub in for jraph.GraphsTuple."""
    nodes: NodeData
    edges: EdgeData
    n_node: Int[Array, 'graphs']
    n_edge: Int[Array, 'graphs']
    graph_data: CrystalData

    @property
    def senders(self) -> Int[Array, 'edges']:
        return self.edges.sender
    
    @property
    def receivers(self) -> Int[Array, 'edges']:
        return self.edges.receiver
    
    @property
    def globals(self):
        return self.graph_data

    @property
    def n_total_nodes(self) -> int:
        return len(self.nodes.graph_i)
    
    @property
    def n_total_edges(self) -> int:
        return len(self.edges.graph_i)



def process_graph(graph_is):
    n_node = []
    n_edge = []
    nodes = defaultdict(list)
    senders = []
    receivers = []
    edge_features = defaultdict(list)
    graph_data = defaultdict(list)

    for batch_samp_i, graph_i in enumerate(graph_is):
        sg = graphs[graph_i]

        crystal = sg.structure
        graph = sg.graph

        
        n_node.append(len(graph.nodes))
        # every edge goes in both directions
        n_edge.append(2 * len(graph.edges))
        nodes['species'].extend([spec.number for spec in crystal.species])
        nodes['frac'].extend(crystal.frac_coords)
        nodes['cart'].extend(crystal.cart_coords)
        nodes['graph_i'].extend([batch_samp_i] * len(graph.nodes))

        node_i_offset = len(senders)

        for i, j, pos_offset in graph.edges(data='to_jimage'):
            neg_offset = tuple(-x for x in pos_offset)
            for sender, receiver, offset in ((i, j, pos_offset), (j, i, neg_offset)):
                senders.append(node_i_offset + sender)
                receivers.append(node_i_offset + receiver)
                edge_features['to_jimage'].append(offset)
                edge_features['graph_i'].append(batch_samp_i)
                edge_features['sender'].append(senders[-1])
                edge_features['receiver'].append(receivers[-1])

        graph_data['dataset_i'].append(graph_i)
        graph_data['abc'].append(crystal.lattice.parameters[:3])
        graph_data['angles_rad'].append(np.deg2rad(crystal.lattice.parameters[3:]))

    dtypes = {
        'species': jnp.uint8,
        'graph_i': jnp.uint16,
        'to_jimage': jnp.int4,
        'sender': jnp.uint16,
        'receiver': jnp.uint16
    }
    for d in nodes, edge_features, graph_data:
         for k in d:
            dtype = dtypes.get(k, None)       
            d[k] = jnp.array(d[k], dtype=dtype)

    G = Graphs(
        nodes=NodeData(**nodes), 
        edges=EdgeData(**edge_features),         
        graph_data=CrystalData(**graph_data),
        n_node=jnp.array(n_node), 
        n_edge=jnp.array(n_edge)
    )
    return G

G = process_graph(parts[:, 1])
debug_structure(G)

Graphs(nodes=NodeData(species=Array([12, 12, 12, 12, 12, 12, 25, 25, 25, 25, 13, 13, 13, 13, 13, 13, 13,
       13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
       13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 70, 70, 22, 22, 48,
       48, 51, 51,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
       67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 49, 49, 49, 26, 26,
       23, 23, 23, 23, 23, 23, 23, 23,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  9,  9,  9,  9, 62, 62, 15, 15, 15, 78, 78, 78, 78,  3,
        3, 39, 39, 42, 42, 42, 42,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  3,  3,  3, 25, 23, 23, 23, 23,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8, 48, 48, 78, 78, 78, 78, 78, 78,
        8,  8,  8,  8,  8,  8,  8,  8, 59, 59, 30, 30, 30, 30, 32, 32, 32,
       51, 51, 52, 52, 52, 52, 52, 52, 81, 81, 81, 49, 49, 49, 50, 34, 34,
       34, 34, 34, 34, 34, 34, 13, 28, 28, 28, 30, 30, 74, 74,  8,  8,