In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric
import torch_scatter

import e3nn
from e3nn import o3
from data_helpers import DataPeriodicNeighbors
from e3nn.nn.models.gate_points_2101 import Convolution, Network
from e3nn.o3 import Irreps

import pymatgen as mg
import pymatgen.io
from pymatgen.core.structure import Structure
from pymatgen.ext.matproj import MPRester
import pymatgen.analysis.magnetism.analyzer as pg
import numpy as np
import pickle
from mendeleev import element
import matplotlib.pyplot as plt

from sklearn.metrics import average_precision_score
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

import io
import random
import math
import sys
import time
import os
import datetime

In [3]:
data = torch.load('run/magnetic_order_data.pt')

run_name = (time.strftime("%y%m%d-%H%M", time.localtime()))


In [4]:
torch.set_default_dtype(torch.float64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {'len_embed_feat': 64,
          'num_channel_irrep': 32,
          'num_e3nn_layer': 2,
          'max_radius': 5,
          'num_basis': 10,
          'adamw_lr': 0.005,
          'adamw_wd': 0.03,
          'radial_layers': 3
          }

# Used for debugging
identification_tag = "1:1:1.1 Relu wd:0.03 4 Linear"
cost_multiplier = 1.0

In [5]:
len_element = 118
atom_types_dim = 3*len_element
embedding_dim = params['len_embed_feat']
lmax = 1

# I think the in and out irreps have to be scalars, but this is causing issues with no paths being created? like, from what I know about tensor products I'm *pretty sure* you need at least 1 path to add things over, but I'm not 100% sure

# num_atom_types scalars (L=0) with even parity
irreps_in = Irreps([(45, (0, 1))])
irreps_hidden = Irreps([(64, (0, 1)), (64, (1, 1))])  # not sure - is this too large?
irreps_out = Irreps([(3, (0, 1))])  # len_dos scalars (L=0) with even parity

model_kwargs = {
    "irreps_in": irreps_in,
    "irreps_hidden": irreps_hidden,
    "irreps_out": irreps_out,
    "irreps_node_attr": '3x0e',  # not really sure, but I think it needs dim=3
    # "irreps_node_attr": '0e',  # use this if I'm not inputting a z argument
    "irreps_edge_attr": '0e+1e',  # not really sure
    "layers": params['num_e3nn_layer'],
    "max_radius": params['max_radius'],
    "number_of_basis": params['num_basis'],
    "radial_layers": params['radial_layers'],
    # for these last 3 I don't know what's normal
    "radial_neurons": 35,
    "num_neighbors": 35, # I think this is correct
    "num_nodes": 35
}

In [6]:
class AtomEmbeddingAndSumLastLayer(torch.nn.Module):
    def __init__(self, atom_type_in, atom_type_out, model):
        super().__init__()
        self.linear = torch.nn.Linear(atom_type_in, 128)
        self.model = model
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(128, 96)
        self.linear3 = torch.nn.Linear(96, 64)
        self.linear4 = torch.nn.Linear(64, 45)
        #self.linear5 = torch.nn.Linear(45, 32)
        #self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x, *args, batch=None, **kwargs):
        print(args)
        output = self.linear(x)
        output = self.relu(output)
        print(f"Input: {x}")
        output = self.linear2(output)
        output = self.relu(output)
        output = self.linear3(output)
        output = self.relu(output)
        output = self.linear4(output)
        #output = self.linear5(output)
        output = self.relu(output)
        output = self.model({'x': output, 'batch': batch, **kwargs})
        if batch is None:
            N = output.shape[0]
            batch = output.new_ones(N)
        # print(f'not-quite-output: {output}')
        # output = torch_scatter.scatter_add(output, batch, dim=0)
        print(f"Output: {output}")
        #output = self.softmax(output)
        return output

In [7]:
model = AtomEmbeddingAndSumLastLayer(
    atom_types_dim, embedding_dim, Network(**model_kwargs))
opt = torch.optim.AdamW(
    model.parameters(), lr=params['adamw_lr'], weight_decay=params['adamw_wd'])

In [8]:
irreps_in1 = o3.Irreps('45x0e')
irreps_in2 = o3.Irreps('1x1e')
irreps_out = o3.Irreps('45x0e')
# the reason this doesn't work is because 0e x 1e = 1e, not 0e

instr = [
    (i_1, i_2, i_out, 'uvw', True, 1.0)
    for i_1, (_, ir_1) in enumerate(irreps_in1)
    for i_2, (_, ir_2) in enumerate(irreps_in2)
    for i_out, (_, ir_out) in enumerate(irreps_out)
    if ir_out in ir_1 * ir_2
]

# update: what did I do that I don't have as many issues with dimensions not matching up?
# not that I don't have issues, but there are fewer of them
# it's the linear layers that are causing problems

# self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in)
# self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out)

In [9]:
comment = """
OK here's the old model

GatedConvParityNetwork(
  (layers): ModuleList(
    (0): ModuleList(
      (0): Convolution()
      (1): GatedBlockParity (32x0e + 32x0e + 32x1o -> 32x0e,32x1o)
    )
    (1): ModuleList(
      (0): Convolution()
      (1): GatedBlockParity (32x0e + 64x0e + 32x1e,32x1o -> 32x0e,32x1e,32x1o)
    )
    (2): Convolution()
  )
)
"""

In [8]:
indices = np.arange(len(data))
np.random.shuffle(indices)
index_tr, index_va, index_te = np.split(
    indices, [int(.8 * len(indices)), int(.9 * len(indices))])

assert set(index_tr).isdisjoint(set(index_te))
assert set(index_tr).isdisjoint(set(index_va))
assert set(index_te).isdisjoint(set(index_va))


with open('loss.txt', 'a') as f:
    f.write(f"Iteration: {identification_tag}")

In [9]:
batch_size = 1
dataloader = torch_geometric.loader.DataLoader(
    [data[i] for i in index_tr], batch_size=batch_size, shuffle=True)
dataloader_valid = torch_geometric.loader.DataLoader(
    [data[i] for i in index_va], batch_size=batch_size)

loss_fn = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.78)

In [14]:
for j, d in enumerate(dataloader):
    d.to(device)
    # output = model(x=d.x, batch=d.batch, pos=d.pos) # if I define node_attr to be 0e, try using this line
    # output = model(x=d.x, batch=d.batch, pos=d.pos, z=d.edge_attr) # ok so z is actually supposed to be node attribute
    output = model(x=d.x, batch=d.batch, pos=d.pos, z=d.pos.new_ones((d.pos.shape[0], 3)))

()
Input: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Output: tensor([[-0.1659, -1.6830, -0.3038]], grad_fn=<DivBackward0>)
()
Input: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Output: tensor([[ 0.2459, -1.0176,  0.0089]], grad_fn=<DivBackward0>)
()
Input: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Output: tensor([[ 0.2201, -1.9613, -0.0194]], grad_fn=<DivBackward0>

KeyboardInterrupt: 

In [10]:
for i, index in enumerate(index_te):
    with torch.no_grad():
        print(len(index_te))
        print(f"Index being tested: {index}")
        d = torch_geometric.data.Batch.from_data_list([data[index]])
        print(d)
        print(vars(d))

645
Index being tested: 3452
DataPeriodicNeighborsBatch(edge_index=[2, 1572], edge_attr=[1572, 3], pos=[32, 3], x=[32, 354], y=[1], n_norm=[1], lattice=[3, 3], batch=[32], ptr=[2])
{'_store': {'edge_index': tensor([[ 0,  0,  0,  ..., 31, 31, 31],
        [31,  2, 13,  ..., 16, 20, 25]]), 'edge_attr': tensor([[-1.3999,  0.2193, -4.6460],
        [ 0.0000, -1.6172, -4.3183],
        [-2.9085,  2.0708, -3.4958],
        ...,
        [ 1.3999, -0.4331,  2.3001],
        [-0.2060,  4.4100,  1.8230],
        [-3.0171, -3.6252, -0.6724]]), 'pos': tensor([[2.9085, 2.6212, 2.8231],
        [2.9085, 4.6293, 5.8136],
        [2.9085, 1.0040, 7.1414],
        [2.9085, 6.2465, 1.4953],
        [0.0000, 0.6788, 1.2917],
        [0.0000, 6.5716, 7.3450],
        [0.0000, 2.9464, 5.6100],
        [0.0000, 4.3041, 3.0266],
        [1.4452, 0.0000, 4.3183],
        [4.3717, 0.0000, 4.3183],
        [4.3717, 3.6252, 0.0000],
        [1.4452, 3.6252, 0.0000],
        [0.0000, 2.5584, 0.6728],
        [0.0

KeyboardInterrupt: 