In [1]:
%env WANDB_NOTEBOOK_NAME Train.ipynb
%env CUDA_DEVICE_ORDER PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES 0
%env PYTORCH_CUDA_ALLOC_CONF backend:cudaMallocAsync

env: WANDB_NOTEBOOK_NAME=Train.ipynb
env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0
env: PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync


In [2]:
import torch
# UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
torch.set_float32_matmul_precision('high')

In [3]:
from mp_20_utils import load_all_data
device = 'cuda'
dataset = 'mp_20_biternary'
datasets_pd, torch_datasets, site_to_ids, element_to_ids, spacegroup_to_ids, max_len, max_enumeration, enumeration_stop, enumeration_pad = load_all_data(
    dataset=dataset)
print(max_len, max_enumeration, enumeration_stop, enumeration_pad)

20 7 8 9


In [4]:
from cascade_transformer.model import CascadeTransformer
from wyckoff_transformer import WyckoffTrainer
from tokenization import PAD_TOKEN, MASK_TOKEN
n_space_groups = len(spacegroup_to_ids)
# Not all 230 space groups are present in the data
# Embedding doesn't support uint8. Sad!
dtype = torch.int64
cascade_order = ("elements", "symmetry_sites", "symmetry_sites_enumeration")
# (N_i, d_i, pad_i)
assert max_enumeration + 1 == enumeration_stop
assert max_enumeration + 2 == enumeration_pad
enumeration_mask = max_enumeration + 3
assert enumeration_mask < torch.iinfo(dtype).max

cascade = (
    (len(element_to_ids), 64, torch.tensor(element_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (len(site_to_ids), 64 - 1, torch.tensor(site_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (enumeration_mask + 1, None, torch.tensor(enumeration_pad, dtype=dtype, device=device))
)
model = CascadeTransformer(
    n_start=n_space_groups,
    cascade=cascade,
    n_head=4,
    d_hid=256,
    n_layers=4,
    dropout=0.1,
    use_mixer=True).to(device)
# Our dynamic discard of predicting PAD calls for frequent recompilation
# model = torch.compile(model)

In [5]:
import torch
pad_dict = {
    "elements": element_to_ids[PAD_TOKEN],
    "symmetry_sites": site_to_ids[PAD_TOKEN],
    "symmetry_sites_enumeration": enumeration_pad
}
mask_dict = {
    "elements": element_to_ids[MASK_TOKEN],
    "symmetry_sites": site_to_ids[MASK_TOKEN],
    "symmetry_sites_enumeration": enumeration_mask
}
trainer = WyckoffTrainer(
    model, torch_datasets, pad_dict, mask_dict, cascade_order, "spacegroup_number", max_len, device, dtype=dtype)
trainer.train(epochs=10000, val_period=20)

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 20 val_loss_epoch 63.65313720703125 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt




Epoch 120 val_loss_epoch 51.414215087890625 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 280 val_loss_epoch 47.08189392089844 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 320 val_loss_epoch 46.95974349975586 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 440 val_loss_epoch 46.286102294921875 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 620 val_loss_epoch 45.68657302856445 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 880 val_loss_epoch 45.53960037231445 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 920 val_loss_epoch 45.190494537353516 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 940 val_loss_epoch 43.77199172973633 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 1260 val_loss_epoch 43.431400299072266 saved to checkpoints/2024-05-24_02-23-34/best_model_params.pt
Epoch 1480 val_loss_epoch 42.27555465698242 saved 

VBox(children=(Label(value='1.764 MB of 1.764 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
known_cascade_len,▁█▁▅▁▁▁▁█▅█▁▁████▅▅██▁▅█▅▅▁▅▁▅▁█▁▅█▁▁▅▅▅
known_seq_len,▂▄▁▃▂▂▇▁█▂▂▃▃█▇▇▂▆▄▄▆▁▆▆▁▆▃█▃▂▃▅█▃▄▅▂▃▄▃
lr,█▇▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_batch,▄▁█▁▂▂▁▆▁▃▁▁▂▁▁▁▁▁▁▁▁▆▁▁▄▁▁▁▂▁▁▁▁▁▁▁▃▁▁▁
train_loss_epoch,█▃▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss_epoch,█▃▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,10000.0
known_cascade_len,2.0
known_seq_len,18.0
lr,0.00728
train_loss_batch,5.60211
train_loss_epoch,35.17665
val_loss_epoch,35.07162
