In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric
import torch_scatter

import e3nn
from e3nn import o3
from data_helpers import DataPeriodicNeighbors
from e3nn.nn.models.gate_points_2101 import Convolution, Network
from e3nn.o3 import Irreps

import pymatgen as mg
import pymatgen.io
from pymatgen.core.structure import Structure
from pymatgen.ext.matproj import MPRester
import pymatgen.analysis.magnetism.analyzer as pg
import numpy as np
import pickle
from mendeleev import element
import matplotlib.pyplot as plt

from sklearn.metrics import average_precision_score
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

import io
import random
import math
import sys
import time
import os
import datetime

In [2]:
data = torch.load('magnetic_order_data.pt')
id_list = []  # list of material ids

run_name = (time.strftime("%y%m%d-%H%M", time.localtime()))

order_list_mp = []
structures_list_mp = []
formula_list_mp = []
sites_list = []
id_list_mp = []
y_values_mp = []
order_encode = {"NM": 0, "AFM": 1, "FM": 2, "FiM": 2}

magnetic_atoms = ['Ga', 'Tm', 'Y', 'Dy', 'Nb', 'Pu', 'Th', 'Er', 'U',
                  'Cr', 'Sc', 'Pr', 'Re', 'Ni', 'Np', 'Nd', 'Yb', 'Ce',
                  'Ti', 'Mo', 'Cu', 'Fe', 'Sm', 'Gd', 'V', 'Co', 'Eu',
                  'Ho', 'Mn', 'Os', 'Tb', 'Ir', 'Pt', 'Rh', 'Ru']

In [3]:
torch.set_default_dtype(torch.float64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {'len_embed_feat': 64,
          'num_channel_irrep': 32,
          'num_e3nn_layer': 2,
          'max_radius': 5,
          'num_basis': 10,
          'adamw_lr': 0.005,
          'adamw_wd': 0.03,
          'radial_layers': 3
          }

# Used for debugging
identification_tag = "1:1:1.1 Relu wd:0.03 4 Linear"
cost_multiplier = 1.0

In [4]:
len_element = 118
atom_types_dim = 3*len_element
embedding_dim = params['len_embed_feat']
lmax = 1
# Roughly the average number (over entire dataset) of nearest neighbors for a given atom
n_norm = 35

# num_atom_types scalars (L=0) with even parity
irreps_in = Irreps([(45, (0, 1))])
irreps_hidden = Irreps([(64, (0, 1))])  # not sure
irreps_out = Irreps([(3, (0, 1))])  # len_dos scalars (L=0) with even parity

model_kwargs = {
    "irreps_in": irreps_in,
    "irreps_hidden": irreps_hidden,
    "irreps_out": irreps_out,
    "irreps_node_attr": '0e+1e',  # not really sure
    "irreps_edge_attr": '0e+1e',  # not really sure
    "layers": params['num_e3nn_layer'],
    "max_radius": params['max_radius'],
    "number_of_basis": params['num_basis'],
    "radial_layers": params['radial_layers'],
    # for these last 3 I don't know what's normal
    "radial_neurons": 5,
    "num_neighbors": 5,
    "num_nodes": 5
}

In [20]:
class AtomEmbeddingAndSumLastLayer(torch.nn.Module):
    def __init__(self, atom_type_in, atom_type_out, model):
        super().__init__()
        self.linear = torch.nn.Linear(atom_type_in, 128)
        self.model = model
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(128, 96)
        self.linear3 = torch.nn.Linear(96, 64)
        self.linear4 = torch.nn.Linear(64, 45)
        #self.linear5 = torch.nn.Linear(45, 32)
        #self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x, *args, batch=None, **kwargs):
        output = self.linear(x)
        output = self.relu(output)
        print(f"Input: {x}")
        output = self.linear2(output)
        output = self.relu(output)
        output = self.linear3(output)
        output = self.relu(output)
        output = self.linear4(output)
        #output = self.linear5(output)
        output = self.relu(output)
        print(output)
        output = self.model(output, *args, **kwargs)
        if batch is None:
            N = output.shape[0]
            batch = output.new_ones(N)
        output = torch_scatter.scatter_add(output, batch, dim=0)
        print(f"Output: {output}")
        #output = self.softmax(output)
        return output

In [21]:
model = AtomEmbeddingAndSumLastLayer(
    atom_types_dim, embedding_dim, Network(**model_kwargs))
opt = torch.optim.AdamW(
    model.parameters(), lr=params['adamw_lr'], weight_decay=params['adamw_wd'])

In [22]:
indices = np.arange(len(data))
np.random.shuffle(indices)
index_tr, index_va, index_te = np.split(
    indices, [int(.8 * len(indices)), int(.9 * len(indices))])

assert set(index_tr).isdisjoint(set(index_te))
assert set(index_tr).isdisjoint(set(index_va))
assert set(index_te).isdisjoint(set(index_va))


with open('loss.txt', 'a') as f:
    f.write(f"Iteration: {identification_tag}")

In [23]:
batch_size = 1
dataloader = torch_geometric.loader.DataLoader(
    [data[i] for i in index_tr], batch_size=batch_size, shuffle=True)
dataloader_valid = torch_geometric.loader.DataLoader(
    [data[i] for i in index_va], batch_size=batch_size)

loss_fn = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.78)

In [25]:
# def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
#         """evaluate the network

#         Parameters
#         ----------
#         data : `torch_geometric.data.Data` or dict
#             data object containing
#             - ``pos`` the position of the nodes (atoms)
#             - ``x`` the input features of the nodes, optional
#             - ``z`` the attributes of the nodes, for instance the atom type, optional
#             - ``batch`` the graph to which the node belong, optional
#         """

In [24]:
for j, d in enumerate(dataloader):
    d.to(device)
    len1 = len(d.x)
    len2 = len(d.x[0]) if len1 > 0 else 0
    output = model(d.x,  batch=d.batch)
    # output = model([[len1, len2], d.edge_index, d.edge_attr], batch=d.batch)

Input: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0.1477, 0.2837, 0.2596,  ..., 0.2000, 0.0595, 0.3252],
        [0.1477, 0.2837, 0.2596,  ..., 0.2000, 0.0595, 0.3252],
        [0.1477, 0.2837, 0.2596,  ..., 0.2000, 0.0595, 0.3252],
        ...,
        [0.2707, 0.0982, 0.2417,  ..., 0.0401, 0.0211, 0.1366],
        [0.2707, 0.0982, 0.2417,  ..., 0.0401, 0.0211, 0.1366],
        [0.2707, 0.0982, 0.2417,  ..., 0.0401, 0.0211, 0.1366]],
       grad_fn=<ReluBackward0>)


RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.

In [9]:
datapoint = next(iter(dataloader))

In [14]:
len(datapoint.x[0])

354

In [15]:
datapoint

DataPeriodicNeighborsBatch(edge_index=[2, 2844], edge_attr=[2844, 3], pos=[52, 3], x=[52, 354], y=[1], n_norm=[1], lattice=[3, 3], batch=[52], ptr=[2])