In [1]:
%env "WANDB_NOTEBOOK_NAME" "Train.ipynb"

env: "WANDB_NOTEBOOK_NAME"="Train.ipynb"


In [2]:
import pickle
import logging
import gzip
from data import read_all_MP_csv
from tokenization import get_tokens, MASK_SITE

MP_20_cache_data = "cache/mp_20/data.pkl.gz"
MP_20_cache_tensors = "cache/mp_20/tensors.pkl.gz"
MP_20_path = "cdvae/data/mp_20"

try:
    with gzip.open(MP_20_cache_data, "rb") as f:
        datasets_pd, max_len = pickle.load(f)
except Exception as e:
    logging.info("Error reading data cache:")
    logging.info(e)
    logging.info("Reading from csv")
    datasets_pd, max_len = read_all_MP_csv()
    with gzip.open(MP_20_cache_data, "wb") as f:
        pickle.dump((datasets_pd, max_len), f)

In [3]:
try:
    with gzip.open(MP_20_cache_tensors, "rb") as f:
        tensors, site_to_ids, element_to_ids, spacegroup_to_ids = pickle.load(f)
except Exception as e:
    logging.warning("Error reading tensor cache! The token order will change!")
    logging.info(e)
    logging.info("Generating new tensors")
    tensors, site_to_ids, element_to_ids, spacegroup_to_ids = get_tokens(datasets_pd)
    with gzip.open(MP_20_cache_tensors, "wb") as f:
        pickle.dump((tensors, site_to_ids, element_to_ids, spacegroup_to_ids), f)

In [4]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def to_combined_dataset(dataset):
    return dict(
        symmetry_sites=torch.stack(dataset['symmetry_sites_tensor'].to_list()).T.to(device),
        symmetry_elements=torch.stack(dataset['symmetry_elements_tensor'].to_list()).T.to(device),
        spacegroup_number=torch.stack(dataset['spacegroup_number_tensor'].to_list()).to(device),
        padding_mask = torch.stack(dataset['padding_mask_tensor'].to_list()).to(device)
    )
torch_datasets = dict(zip(tensors.keys(), map(to_combined_dataset, tensors.values())))

In [5]:
from wyckoff_transformer import WyckoffTransformerModel, WyckoffTrainer
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 4  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 4  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
n_space_groups = len(spacegroup_to_ids)
# Nope, not all SGs are present
# assert n_space_groups == 230
model = WyckoffTransformerModel(
    n_space_groups = n_space_groups,
    n_sites = len(site_to_ids),
    n_elements = len(element_to_ids),
    d_space_groups = 16,
    d_sites = 64,
    d_species = 64,
    nhead=nhead,
    d_hid=d_hid,
    nlayers=nlayers,
    dropout=dropout).to(device)



In [6]:
trainer = WyckoffTrainer(model, torch_datasets, max_len, torch.tensor(site_to_ids[MASK_SITE]).to(device))
trainer.train(epochs=2000, val_period=10)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113049355076832, max=1.0…



New best val_loss_epoch 9.011067390441895, saving the model
Epoch 10 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 20 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 30 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 40 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 50 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 60 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 70 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 80 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoch 90 val_loss_epoch 9.011067390441895 saved to checkpoints/2024-01-16_15-32-26/best_model_params.pt
Epoc



VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,2000.0
val_loss_epoch,9.01107
