# Import Libraries

In [None]:
# Import Libraries
import numpy as np
import pandas as pd
import os
import glob
import json
import csv
from IPython import display
import matplotlib.pyplot as plt
import hashlib

# Dataset 

In [None]:
path = 'data'
csv_files = glob.glob(os.path.join(path, "*.csv"))

dfs = []
for i, f in enumerate(csv_files):
    df = pd.read_csv(f, engine="python")
    df['root'] = f'{i+1}'
    dfs.append(df)

final_data = pd.concat(dfs)
#final_data.to_csv("./dataset2/data.csv", encoding="utf-8", index=False)
final_data.head()

In [None]:
#final_data.to_csv("data.csv", encoding="utf-8", index=False)

In [None]:
print(final_data['root'].shape)
print(final_data['root'].unique())
print(len(final_data[final_data['root'] == '7']))

In [None]:
final_data.info()

# Train-Valid-Test Split

In [None]:
df = final_data.copy()
df.info()

In [None]:
# Train-test split
from sklearn.model_selection import train_test_split	

tmp, test_df = train_test_split(df, test_size=0.2, random_state=101, stratify=df['root'])
train_df, valid_df = train_test_split(tmp, test_size=0.1, random_state=101, stratify=tmp['root'])

train_df = train_df.drop('root', axis=1).reset_index(drop=True)
valid_df = valid_df.drop('root', axis=1).reset_index(drop=True)
test_df = test_df.drop('root', axis=1).reset_index(drop=True)

train_df.to_csv("dataset/org-train.csv", encoding="utf-8", index=False)
valid_df.to_csv("dataset/org-valid.csv", encoding="utf-8", index=False)
test_df.to_csv("dataset/org-test.csv", encoding="utf-8", index=False)

In [None]:
from sklearn import preprocessing
x_cols = list(train_df.columns)[1:-1]
scaler = preprocessing.MinMaxScaler(feature_range=(0, 1))
scaler.fit(train_df[x_cols])


new_train_df = train_df.copy()
new_train_df[x_cols] = scaler.transform(train_df[x_cols])

new_valid_df = valid_df.copy()
new_valid_df[x_cols] = scaler.transform(valid_df[x_cols])

new_test_df = test_df.copy()
new_test_df[x_cols] = scaler.transform(test_df[x_cols])


new_train_df.to_csv("dataset/train.csv", encoding="utf-8", index=False)
new_valid_df.to_csv("dataset/valid.csv", encoding="utf-8", index=False)
new_test_df.to_csv("dataset/test.csv", encoding="utf-8", index=False)

In [None]:
new_train_df.shape, new_valid_df.shape, new_test_df.shape

# Dataset Loader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

from pymatgen.core.structure import Structure

import pandas as pd
import numpy as np
import json
import warnings


class GaussianDistance(object):
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """
    def __init__(self, dmin, dmax, step, var=None):
        """
        Parameters
        ----------

        dmin: float
          Minimum interatomic distance
        dmax: float
          Maximum interatomic distance
        step: float
          Step size for the Gaussian filter
        """
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax+step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Apply Gaussian disntance filter to a numpy distance array

        Parameters
        ----------

        distance: np.array shape n-d array
          A distance matrix of any shape

        Returns
        -------
        expanded_distance: shape (n+1)-d array
          Expanded distance matrix with the last dimension of length
          len(self.filter)
        """
        return np.exp(-(distances[..., np.newaxis] - self.filter)**2 /
                      self.var**2)


class AtomInitializer(object):
    """
    Base class for intializing the vector representation for atoms.

    !!! Use one AtomInitializer per dataset !!!
    """
    def __init__(self, atom_types):
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {idx: atom_type for atom_type, idx in
                            self._embedding.items()}

    def state_dict(self):
        return self._embedding

    def decode(self, idx):
        if not hasattr(self, '_decodedict'):
            self._decodedict = {idx: atom_type for atom_type, idx in
                                self._embedding.items()}
        return self._decodedict[idx]


class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Parameters
    ----------

    elem_embedding_file: str
        The path to the .json file
    """
    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value
                          in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)
            

class CrystalDataset(Dataset):

    def __init__(self, data_path, atom_init_file="./atom_init.json", max_num_nbr=12, radius=8, dmin=0, step=0.2):
        #self.df = pd.read_json(data_path, orient='records', lines=True)
        self.df = pd.read_csv(data_path)
        self.max_num_nbr, self.radius = max_num_nbr, radius
        
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sample = self.df.iloc[idx].values
        struct, target, cif_id = sample[0], sample[1:-1], sample[-1]
        #struct, target, cif_id = sample[0], sample[1], sample[2]

        crystal = Structure.from_file(struct)
        #crystal = Structure.from_dict(struct)
        atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) for i in range(len(crystal))])
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn('{} not find enough neighbors to build graph. '
                              'If it happens frequently, consider increase '
                              'radius.'.format(cif_id))
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) +
                                   [0] * (self.max_num_nbr - len(nbr)))
                nbr_fea.append(list(map(lambda x: x[1], nbr)) +
                               [self.radius + 1.] * (self.max_num_nbr -
                                                     len(nbr)))
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2],
                                            nbr[:self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1],
                                        nbr[:self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor(np.array(target, dtype=float))
        #target = torch.Tensor([float(target)])

        x, edge_attr, edge_index = atom_fea, nbr_fea, nbr_fea_idx
        return (x, edge_attr, edge_index), target, cif_id

In [None]:
train_dataset = CrystalDataset("dataset/train.csv", max_num_nbr=12, radius=8)
valid_dataset = CrystalDataset("dataset/valid.csv", max_num_nbr=12, radius=8)
test_dataset = CrystalDataset("dataset/test.csv", max_num_nbr=12, radius=8)

In [None]:
print(train_dataset[0][0][0].shape)
print(train_dataset[0][0][1].shape)
print(train_dataset[0][0][2].shape)
print(train_dataset[0][1].shape)
print(train_dataset[0][2])

In [None]:
def collate_fn(data_list):
    batch_x, batch_edge_attr, batch_edge_index = [], [], []
    crystal_x_idx, batch_target = [], []
    batch_ids = []
    base_idx = 0
    
    for i, ((x, edge_attr, edge_index), target, _id) in enumerate(data_list):
        n_i = x.shape[0]
        batch_x.append(x)
        batch_edge_attr.append(edge_attr)
        batch_edge_index.append(edge_index + base_idx)
        
        new_idx = torch.LongTensor(np.arange(n_i) + base_idx)
        crystal_x_idx.append(new_idx)
        batch_target.append(target)
        batch_ids.append(_id)
        base_idx += n_i
    
    
    batch_x = torch.cat(batch_x, dim=0)
    batch_edge_attr = torch.cat(batch_edge_attr, dim=0)
    batch_edge_index = torch.cat(batch_edge_index, dim=0)
    batch_target = torch.stack(batch_target, dim=0)
    
    return (batch_x, batch_edge_attr, batch_edge_index, crystal_x_idx), batch_target, batch_ids

In [None]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
for batch in train_loader:
    print("x", batch[0][0].shape)
    print("edge_attr", batch[0][1].shape)
    print("edge_index", batch[0][2].shape)
#    print("crystal_x_idx", batch[0][3])
    
#    print(batch[1])
#    print(batch[2])
    break

# CGConv model

In [None]:
structures, _, _ = train_dataset[0]
original_x_len = structures[0].shape[-1]
edge_attr_len = structures[1].shape[-1]

print(original_x_len, edge_attr_len)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

import pytorch_lightning as pl

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, x_len, edge_attr_len):
        super(ConvLayer, self).__init__()
        
        self.x_len = x_len
        self.edge_attr_len = edge_attr_len
        self.fc = nn.Linear(2 * self.x_len + self.edge_attr_len, 2 * self.x_len)
        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(2 * self.x_len)
        self.bn2 = nn.BatchNorm1d(self.x_len)
        self.softplus2 = nn.Softplus()
    
    def forward(self, x, edge_attr, edge_index):
        N, M = edge_index.shape
        
        x_nbr = x[edge_index, :]
        nbrs = torch.cat([x.unsqueeze(1).expand(N, M, self.x_len), x_nbr, edge_attr], dim=2)
        gated = self.fc(nbrs)
        gated = self.bn1(gated.view(-1, self.x_len * 2)).view(N, M, self.x_len * 2)
        
        nbr_filter, nbr_core = gated.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)
        
        nbr_sum = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sum = self.bn2(nbr_sum)
        out = self.softplus2(x + nbr_sum)
        return out
        

class CrystalGraphConvNet(pl.LightningModule):

    def __init__(
        self, 
        original_x_len,
        edge_attr_len,
        out_dim=381,
        x_len=64,
        n_conv=3,
        n_h_features=128,
        n_h_layers=1,
        learning_rate=1e-3,
        bias=True
    ):
        # Inheritances
        super().__init__()

        # Params
        self.learning_rate = learning_rate

        # Embedding
        self.embedding = nn.Linear(original_x_len, x_len)

        # CGC layer
        self.convs = nn.ModuleList([
            ConvLayer(x_len=x_len, edge_attr_len=edge_attr_len)
            for _ in range(n_conv)
        ])        

        self.fc = nn.ModuleList(
            [nn.Linear(x_len, n_h_features)]
            +
            [nn.Linear(n_h_features, n_h_features) for _ in range(n_h_layers - 1)]
        )
        self.ac = nn.ModuleList(
            [nn.Softplus()]
            +
            [nn.Softplus() for _ in range(n_h_layers - 1)]
        )
        
        #self.sigmoid = nn.Sigmoid() # change one
        self.out = nn.Linear(n_h_features, out_dim)

    def forward(self, x, edge_attr, edge_index, crystal_x_idx):
        
        hidden = self.embedding(x)
        
        
        for conv in self.convs:
            hidden = conv(hidden, edge_attr, edge_index)
            
        # Pooling
        hidden =self.pooling(hidden, crystal_x_idx)
        
        # Fully-connection
        for fc, ac in zip(self.fc, self.ac):
            hidden = ac(fc(hidden))

        # Apply a final (linear) classifier.
        #out = self.sigmoid(self.out(hidden)) # change two
        out = self.out(hidden) 

        return out
    
    def pooling(self, x, crystal_x_idx):
        assert sum([len(idx) for idx in crystal_x_idx]) == x.data.shape[0]

        hidden = torch.cat(
            [torch.mean(x[idx], dim=0, keepdim=True) for idx in crystal_x_idx],
            dim=0
        )

        return hidden

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=0.01)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
        return (
            {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
        )

    def training_step(self, batch, batch_idx):
        (x, edge_attr, edge_index, crystal_x_idx), y, ids = batch
        y_hat = self(x, edge_attr, edge_index, crystal_x_idx)

        loss = F.mse_loss(y_hat, y)
        loss_mae = F.l1_loss(y_hat, y)
        
        self.log(
            'train_loss_mse', loss, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True,
            batch_size=y_hat.size(0)
        )
        self.log(
            'train_loss_mae', loss_mae, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True,
            batch_size=y_hat.size(0)
        )
        return loss    

    def validation_step(self, batch, batch_idx):
        (x, edge_attr, edge_index, crystal_x_idx), y, ids = batch
        y_hat = self(x, edge_attr, edge_index, crystal_x_idx)

        loss = F.mse_loss(y_hat, y)
        loss_mae = F.l1_loss(y_hat, y)
        
        self.log(
            'valid_loss_mse', loss, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True,
            batch_size=y_hat.size(0)
        )
        self.log(
            'valid_loss_mae', loss_mae, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True,
            batch_size=y_hat.size(0)
        )
        return loss

    def test_step(self, batch, batch_idx):
        (x, edge_attr, edge_index, crystal_x_idx), y, ids = batch
        y_hat = self(x, edge_attr, edge_index, crystal_x_idx)

        loss = F.mse_loss(y_hat, y)
        loss_mae = F.l1_loss(y_hat, y)
        
        self.log(
            'test_loss_mse', loss, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True,
            batch_size=y_hat.size(0)
        )
        self.log(
            'test_loss_mae', loss_mae, 
            on_step=True, on_epoch=True, prog_bar=True, logger=True,
            batch_size=y_hat.size(0)
        )
        return loss

In [None]:
from dataclasses import dataclass


@dataclass
class ModelConfig:
    original_x_len = original_x_len
    edge_attr_len = edge_attr_len
    out_dim = 381
    x_len = 64
    n_conv = 3
    n_h_features = 128
    n_h_layers = 1
    learning_rate = 1e-2

In [None]:
model = CrystalGraphConvNet( 
    original_x_len=ModelConfig.original_x_len,
    edge_attr_len=ModelConfig.edge_attr_len,
    out_dim=ModelConfig.out_dim,
    x_len=ModelConfig.x_len,
    n_conv=ModelConfig.n_conv,
    n_h_features=ModelConfig.n_h_features,
    n_h_layers=ModelConfig.n_h_layers,
    learning_rate=ModelConfig.learning_rate,
)
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
for batch in train_loader:
    (batch_x, batch_edge_attr, batch_edge_index, crystal_x_idx), batch_target, batch_ids = batch

    out = model(batch_x, batch_edge_attr, batch_edge_index, crystal_x_idx)
    print(out.shape)
    break

# Training

In [None]:
model = CrystalGraphConvNet( 
    original_x_len=ModelConfig.original_x_len,
    edge_attr_len=ModelConfig.edge_attr_len,
    out_dim=ModelConfig.out_dim,
    x_len=ModelConfig.x_len,
    n_conv=ModelConfig.n_conv,
    n_h_features=ModelConfig.n_h_features,
    n_h_layers=ModelConfig.n_h_layers,
    learning_rate=ModelConfig.learning_rate,
)
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
gpus = torch.cuda.device_count()
# early_stop = pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", patience=5, mode="min")
checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath="content/ckpts/", 
    save_top_k=1, 
    monitor="val_loss_mse",
    mode="min",
)
print(gpus)

In [None]:
from pytorch_lightning.loggers import CSVLogger

logger = CSVLogger(save_dir="log/", name="gnn")

In [None]:
# trainer = pl.Trainer(gpus=gpus, strategy=strategy, max_epochs=50, log_every_n_steps=5, callbacks=[checkpoint])
trainer = pl.Trainer(gpus=gpus, max_epochs=5, log_every_n_steps=10, callbacks=[checkpoint], logger=logger)
trainer

In [None]:
trainer.fit(model, train_loader, valid_loader)

In [None]:
trainer.save_checkpoint("content/ckpts/final.ckpt")

# Testing stage

In [None]:
trainer.test(ckpt_path='best', dataloaders=test_loader)

In [None]:
metrics = []
for f in glob.glob('log/gnn/*'):
    metric_file = f"{f}/metrics.csv"
    if not os.path.exists(metric_file):
        continue
        
    metrics.append(pd.read_csv(metric_file))

metrics = pd.concat(metrics)
#display.display(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
# display.display(df_metrics.head())

df_metrics[["train_loss_mse_epoch", "val_loss_mse_epoch"]].plot(grid=True, legend=True, figsize=(10, 8))
plt.show()

In [None]:
metrics = []
for f in glob.glob('log/gnn/*'):
    metric_file = f"{f}/metrics.csv"
    if not os.path.exists(metric_file):
        continue
        
    metrics.append(pd.read_csv(metric_file))

metrics = pd.concat(metrics)
#display.display(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
# display.display(df_metrics.head())

df_metrics[["train_loss_mae_epoch", "val_loss_mae_epoch"]].plot(grid=True, legend=True, figsize=(10, 8))
plt.show()

# Prediction stage

In [None]:
final_model_path = checkpoint.best_model_path
print(final_model_path)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
final_model = CrystalGraphConvNet.load_from_checkpoint(
    final_model_path, 
    original_x_len=ModelConfig.original_x_len,
    edge_attr_len=ModelConfig.edge_attr_len,
    out_dim=ModelConfig.out_dim,
    x_len=ModelConfig.x_len,
    n_conv=ModelConfig.n_conv,
    n_h_features=ModelConfig.n_h_features,
    n_h_layers=ModelConfig.n_h_layers,
    learning_rate=ModelConfig.learning_rate,
).to(device)
final_model

In [None]:
y_pred, y_real = [], []
for i, batch in enumerate(train_loader):
    with torch.no_grad():
        
        (x, edge_attr, edge_index, crystal_x_idx), y, ids = batch
        x = x.to(device)
        edge_attr = edge_attr.to(device)
        edge_index = edge_index.to(device)
        
        pred = final_model(x, edge_attr, edge_index, crystal_x_idx)
        
        pred = pred.cpu().numpy()
        y = y.cpu().numpy()

        # y_pred.append(np.concatenate([pred, pred], axis=1))
        # y_real.append(np.concatenate([y, y], axis=1))

        y_pred.append(pred)
        y_real.append(y)
        break

        
y_pred = np.vstack(y_pred)
y_real = np.vstack(y_real)

print(y_pred.shape)
print(y_real.shape)

In [None]:
idx = 20
print(y_pred[idx, 200:210].tolist())
print(y_real[idx, 200:210].tolist())

In [None]:
plt.rcParams['figure.figsize'] = (12, 8)

plt.plot(y_real[20, :],'b', y_pred[20, :],'r')

plt.title('y-real vs y_pred')
plt.xlabel('number of records')
plt.ylabel('intensity values')
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, figsize=(12,8))


yr = y_real[100, :]
yp = y_pred[100, :]

ax1.plot(yr)
ax1.set_title('y_real')
ax1.set_ylabel('intensity value')
ax2.plot(yp, 'r')
ax2.set_title('y_pred')
ax2.set_xlabel('number of records (381)')
ax2.set_ylabel('intensity value')