In [None]:
#@title Install Uni-MOF and pretrained weights
%%bash
cd /content

if [ ! -f ENV_READY ]; then

    pip3 install rdkit

    pip3 install lmdb

    pip3 install pymatgen

    touch ENV_READY
fi

UNIMOF_GIT='https://github.com/dptech-corp/Uni-MOF'
UNICORE_GIT='https://github.com/dptech-corp/Uni-Core.git'
PARAM_URL='https://github.com/dptech-corp/Uni-MOF/releases/download/v0.1/unimof_CoRE_MOFX_DB_finetune_best.pt'

if [ ! -f UNIMOF_READY ]; then
    git clone -b main ${UNICORE_GIT}
    perl -pi -e 's/state = torch\.load\(f, map_location=torch\.device\("cpu"\)\)/state = torch.load(f, map_location=torch.device("cpu"), weights_only=False)/' ./Uni-Core/unicore/checkpoint_utils.py
    pip3 install -e ./Uni-Core
    git clone -b main ${UNIMOF_GIT}
    perl -pi -e 's/unimat\./unimof./g' ./Uni-MOF/unimof/tasks/*.py
    perl -pi -e 's/unimat\./unimof./g' ./Uni-MOF/unimof/__init__.py
    perl -pi -e 's/unimat\./unimof./g' ./Uni-MOF/unimof/losses/__init__.py
    mkdir ./Uni-MOF/weights
    wget ${PARAM_URL} -P ./Uni-MOF/weights
    touch UNIMOF_READY
fi

cd /content/Uni-MOF

In [None]:
#@title Download Data
%%bash
cd /content/Uni-MOF

DATA_URL='https://mof.tech.northwestern.edu/Datasets/CoREMOF%202019-10%201021%20acs%20jced%209b00835-all-mofdb-version%3Adc8a0295db.zip'
DATA_NAME='CoRE_MOF_2019'

if [ ! -f DATA_DOWNLOADED ]; then

    mkdir data

    wget ${DATA_URL} -O ./data/${DATA_NAME}.zip

    unzip ./data/${DATA_NAME}.zip -d ./data/${DATA_NAME}

    touch DATA_DOWNLOADED
fi


In [None]:
#@title Convert json data to csv
%cd /content/Uni-MOF

import json
import os
import pandas as pd

gas_dic = {1:"CH4", 2:"CO2", 3:"Ar", 4:"Kr", 5:"Xe", 6:"O2", 7:"N2"}

inv_gas_dic = {v: k for k, v in gas_dic.items()}

GAS2ATTR = {
    "CH4":[0.295589,0.165132,0.251511019,-0.61518,0.026952,0.25887781],
    "CO2":[1.475242,1.475921,1.620478155,0.086439,1.976795,1.69928074],
    "Ar":[-0.11632,0.294448,0.1914686,-0.01667,-0.07999,-0.1631478],
    "Kr":[0.48802,0.602454,0.215485568,1.084671,0.415991,0.39885917],
    "Xe":[1.324657,0.751519,0.233498293,2.276323,1.12122,1.18462811],
    "O2":[-0.08095,0.37909,0.335570404,-0.61626,-0.5363,-0.1130181],
    "He":[-1.66617,-1.88746,-2.15618995,-0.9173,-1.36413,-1.6042445],
    "N2":[-0.37636,-0.3968,0.41962979,-0.31495,-0.40022,-0.3355659],
    "H2":[-1.34371,-1.3843,-1.11145188,-0.96708,-1.16031,-1.3256695],
}

directory = './data/CoRE_MOF_2019'

data_rows = []

for filename in os.listdir(directory):
    if filename.endswith(".json"):
        filepath = os.path.join(directory, filename)
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
                name = data['name']
                for iso in data['isotherms']:
                    temperature = iso['temperature']
                    adsorbates_dict = {ads['InChIKey']:ads['formula'] for ads in iso['adsorbates']}
                    for iso_data in iso['isotherm_data']:
                        pressure = iso_data['pressure']
                        species_data = iso_data['species_data']
                        if len(species_data) > 1:
                            continue #skip multiple species data for now
                        for species in species_data:
                            gas_formula = adsorbates_dict[species['InChIKey']]
                            if gas_formula not in inv_gas_dic:
                                continue
                            gas = inv_gas_dic[gas_formula]
                            gas_attr = GAS2ATTR[gas_formula]
                            loading = species['adsorption']

                            data_rows.append({
                            'name': name,
                            'gas-name': gas,
                            'gas-CriticalTemp': gas_attr[0],
                            'gas-CriticalPressure': gas_attr[1],
                            'gas-AcentricFactor': gas_attr[2],
                            'gas-MolecularWeight': gas_attr[3],
                            'gas-MeltPoint': gas_attr[4],
                            'gas-BoilingPoint': gas_attr[5],
                            'temperature': temperature,
                            'pressure': pressure,
                            'loading_[cm^3(STP)/gr(framework)]_abs_num': loading, #not sure if loading is always given in the same units but leaving as is for now
                            })

        except json.JSONDecodeError:
            print(f"Error: Could not decode JSON in {filename}")
        except FileNotFoundError:
            print(f"Error: File not found: {filename}")

# Convert list of dicts to DataFrame
df = pd.DataFrame(data_rows)

# Export to CSV
df.to_csv("./data/data.csv", index=True)
print("Exported parsed data to data.csv")

In [None]:
#@title Preprocess Data
from pymatgen.core import Structure
from pymatgen.transformations.standard_transformations import ConventionalCellTransformation
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifParser
from multiprocessing import Process, Queue, Pool
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import pickle
import lmdb
import sys
import glob
import os
import re
from sklearn.preprocessing import MinMaxScaler
import itertools
import pdb
import random

def normalize_atoms(atom):
    return re.sub("\d+", "", atom)

def transform(cif_path):
    max_tolerance = 100
    s = CifParser(cif_path, occupancy_tolerance=max_tolerance)
    trans = ConventionalCellTransformation()
    s_trans = trans.apply_transformation(s.get_structures()[0])
    return s_trans

def cif_parser(cif_path, primitive=False):
    """
    Parser for single cif file
    """
    try:
        s = Structure.from_file(cif_path, primitive=primitive)
    except:
        s = transform(cif_path)
    id = cif_path.split('/')[-1][:-4]
    lattice = s.lattice
    abc = lattice.abc # lattice vectors
    angles = lattice.angles # lattice angles
    volume = lattice.volume # lattice volume
    lattice_matrix = lattice.matrix # lattice 3x3 matrix

    df = s.as_dataframe()
    atoms = df['Species'].astype(str).map(normalize_atoms).tolist()
    coordinates = df[['x', 'y', 'z']].values.astype(np.float32)
    abc_coordinates = df[['a', 'b', 'c']].values.astype(np.float32)
    assert len(atoms) == coordinates.shape[0]
    assert len(atoms) == abc_coordinates.shape[0]

    return {'ID':id,
            'atoms':atoms,
            'coordinates':coordinates,
            'abc':abc,
            'angles':angles,
            'volume':volume,
            'lattice_matrix':lattice_matrix,
            'abc_coordinates':abc_coordinates,
            }

def single_parser(content):
    dir_path = './data/CoRE_MOF_2019'
    cif_name, gas, gas_attr, temperature, pressure, targets, task_name = content
    cif_path = os.path.join(dir_path, cif_name+'.cif')
    if os.path.exists(cif_path):
        data = cif_parser(cif_path, primitive=False)
        data['gas'] = np.array(gas, dtype=np.int32)
        data['gas_attr'] = gas_attr.astype(np.float32)
        data['temperature'] = np.array(temperature, dtype=np.float32)
        data['pressure'] = np.array(np.log10(pressure), dtype=np.float32)
        data['target'] = np.array(targets, dtype=np.float32)
        data['task_name'] = task_name
        return pickle.dumps(data, protocol=-1)
    else:
        print(f'{cif_path} does not exit!')
        return None

def get_data(path):
    data = pd.read_csv(path)
    cif_names = 'name'
    gas = 'gas-name'
    gas_attr = ['gas-CriticalTemp', 'gas-CriticalPressure', 'gas-AcentricFactor', 'gas-MolecularWeight', 'gas-MeltPoint', 'gas-BoilingPoint']
    temperature = 'temperature'
    pressure = 'pressure'
    columns = 'loading_[cm^3(STP)/gr(framework)]_abs_num'
    data['task_name'] = data[cif_names].astype(str) + '#' + data[gas].astype(str) + '#' + data[temperature].astype(str) + '#' + data[pressure].astype(str)

    # print mean and std
    value_log1p = np.log1p(data[columns])
    _mean,_std = value_log1p.mean(), value_log1p.std()
    print(f'mean and std of target values are: {_mean}, {_std}')

    return [(item[0], item[1], item[2], item[3], item[4], item[5], item[6]) for item in zip(data[cif_names], data[gas], data[gas_attr].values, data[temperature], data[pressure], data[columns], data['task_name'])]

# split the database into train, validation and test set according to gases
def train_valid_test_split(data, train_ratio=0.8, valid_ratio=0.1, test_ratio=0.1):
    np.random.seed(42)
    id_list = [item[1] for item in data] ##gas_id
    unique_id_list = list(set(id_list))
    unique_id_list = np.random.permutation(unique_id_list)
    print(f'length of data is {len(data)}')
    print(f'length of unique_id_list is {len(unique_id_list)}')

    gas_dic = {1:"CH4", 2:"CO2", 3:"Ar", 4:"Kr", 5:"Xe", 6:"O2", 7:"N2"}
    gas_list = [1,2,3,4,5,6,7]

    print("*******************************")

    test_mat_id = 6
    test_id_list = np.array([test_mat_id])
    print("test_id_list:", gas_dic[test_mat_id])

    gas_list.remove(test_mat_id)
    valid_mat_id = random.sample(gas_list,1)
    valid_id_list = np.array([valid_mat_id])
    print("valid_id_list:", gas_dic[valid_mat_id[0]])

    gas_list.remove(valid_mat_id[0])
    train_id_list = np.array(gas_list)
    print("train_id_list:", train_id_list)

    train_data = [item for item in data if item[1] in train_id_list]
    valid_data = [item for item in data if item[1] in valid_id_list]
    test_data = [item for item in data if item[1] in test_id_list]

    print("*******************************")
    print(f'train_len:{len(train_data)}')
    print(f'valid_len:{len(valid_data)}')
    print(f'test_len:{len(test_data)}')

    return train_data, valid_data, test_data

def rand_test_split(data, num_samples):
    np.random.seed(10)
    np.random.shuffle(data)
    test_data = data[:num_samples]
    return [], [], test_data

def write_lmdb(inpath='./', outpath='./', nthreads=40):
    data = get_data(inpath)
    #train_data, valid_data, test_data = train_valid_test_split(data)
    train_data, valid_data, test_data = rand_test_split(data, 16)
    print(len(train_data), len(valid_data), len(test_data))
    for name, content in [ ('train.lmdb', train_data),
                            ('valid.lmdb', valid_data),
                            ('test.lmdb', test_data) ]:
        outputfilename = os.path.join(outpath, name)
        os.makedirs(os.path.dirname(outputfilename), exist_ok=True)
        try:
            os.remove(outputfilename)
        except:
            pass
        env_new = lmdb.open(
            outputfilename,
            subdir=False,
            readonly=False,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=1,
            map_size=int(100e9),
        )
        txn_write = env_new.begin(write=True)
        with Pool(nthreads) as pool:
            i = 0
            for inner_output in tqdm(pool.imap(single_parser, content), total=len(content)):
                if inner_output is not None:
                    txn_write.put(f'{i}'.encode("ascii"), inner_output)
                    i += 1
                    if i % 1000 == 0:
                        txn_write.commit()
                        txn_write = env_new.begin(write=True)
            print('{} process {} lines'.format(name, i))
            txn_write.commit()
            env_new.close()

inpath = './data/data.csv' # replace to your data path
outpath = './data/CoRE_DB' # replace to your out path
!mkdir -p $outpath
write_lmdb(inpath=inpath, outpath=outpath, nthreads=8)


In [None]:
%cp /content/Uni-MOF/unimof/infer.py /content/Uni-MOF/infer.py
%cp /content/Uni-MOF/examples/mof/dict.txt /content/Uni-MOF/data/dict.txt
!mkdir -p /content/Uni-MOF/results

In [None]:
import subprocess

cmd = [
    "python", "infer.py", "./data",
    "--user-dir", "./unimof",
    "--path", "./weights/unimof_CoRE_MOFX_DB_finetune_best.pt",
    "--task-name", "CoRE_DB",
    "--valid-subset", "test",
    "--num-workers", "0",
    "--task", "unimof_v2",
    "--arch", "unimof_v2",
    "--loss", "mof_v2_mse",
    "--batch-size", "4",
    "--seed", "1",
    "--fp16",
    "--fp16-init-scale", "4",
    "--fp16-scale-window", "256",
    "--num-classes", "1",
    "--remove-hydrogen",
    "--results-path", "results",
    "--log-interval", "1",
]

process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
for line in process.stdout:
    print(line, end="", flush=True)  # prints output live


In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict
import pickle
import torch
import matplotlib.pyplot as plt

# Load the data
with open('./results/weights_test.out.pkl', 'rb') as f:
    results = pickle.load(f)

# Extract predictions, targets, and names
all_preds = []
all_targets = []
all_names = []

for entry in results:
    preds = entry['predict'].cpu().numpy().flatten()
    targets = entry['target'].cpu().numpy().flatten()
    names = entry['task_name']

    all_preds.extend(preds)
    all_targets.extend(targets)
    all_names.extend(names)


# Mapping from gas ID to gas name
gas_dic = {1: "CH4", 2: "CO2", 3: "Ar", 4: "Kr", 5: "Xe", 6: "O2", 7: "N2"}

# Group predictions by gas name
gas_to_points = defaultdict(list)

annotations = []

for pred, target, name in zip(all_preds, all_targets, all_names):
    parts = name.split('#')
    gas_id = int(parts[1])
    temp = parts[2]
    pressure = parts[3]
    gas_name = gas_dic.get(gas_id, f"Unknown({gas_id})")

    gas_to_points[gas_name].append((target, pred))
    annotations.append((target, pred, gas_name, temp, pressure))

# Plotting
plt.figure(figsize=(10, 10))
colors = plt.cm.tab10.colors
gas_names = sorted(gas_to_points.keys())
color_map = {gas: colors[i % len(colors)] for i, gas in enumerate(gas_names)}

for gas_name in gas_names:
    points = gas_to_points[gas_name]
    t, p = zip(*points)
    plt.scatter(t, p, label=gas_name, color=color_map[gas_name], alpha=0.7)

# Annotations
for t, p, gas, temp, pressure in annotations:
    plt.annotate(f'{temp}K, {pressure}Pa', (t, p), textcoords="offset points", xytext=(0,5), ha='center', fontsize=8, alpha=0.6)

# Diagonal line
min_val = min(all_targets + all_preds)
max_val = max(all_targets + all_preds)
plt.plot([min_val, max_val], [min_val, max_val], 'k--', label='Ideal (y = x)')

plt.xlabel('Target')
plt.ylabel('Predicted')
plt.title('Predicted vs Target (Gas Type + Temp/Pressure)')
plt.legend(title="Gas")
plt.grid(True)
plt.axis('equal')
plt.tight_layout()
plt.show()


In [None]:
from google.colab import drive
drive.mount('/content/drive')
import shutil
import os

# Define source and destination paths
source_dir = "/content"
dest_dir = '/content/drive/My Drive/Uni-MOF_inference_test'

# Define folders or file patterns to exclude
excluded_names = ['weights','logs','data','results', '__pycache__', 'drive', 'sample_data']  # Add any others

def ignore_func(dir, files):
    ignored = []
    for name in files:
        full_path = os.path.join(dir, name)
        for excl in excluded_names:
            if name == excl or full_path.startswith(os.path.join(source_dir, excl)):
                ignored.append(name)
    return ignored

# Perform the copy
shutil.copytree(source_dir, dest_dir, ignore=ignore_func, dirs_exist_ok=True)