In [1]:
import os
from datasets import load_dataset
from transformers import BertTokenizerFast
import re
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataclasses import dataclass

from dotenv import load_dotenv
load_dotenv("../.env")


  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
dataset=load_dataset('ur-whitelab/mapi', token=os.environ['HF_TOKEN'])

print(dataset['train'].column_names)

target = ["band_gap"]
features=['nsites', 'nelements', 'formula_pretty', 'chemsys', 'volume', 'density', 'density_atomic', 'crystal_system', 'symbol', 'number', 'point_group', 'structure']




In [11]:
print(dataset['train'][2]['structure'])

Full Formula (Nb1 V2 Mo1)
Reduced Formula: NbV2Mo
abc   :  10.220753  10.220753  10.220753
angles: 128.933454 117.899846  84.471274
pbc   :       True       True       True
Sites (4)
  #  SP      a         b         c    magmom
---  ----  ---  --------  --------  --------
  0  Nb      0  0         0            0.516
  1  V       0  0.251541  0.251541     1.349
  2  V       0  0.748459  0.748459     1.349
  3  Mo      0  0.5       0.5          1.103


In [None]:
train_dataset = dataset['train'].select_columns(features+target).to_pandas()
test_dataset = dataset['test'].select_columns(features+target).to_pandas()

train_dataset = train_dataset.dropna(subset=target, axis=0)
test_dataset = test_dataset.dropna(subset=target, axis=0)

formula = train_dataset.iloc[0]['formula_pretty']
all_elements = "-".join(train_dataset['chemsys'].to_list())
all_elements = set(all_elements.split("-"))
elements_dics = {k: v for (v, k) in enumerate(all_elements, 1)}

voc = 	list(all_elements) + [str(i) for i in range(10)] + ["Full Formula", "Reduced Formula", "abc", "angles", "pbc", "Sites", "True", "False", "magmom", ".", "(", ")", "-", "\n", " ", "a", "b", "c"]


tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # Check if there's a cased. It might be important for elements
tokenizer.add_tokens(voc)

# def tokenize_function(examples):
    # return tokenizer(examples['structure'], padding="max_length", truncation=True, return_tensors='pt')

# tokenized_datasets = dataset.map(tokenize_function, batched=True)
# tokenizer

In [None]:
train_dataset = dataset['train'].select_columns(['structure']+target)
test_dataset = dataset['test'].select_columns(['structure']+target)

def filter_none(example):
    return all(value is not None for value in example.values())

train_dataset = train_dataset.filter(filter_none)
test_dataset = test_dataset.filter(filter_none)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [None]:
@dataclass
class KDESolConfig:
    vocab_size: int = 30618
    batch_size: int = 256
    buffer_size: int = 10000
    rnn_units: int = 1028
    hidden_dim: int = 512
    embedding_dim: int = tokenizer.model_max_length
    reg_strength: float = 0.01
    lr: float = 1e-4
    drop_rate: float = 0.35
    nmodels: int = 10
    adv_epsilon: float = 1e-3
    epochs: int = 150
    pad_to_len: int = 512

class RNN(nn.Module):
    def __init__(self, config=KDESolConfig()):
        super(RNN, self).__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
        self.dropout = nn.Dropout(config.drop_rate)
        self.rnn1 = nn.LSTM(config.embedding_dim, config.rnn_units, batch_first=True, bidirectional=True)
        self.rnn2 = nn.LSTM(2 * config.rnn_units, config.rnn_units, batch_first=True, bidirectional=True)
        self.layer_norm = nn.LayerNorm(2 * config.rnn_units)
        self.dense1 = nn.Linear(2 * config.rnn_units, config.hidden_dim)
        self.dense2 = nn.Linear(config.hidden_dim, config.hidden_dim // 2)
        self.out_mu = nn.Linear(config.hidden_dim // 2, 1)
        self.out_std = nn.Linear(config.hidden_dim // 2, 1)

        self.softplus = nn.Softplus()

    def forward(self, x):
        x = self.embedding(x)
        x = self.dropout(x)
        x, _ = self.rnn1(x)
        x, _ = self.rnn2(x)
        x = self.layer_norm(x[:, -1, :]) 
        x = nn.SiLU()(self.dense1(x))
        x = self.dropout(x)
        x = nn.SiLU()(self.dense2(x))
        x = self.dropout(x)
        mu = self.out_mu(x)
        std = self.softplus(self.out_std(x))
        return mu
        return torch.cat((mu, std), dim=-1)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = RNN()
model.to(device)

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

losses = []

for _ in range(10):
  print(f"Starting epoch {_}.")
  model.train()
  for batch, d in enumerate(train_dataloader):
    optimizer.zero_grad()
    size = len(train_dataloader.dataset)
    
    X = tokenizer(d['structure'], padding="max_length", truncation=True, return_tensors='pt')

    X = X['input_ids'].to(device)
    y = d['band_gap'].to(device)
    
    pred = model(X)
    loss = loss_fn(torch.flatten(pred), y.to(torch.float32))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 500 == 0:
      losses.append(loss.item())
      loss_item = loss.item()
      current = batch * len(X) + len(X)
      print(f"\tloss: {loss_item:>7f}  [{current:>5d}/{size:>5d}]")
  print(f"Epoch {_} done.")


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import urllib.request
import matplotlib as mpl
import matplotlib.font_manager as font_manager
urllib.request.urlretrieve('https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf', 'IBMPlexMono-Regular.ttf')
fe = font_manager.FontEntry(
    fname='IBMPlexMono-Regular.ttf',
    name='plexmono')
font_manager.fontManager.ttflist.append(fe)
plt.rcParams.update({'axes.facecolor':'#f5f4e9',
            'grid.color' : '#AAAAAA',
            'axes.edgecolor':'#333333',
            'figure.facecolor':'#FFFFFF',
            'axes.grid': False,
            'axes.prop_cycle':   plt.cycler('color', plt.cm.Dark2.colors),
            'font.family': fe.name,
            'figure.figsize': (3.5,3.5 / 1.2),
            'ytick.left': True,
            'xtick.bottom': True
           })

In [2]:
k=8

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error

yhat=[]
y=[]
for k in range(100):
    try:
        y.append(test_dataset.iloc[k][target])
        yhat.append(model.predict([test_dataset.iloc[k][features]]))
    except Exception as e:
        print(k, e)

y = np.array(y).flatten()
yhat = np.array(yhat).flatten()
print(y.shape, yhat.shape)

0 name 'test_dataset' is not defined
1 name 'test_dataset' is not defined
2 name 'test_dataset' is not defined
3 name 'test_dataset' is not defined
4 name 'test_dataset' is not defined
5 name 'test_dataset' is not defined
6 name 'test_dataset' is not defined
7 name 'test_dataset' is not defined
8 name 'test_dataset' is not defined
9 name 'test_dataset' is not defined
10 name 'test_dataset' is not defined
11 name 'test_dataset' is not defined
12 name 'test_dataset' is not defined
13 name 'test_dataset' is not defined
14 name 'test_dataset' is not defined
15 name 'test_dataset' is not defined
16 name 'test_dataset' is not defined
17 name 'test_dataset' is not defined
18 name 'test_dataset' is not defined
19 name 'test_dataset' is not defined
20 name 'test_dataset' is not defined
21 name 'test_dataset' is not defined
22 name 'test_dataset' is not defined
23 name 'test_dataset' is not defined
24 name 'test_dataset' is not defined
25 name 'test_dataset' is not defined
26 name 'test_dataset'

In [None]:
lim = (min(y),max(y))
plt.xlabel('True')
plt.ylabel('Predicted')
plt.plot(y, yhat, 'o', alpha=0.2)
plt.plot(lim, lim, '--')
plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 1*0.1*(max(y)-min(y)), f"correlation = {np.corrcoef(y, yhat)[0,1]:.3f}")
plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 2*0.1*(max(y)-min(y)), f"MAE = {mean_squared_error(y, yhat):.3f}")
plt.show()
