In [1]:
%env WANDB_NOTEBOOK_NAME=sg_predict.ipynb

env: WANDB_NOTEBOOK_NAME=sg_predict.ipynb


In [2]:
from wyckoff_transformer.tokenization import (
    load_tensors_and_tokenisers)
from cascade_transformer.model import get_pyramid_perceptron

In [3]:
tensors, tokenisers, engineers = load_tensors_and_tokenisers("mp_20", "mp_20_CSP")

In [4]:
import torch
torch.set_float32_matmul_precision('high')
def composition_to_many_hot(
    composition_tokens: torch.Tensor,
    composition_counts: torch.Tensor,
    max_elements: int = len(tokenisers["elements"])):

    many_hot = torch.zeros(max_elements, device=composition_counts.device, dtype=torch.int64)
    many_hot.scatter_add_(0, composition_tokens.to(torch.int64), composition_counts.to(torch.int64))
    return many_hot

In [5]:
device = torch.device("cuda:0")
present_space_groups = frozenset(tensors["train"]["spacegroup_number"].tolist())
with torch.no_grad():
    for dataset in tensors.values():
        valid_examples = torch.tensor([space_group in present_space_groups for space_group in dataset["spacegroup_number"].tolist()], dtype=torch.bool)
        dataset["composition_vector_filtered"] = torch.stack([composition_to_many_hot(
            composition_tokens, composition_counts) for composition_tokens, composition_counts in zip(
            dataset["composition_tokens"], dataset["composition_counts"])])[valid_examples].to(device, dtype=torch.float32)
        dataset["spacegroup_number_filtered"] = dataset["spacegroup_number"][valid_examples].to(device, dtype=torch.int64)


In [6]:
different_compositions = frozenset(map(tuple, tensors["train"]["composition_vector_filtered"].tolist()))

In [7]:
len(different_compositions)

25396

In [8]:
tensors["train"]["composition_vector_filtered"].shape

torch.Size([27136, 92])

In [9]:
sg_predictor = get_pyramid_perceptron(
    len(tokenisers["elements"]),
    len(tokenisers["spacegroup_number"]), 3, 0.3).to(device)


In [10]:
import wandb
from tqdm.auto import trange
optimizer = torch.optim.Adam(sg_predictor.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

with wandb.init(
            project="WyckoffTransformer",
            job_type="train",
            tags=["SG_toy"]) as wandb_run:
    for epoch in trange(7000):
        with torch.no_grad():
            sg_predictor.eval()
            val_prediction = sg_predictor(tensors["val"]["composition_vector_filtered"])
            val_loss = loss_fn(val_prediction, tensors["val"]["spacegroup_number_filtered"])
            val_accuracy = (val_prediction.argmax(dim=-1) == tensors["val"]["spacegroup_number_filtered"]).float().mean()
            train_predictions_no_drop = sg_predictor(tensors["train"]["composition_vector_filtered"])
            train_loss_no_drop = loss_fn(train_predictions_no_drop, tensors["train"]["spacegroup_number_filtered"])
            train_accuracy_no_drop = (train_predictions_no_drop.argmax(dim=-1) == tensors["train"]["spacegroup_number_filtered"]).float().mean()
        sg_predictor.train()
        optimizer.zero_grad(set_to_none=True)
        train_predictions = sg_predictor(tensors["train"]["composition_vector_filtered"])
        train_loss = loss_fn(train_predictions, tensors["train"]["spacegroup_number_filtered"])
        train_loss.backward()
        optimizer.step()
        wandb_run.log({
            "train_loss": train_loss_no_drop, "val_loss": val_loss,
            "val_accuracy": val_accuracy, "train_accuracy": train_accuracy_no_drop,
            "train_loss_dropout": train_loss})

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m ([33msymmetry-advantage[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/7000 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.007 MB uploaded\r'), FloatProgress(value=0.9038217458750529, max=1.0…

0,1
train_accuracy,▁▃▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇███████████████████████
train_loss,█▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_dropout,█▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▆▆▆▇▇▇▇▇▇████████████████████████████
val_loss,█▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,0.57127
train_loss,1.79211
train_loss_dropout,2.92236
val_accuracy,0.49054
val_loss,2.23768


In [11]:
import catboost

In [12]:
cb_train = catboost.Pool(
    data=tensors["train"]["composition_vector_filtered"].cpu().numpy(),
    label=tensors["train"]["spacegroup_number_filtered"].cpu().numpy())
cb_val = catboost.Pool(
    data=tensors["val"]["composition_vector_filtered"].cpu().numpy(),
    label=tensors["val"]["spacegroup_number_filtered"].cpu().numpy())
catboost_model = catboost.CatBoostClassifier(
    iterations=5000,
    learning_rate=1e-1,
    loss_function="MultiClass",
    eval_metric="Accuracy",
    task_type="GPU",
    devices="0")

In [13]:
catboost_model.fit(cb_train, eval_set=cb_val)

0:	learn: 0.1687795	test: 0.1660215	best: 0.1660215 (0)	total: 75.8ms	remaining: 6m 18s
1:	learn: 0.1765183	test: 0.1765291	best: 0.1765291 (1)	total: 143ms	remaining: 5m 57s
2:	learn: 0.1767025	test: 0.1763079	best: 0.1765291 (1)	total: 207ms	remaining: 5m 44s
3:	learn: 0.1831884	test: 0.1829444	best: 0.1829444 (3)	total: 261ms	remaining: 5m 25s
4:	learn: 0.1829673	test: 0.1827232	best: 0.1829444 (3)	total: 313ms	remaining: 5m 12s
5:	learn: 0.1875369	test: 0.1853777	best: 0.1853777 (5)	total: 365ms	remaining: 5m 4s
6:	learn: 0.1876843	test: 0.1859308	best: 0.1859308 (6)	total: 421ms	remaining: 5m
7:	learn: 0.1882739	test: 0.1863732	best: 0.1863732 (7)	total: 474ms	remaining: 4m 56s
8:	learn: 0.1890478	test: 0.1872580	best: 0.1872580 (8)	total: 529ms	remaining: 4m 53s
9:	learn: 0.1894531	test: 0.1881429	best: 0.1881429 (9)	total: 582ms	remaining: 4m 50s
10:	learn: 0.1903007	test: 0.1896914	best: 0.1896914 (10)	total: 637ms	remaining: 4m 48s
11:	learn: 0.1905218	test: 0.1898020	best: 0.

<catboost.core.CatBoostClassifier at 0x7f6597156b00>