In [1]:

%env WANDB_NOTEBOOK_NAME Train.ipynb
%env CUDA_DEVICE_ORDER PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES 1

env: WANDB_NOTEBOOK_NAME=Train.ipynb
env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [2]:
import pickle
import logging
import gzip
from data import read_all_MP_csv
from tokenization import get_tokens, MASK_SITE

MP_20_cache_data = "cache/mp_20/data.pkl.gz"
MP_20_cache_tensors = "cache/mp_20/tensors.pkl.gz"
MP_20_path = "cdvae/data/mp_20"

try:
    with gzip.open(MP_20_cache_data, "rb") as f:
        datasets_pd, max_len = pickle.load(f)
except Exception as e:
    logging.info("Error reading data cache:")
    logging.info(e)
    logging.info("Reading from csv")
    datasets_pd, max_len = read_all_MP_csv()
    with gzip.open(MP_20_cache_data, "wb") as f:
        pickle.dump((datasets_pd, max_len), f)

In [3]:
try:
    with gzip.open(MP_20_cache_tensors, "rb") as f:
        tensors, site_to_ids, element_to_ids, spacegroup_to_ids = pickle.load(f)
except Exception as e:
    logging.warning("Error reading tensor cache! The token order will change!")
    logging.info(e)
    logging.info("Generating new tensors")
    tensors, site_to_ids, element_to_ids, spacegroup_to_ids = get_tokens(datasets_pd)
    with gzip.open(MP_20_cache_tensors, "wb") as f:
        pickle.dump((tensors, site_to_ids, element_to_ids, spacegroup_to_ids), f)

In [4]:
import torch
device = 'cuda'
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def to_combined_dataset(dataset):
    return dict(
        symmetry_sites=torch.stack(dataset['symmetry_sites_tensor'].to_list()).T.to(device),
        symmetry_elements=torch.stack(dataset['symmetry_elements_tensor'].to_list()).T.to(device),
        spacegroup_number=torch.stack(dataset['spacegroup_number_tensor'].to_list()).to(device),
        padding_mask = torch.stack(dataset['padding_mask_tensor'].to_list()).to(device)
    )
torch_datasets = dict(zip(tensors.keys(), map(to_combined_dataset, tensors.values())))

In [5]:
from wyckoff_transformer import WyckoffTransformerModel, WyckoffTrainer
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 4  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 4  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
n_space_groups = len(spacegroup_to_ids)
# Nope, not all SGs are present
# assert n_space_groups == 230
model = WyckoffTransformerModel(
    n_space_groups = n_space_groups,
    n_sites = len(site_to_ids),
    n_elements = len(element_to_ids),
    d_space_groups = 16,
    d_sites = 64,
    d_species = 64,
    nhead=nhead,
    d_hid=d_hid,
    nlayers=nlayers,
    dropout=dropout).to(device)



In [6]:
trainer = WyckoffTrainer(model, torch_datasets, max_len, torch.tensor(site_to_ids[MASK_SITE]).to(device))
trainer.train(epochs=2000, val_period=10)

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m. Use [1m`wandb login --relogin`[0m to force relogin


New best val_loss_epoch 4.947813510894775, saving the model
Epoch 10 val_loss_epoch 4.947813510894775 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
New best val_loss_epoch 2.034229278564453, saving the model
Epoch 20 val_loss_epoch 2.034229278564453 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
Epoch 30 val_loss_epoch 2.212127447128296 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
New best val_loss_epoch 1.8879921436309814, saving the model
Epoch 40 val_loss_epoch 1.8879921436309814 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
New best val_loss_epoch 1.6738146543502808, saving the model
Epoch 50 val_loss_epoch 1.6738146543502808 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
Epoch 60 val_loss_epoch 1.8593095541000366 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
Epoch 70 val_loss_epoch 1.6759945154190063 saved to checkpoints/2024-01-23_16-21-43/best_model_params.pt
Epoch 80 val_loss_epoch 1.

VBox(children=(Label(value='0.574 MB of 0.574 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
known_seq_len,▂▄▇▄▅▁▃▁█▁▁▇▅█▇▄▆▇▇▄▅▁▂▄▆▅▅▅▁▆▂▇▃▇█▄▅▆▁▇
lr,█▇▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_batch,▅▁▄▁▁▇▂▆▁█▇▁▁▁▁▁▁▁▁▁▁▅▂▁▁▁▁▁▅▁▂▁▁▁▁▁▁▁▅▁
val_loss_epoch,▃▃█▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,2000.0
known_seq_len,2.0
lr,0.03116
train_loss_batch,4.09911
val_loss_epoch,0.62934
