In [1]:
from mof import MOF_CGCNN
import csv
from sklearn.model_selection import train_test_split
import pandas as pd
import os
from pathlib import Path

def create_single_label_df(molecule, target_pressure):
    data_all_labels = pd.read_csv(f'{molecule}_data_all_labels.csv')
    data_all_labels.drop('surface_area_m2g', axis=1, inplace=True)

    target_pressure = f'{target_pressure}'
    columns_to_keep = ['id', 'surface_area_m2cm3', 'void_fraction', 'lcd', 'pld', target_pressure]
    data_single_label = data_all_labels.copy()
    data_single_label = data_single_label[columns_to_keep]
    data_single_label.rename(columns={target_pressure: 'target'}, inplace=True)

    # Reorder the columns as 'id', 'surface_area_m2_cm3', 'void_fraction', 'lcd', 'pld', 'target'
    column_order = ['id', 'surface_area_m2cm3', 'void_fraction', 'lcd', 'pld', 'target']
    data_single_label = data_single_label.reindex(columns=column_order)
    data_single_label = data_single_label.set_index('id')
    return data_single_label

def find_directory():
    """Finds the directory of the python script or Jupyter notebook.

    Returns:
        directory (str): directory of script.
    """
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':  # If running in a Jupter notebook
            directory = os.getcwd()
        else:
            directory = Path(__file__).parent
    except NameError:
        directory = Path(__file__).parent
    return directory

def get_cif_IDs():
    directory = find_directory()
    cif_directory = f'{directory}/cif'
    filenames = os.listdir(cif_directory)

    # Remove file extensions and return list
    cif_ids = [Path(file).stem for file in filenames if os.path.isfile(os.path.join(cif_directory, file))]
    return cif_ids

def build_dataset(molecule, target_pressure):
    # Make dataframe with single label at a given pressure
    data_single_label = create_single_label_df(molecule, target_pressure)

    # Filter dataframe so that it only contains MOFs that have corresponding cif files
    cif_ids = get_cif_IDs()
    data_single_label = data_single_label[data_single_label.index.isin(cif_ids)]

    # Save csv as training+validation set
    data_single_label.to_csv('data.csv', header=False)
    return data_single_label

def split_dataset(molecule, target_pressure, train_val_test_split):
    data_single_label = build_dataset(molecule=molecule, target_pressure=target_pressure)

    # Split away the test set
    training_val_set, test_set = train_test_split(data_single_label, test_size=train_val_test_split[2], random_state=42)

    # Save these dataframes as .csv files
    training_val_set.to_csv('training_val.csv', header=False)
    test_set.to_csv('test.csv', header=False)

# --- Hyperparameters ---
molecule = 'co2'
target_pressure = 0.1
train_val_test_split = [0.7, 0.2, 0.1]
epochs = 200
initial_lr = 0.002
lr_decay_rate = 0.9
# -----------------------
    
split_dataset(molecule, target_pressure, train_val_test_split)

##read data
with open('./training_val.csv') as f:
    readerv = csv.reader(f)
    trainandval = [row for row in readerv]
with open('./test.csv') as f:
    readerv = csv.reader(f)
    test = [row for row in readerv]

test_size = train_val_test_split[1] / (1 - train_val_test_split[2])
train, val = train_test_split(trainandval, test_size=test_size, random_state=24)
#file path
root = './cif'
#create a class
mof = MOF_CGCNN(cuda=True,
                root_file=root,
                trainset = train,
                valset=val,
                testset=test,
                epoch = epochs,
                lr=initial_lr,
                optim='Adam',
                batch_size=24,
                h_fea_len=480,
                n_conv=5,
                lr_milestones=[200],
                lr_decay_rate=lr_decay_rate,
                weight_decay=1e-7,
                dropout=0.2)
# train the model
mof.train_MOF()



Predicting  1  properties!!
Epoch: [0][0/917]	Loss 27.8569 (27.8569)	MAE 3.698 (3.698)
Epoch: [0][10/917]	Loss 1174.7810 (1308.7495)	MAE 25.554 (22.173)
Epoch: [0][20/917]	Loss 25.3808 (837.5422)	MAE 3.740 (16.849)
Epoch: [0][30/917]	Loss 88.6656 (611.1092)	MAE 6.557 (13.892)
Epoch: [0][40/917]	Loss 57.7713 (472.4361)	MAE 5.525 (11.531)
Epoch: [0][50/917]	Loss 3.7973 (382.2437)	MAE 1.320 (9.704)
Epoch: [0][60/917]	Loss 1.3376 (320.4774)	MAE 0.575 (8.357)
Epoch: [0][70/917]	Loss 0.7752 (275.6117)	MAE 0.576 (7.296)
Epoch: [0][80/917]	Loss 0.5948 (241.7028)	MAE 0.521 (6.461)
Epoch: [0][90/917]	Loss 0.9880 (215.2468)	MAE 0.586 (5.807)
Epoch: [0][100/917]	Loss 0.2329 (193.9913)	MAE 0.263 (5.275)
Epoch: [0][110/917]	Loss 1.2372 (176.5878)	MAE 0.569 (4.841)
Epoch: [0][120/917]	Loss 0.5365 (162.0593)	MAE 0.525 (4.484)
Epoch: [0][130/917]	Loss 0.7303 (149.7426)	MAE 0.481 (4.178)
Epoch: [0][140/917]	Loss 0.2724 (139.1501)	MAE 0.287 (3.905)
Epoch: [0][150/917]	Loss 0.5424 (129.9748)	MAE 0.332 (3.



Epoch: [0][750/917]	Loss 0.8542 (26.6531)	MAE 0.408 (1.096)
Epoch: [0][760/917]	Loss 0.2097 (26.3108)	MAE 0.236 (1.087)
Epoch: [0][770/917]	Loss 0.1359 (25.9763)	MAE 0.196 (1.078)
Epoch: [0][780/917]	Loss 0.2753 (25.6506)	MAE 0.341 (1.069)
Epoch: [0][790/917]	Loss 1.4928 (25.3337)	MAE 0.532 (1.061)
Epoch: [0][800/917]	Loss 0.9966 (25.0271)	MAE 0.472 (1.054)
Epoch: [0][810/917]	Loss 0.4795 (24.7246)	MAE 0.310 (1.045)
Epoch: [0][820/917]	Loss 0.5717 (24.4325)	MAE 0.479 (1.039)
Epoch: [0][830/917]	Loss 0.3446 (24.1454)	MAE 0.370 (1.031)
Epoch: [0][840/917]	Loss 0.3075 (23.8646)	MAE 0.290 (1.024)
Epoch: [0][850/917]	Loss 1.3910 (23.5908)	MAE 0.624 (1.016)
Epoch: [0][860/917]	Loss 0.4710 (23.3238)	MAE 0.475 (1.010)
Epoch: [0][870/917]	Loss 0.7262 (23.0639)	MAE 0.403 (1.004)
Epoch: [0][880/917]	Loss 0.1964 (22.8075)	MAE 0.255 (0.997)




Epoch: [0][890/917]	Loss 0.3942 (22.5590)	MAE 0.372 (0.991)
Epoch: [0][900/917]	Loss 0.3767 (22.3149)	MAE 0.429 (0.985)
Epoch: [0][910/917]	Loss 0.2451 (22.0749)	MAE 0.302 (0.978)
 * MAE tensor(0.9747)
Test: [0/262]	Loss 0.2262 (0.2262)	MAE 0.301 (0.301)
Test: [10/262]	Loss 0.1751 (0.5027)	MAE 0.272 (0.401)
Test: [20/262]	Loss 0.3625 (0.5009)	MAE 0.351 (0.402)
Test: [30/262]	Loss 0.4478 (0.4863)	MAE 0.428 (0.401)
Test: [40/262]	Loss 0.7369 (0.5059)	MAE 0.490 (0.408)
Test: [50/262]	Loss 0.8107 (0.5340)	MAE 0.451 (0.413)
Test: [60/262]	Loss 0.4453 (0.5335)	MAE 0.380 (0.415)
Test: [70/262]	Loss 0.3981 (0.5236)	MAE 0.368 (0.415)
Test: [80/262]	Loss 0.4972 (0.5182)	MAE 0.381 (0.413)
Test: [90/262]	Loss 0.7261 (0.5213)	MAE 0.429 (0.414)
Test: [100/262]	Loss 0.8672 (0.5382)	MAE 0.533 (0.417)
Test: [110/262]	Loss 0.2257 (0.5292)	MAE 0.312 (0.415)
Test: [120/262]	Loss 0.3027 (0.5278)	MAE 0.314 (0.414)
Test: [130/262]	Loss 0.6908 (0.5353)	MAE 0.471 (0.416)
Test: [140/262]	Loss 0.3316 (0.5329)	MA