In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

theme, cs = rp.mpl_setup(False)

In [2]:
#  -U git+https://github.com/CompRhys/aviary.git

In [None]:
df = pd.read_feather('pairs_data.feather')
df

In [None]:
elem_embs = pd.read_json('https://raw.githubusercontent.com/CompRhys/aviary/refs/heads/main/aviary/embeddings/element/megnet16.json')
elem_embs

In [None]:
import torch
import torch.nn.functional as F
from aviary.roost.model import DescriptorNetwork
from pymatgen.core import Composition
from torch import Tensor, LongTensor
from data import collate_batch, comp2graph

device = 'cuda'
torch.set_default_device(device)



elem_embed_dim = 112
comp_embed_dim = 64

batch = collate_batch([comp2graph(x) for x in df.sample(16)['pretty_formula_1']])
print([tuple(x.shape) for x in batch])

gnn = DescriptorNetwork(elem_emb_len=elem_embed_dim, elem_fea_len=64)

out = gnn(*batch)

print(out.shape)
out

In [6]:
benchmark = pd.read_csv('https://raw.githubusercontent.com/usccolumbia/cspbenchmark/main/data/CSPbenchmark_test_data.csv')
benchmark_ids = benchmark['material_id']

In [None]:
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader, random_split, IterableDataset
from tqdm import tqdm, trange

fn = 'ds-oh.pt'
regen = False

val_frac = 0.1
test_frac = 0.5
batch_size = 256
shard_size = 256
df_train = df.iloc[::1].query('dist > 0.01')
print(df_train.shape)
df_train = df_train.query('id_1 not in @benchmark_ids and id_2 not in @benchmark_ids')
print(df_train.shape)
df_train = df_train.iloc[:-(df_train.shape[0] % batch_size)]
print(df_train.shape)


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X1, X2, y):
        self.X1 = X1
        self.X2 = X2
        self.y = y
    def __len__(self):
        return len(self.X1)
    def __getitem__(self, idx):
        return self.X1[idx], self.X2[idx], self.y[idx]
if regen:
    X1 = []
    X2 = []
    y = []
    for i in trange(0, len(df_train.index), batch_size):
        df_batch = df_train.iloc[i:i+batch_size]
        X1.append(collate_batch([comp2graph(x) for x in df_batch['pretty_formula_1']]))
        X2.append(collate_batch([comp2graph(x) for x in df_batch['pretty_formula_2']]))
        y.append(torch.tensor(df_batch['dist'].values))

        if (i // batch_size + 1) % shard_size == 0:
            ds = MyDataset(X1, X2, y)
            torch.save(ds, Path('full_dataset') / f'{(i // batch_size +1) // shard_size}.pt')
            X1 = []
            X2 = []
            y = []

    ds = MyDataset(X1, X2, y)
    torch.save(ds, Path('full_dataset') / '0.pt')
else:
    # ds = torch.load(fn, weights_only=False)
    dses = []
    shards = sorted(Path('full_dataset').glob('*.pt'))
    train_len = int((1 - test_frac) * len(shards))    
    train_shards = shards[:train_len]
    test_shards = shards[train_len:]
    for shard in tqdm(train_shards):
        dses.append(torch.load(shard, weights_only=False))
    
    ds = torch.utils.data.ConcatDataset(dses)
    

train_ds, val_ds = random_split(ds, [1 - val_frac, val_frac], generator=torch.Generator(device=device).manual_seed(123))
train_dl = DataLoader(train_ds, batch_size=None, shuffle=True, generator=torch.Generator(device=device))
val_dl = DataLoader(val_ds, batch_size=None)
val_X1, val_X2, val_y = next(iter(val_dl))

In [24]:
lr = 3e-3
num_epochs = 2
tau = 1.0
elem_embed_dim: int = 64
comp_embed_dim: int = 128
torch.set_float32_matmul_precision('high')

In [None]:
from tqdm import trange
from model import CompositionEmbedding


hist = []
model = CompositionEmbedding(elem_input_dim=112, elem_hidden_dim=elem_embed_dim, comp_embed_dim=comp_embed_dim, rescale_init=np.sqrt(comp_embed_dim))
model = torch.load('checkpoints/full-2.pt', weights_only=False).to(device)
print(model.rescale)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=num_epochs)

with trange(num_epochs * len(train_dl)) as bar:
    for epoch in range(num_epochs):
        model.train()
        loss_vals = []
        for X1, X2, y in train_dl:
            bar.update()            
            loss_val = F.binary_cross_entropy(model(X1, X2), (y < tau).float())
            loss_val.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_vals.append(loss_val.detach().item())
            if (bar.n + 1) % 100 == 0:
                bar.set_description_str('Train: {:.3f} Valid: {:.3f}'.format(np.mean(loss_vals[-100:]),  hist[-1]["Validation Loss"] if hist else 0))
        model.eval()
        with torch.no_grad():
            val_losses = []
            for X1, X2, y in val_dl:
                val_losses.append(F.binary_cross_entropy(model(X1, X2), (y < tau).float()))

            val_losses = torch.tensor(val_losses)
        hist.append({
            'Epoch': epoch,
            'Train Loss': sum(loss_vals) / len(loss_vals),
            'Validation Loss': val_losses.mean().item()
        })

        print({k: f'{v:.4f}' for k, v in hist[-1].items()})

        bar.set_description_str('Train: {:.3f} Valid: {:.3f}'.format(hist[-1]["Train Loss"], hist[-1]["Validation Loss"]))
        scheduler.step()

hist = pd.DataFrame(hist)

sns.lineplot(hist, x='Epoch', y='Train Loss')
sns.lineplot(hist, x='Epoch', y='Validation Loss')

In [None]:
y < tau

In [None]:
model(X1, X2)[y < tau]

In [None]:
df.query('2.044 < dist and dist < 2.045')

In [None]:
model.rescale

In [19]:
torch.save(model, 'checkpoints/full-2.pt')

In [22]:
ypred = []
ytrue = []
for X1, X2, y in val_dl:
    ypred.append(model(X1, X2).detach().cpu())
    ytrue.append(y)

ypred = torch.cat(ypred)
ytrue = torch.cat(ytrue).cpu() < tau

In [None]:
print(((ypred > 0.5) == ytrue).float().mean())
print(torch.corrcoef(torch.vstack([ypred > 0.5, ytrue]).float()))

In [None]:
sns.displot(ypred[::100].numpy(force=True))

In [None]:
dists = torch.linalg.vector_norm(model.embed(X1) - model.embed(X2), dim=1).numpy(force=True)
y_np = y.numpy(force=True) < tau

sns.histplot(x=dists[y_np], label='y = 1', fill=False, element='step', bins=20)
sns.histplot(x=dists[~y_np], label='y = 0', fill=False, element='step', bins=20)
plt.legend()

- train from CIF folder
- oxidation state (BERTOS), make nodes

In [None]:
mp_id = 'mp-5615'

df_id = df.query('id_1 == @mp_id')

X1 = collate_batch([comp2graph(df_id['pretty_formula_1'].iloc[0])])

x2s = df_id['pretty_formula_2']

X2 = []
for i in range(0, x2s.shape[0], batch_size):
    X2.append(collate_batch([comp2graph(c) for c in x2s.iloc[i:i+batch_size]]))

model.eval()
z1 = model.embed(X1)
z2 = torch.cat([model.embed(x) for x in X2])

dists = torch.cdist(z1, z2)
df_id['z_dist'] = dists.numpy(force=True).reshape(-1)

In [None]:
df_id.sort_values('z_dist')