# Finetuning UniMOF Model for WS24

This notebook demonstrates how to finetune a pretrained UniMOF model on a dataset containing MOF structures to predict adsorption energies. We'll utilize the pretrained weights and adapt the model for our specific task.

In [1]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [4]:
#@title Install Uni-Core and dependencies
%%bash
cd /content

# install dependencies if not done already
if [ ! -f ENV_READY ]; then
    pip3 install molSimplify
    pip3 install rdkit
    pip3 install lmdb
    pip3 install pymatgen
    touch ENV_READY
fi

UNICORE_GIT='https://github.com/dptech-corp/Uni-Core.git'

# install Uni-Core if not done already
if [ ! -f UNICORE_READY ]; then
    git clone -b main ${UNICORE_GIT}
    # fix error in code before installing
    perl -pi -e 's/state = torch\.load\(f, map_location=torch\.device\("cpu"\)\)/state = torch.load(f, map_location=torch.device("cpu"), weights_only=False)/' ./Uni-Core/unicore/checkpoint_utils.py
    pip3 install -e ./Uni-Core
    touch UNICORE_READY
fi

Collecting molSimplify
  Downloading molSimplify-1.7.6-py3-none-any.whl.metadata (48 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 48.7/48.7 kB 1.7 MB/s eta 0:00:00
Collecting openbabel-wheel (from molSimplify)
  Downloading openbabel_wheel-3.1.1.21-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading molSimplify-1.7.6-py3-none-any.whl (15.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 15.7/15.7 MB 79.8 MB/s eta 0:00:00
Downloading openbabel_wheel-3.1.1.21-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.1/16.1 MB 61.8 MB/s eta 0:00:00
Installing collected packages: openbabel-wheel, molSimplify
Successfully installed molSimplify-1.7.6 openbabel-wheel-3.1.1.21


In [3]:
!wget https://zenodo.org/records/12110918/files/data_sets.zip
!unzip data_sets.zip

--2025-05-09 05:36:10--  https://zenodo.org/records/12110918/files/data_sets.zip
Resolving zenodo.org (zenodo.org)... 188.185.43.25, 188.185.45.92, 188.185.48.194, ...
Connecting to zenodo.org (zenodo.org)|188.185.43.25|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4826475 (4.6M) [application/octet-stream]
Saving to: ‘data_sets.zip’


2025-05-09 05:36:12 (5.02 MB/s) - ‘data_sets.zip’ saved [4826475/4826475]

Archive:  data_sets.zip
   creating: data_sets/
  inflating: __MACOSX/._data_sets    
   creating: data_sets/WS14s/
  inflating: __MACOSX/data_sets/._WS14s  
   creating: data_sets/CIF_preparation_examples/
  inflating: __MACOSX/data_sets/._CIF_preparation_examples  
   creating: data_sets/text_mining_example/
  inflating: __MACOSX/data_sets/._text_mining_example  
  inflating: data_sets/.DS_Store     
  inflating: __MACOSX/data_sets/._.DS_Store  
   creating: data_sets/WS24s/
  inflating: __MACOSX/data_sets/._WS24s  
   creating: data_sets/validation_se

In [None]:
#@title Navigate to WS24 file in drive
%cd /content/drive/MyDrive/X.C51_project/WS24-UniMOF

In [6]:
# Clean up mol files (remove overlapping atoms and floating sovlent from CIFs)

from molSimplify.Informatics.MOF.PBC_functions import solvent_removal, overlap_removal
import os

# Path to the directory containing your .cif files
input_dir = "./data_sets/WS24s/CIFs"
output_dir = "./data_sets/WS24s/CIFs"  # Can be the same as input_dir if you want

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Loop through all files in the directory
for filename in os.listdir(input_dir):
    if filename.endswith(".cif"):
        input_path = os.path.join(input_dir, filename)
        base_name = os.path.splitext(filename)[0]
        output_filename = f"{base_name}_clean.cif"
        cleaned_path = os.path.join(output_dir, output_filename)

        overlap_removal(input_path, cleaned_path) # Input CIF should have P1 symmetry.
        solvent_removal(cleaned_path, cleaned_path)

('cell vectors: ', 'alpha, beta, gamma = 81.911, 84.45, 70.545')
n_components: 1
labels_components: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0]
len is 194
('cell vectors: ', 'alpha, beta, gamma = 90.0, 96.876, 90.0')
n_components: 1
labels_components: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0

KeyboardInterrupt: 

## 1. Setup Environment and Dependencies

## 3. Prepare the Dataset

Since your dataset is in a specific format, we need to convert it into an LMDB database that UniMOF can use. We'll extract the relevant information and create a compatible dataset.

In [12]:
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifParser
from multiprocessing import Process, Queue, Pool
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import pickle
import lmdb
import sys
import glob
import os
import re

def normalize_atoms(atom):
    return re.sub("\d+", "", atom)

def cif_parser(cif_path, primitive=False):
    """
    Parser for single cif file
    """
    s = Structure.from_file(cif_path, primitive=primitive)
    id = cif_path.split('/')[-1][:-4]
    lattice = s.lattice
    abc = lattice.abc # lattice vectors
    angles = lattice.angles # lattice angles
    volume = lattice.volume # lattice volume
    lattice_matrix = lattice.matrix # lattice 3x3 matrix

    df = s.as_dataframe()
    atoms = df['Species'].astype(str).map(normalize_atoms).tolist()
    coordinates = df[['x', 'y', 'z']].values.astype(np.float32)
    abc_coordinates = df[['a', 'b', 'c']].values.astype(np.float32)
    assert len(atoms) == coordinates.shape[0]
    assert len(atoms) == abc_coordinates.shape[0]

    return {'ID':id,
            'atoms':atoms,
            'coordinates':coordinates,
            'abc':abc,
            'angles':angles,
            'volume':volume,
            'lattice_matrix':lattice_matrix,
            'abc_coordinates':abc_coordinates
            }

def single_parser(content):
    dir_path = './data_sets/WS24s/CIFs'  # replace with your MOF database path
    cif_name, targets = content
    cif_path = os.path.join(dir_path, cif_name + '.cif')
    if os.path.exists(cif_path):
        try:
            data = cif_parser(cif_path, primitive=False)
            data['mof-name'] = cif_name
            data['target'] = targets
            return pickle.dumps(data, protocol=-1)
        except ValueError as e:
            print(f"Error processing {cif_path}: {e}")
            return None
    else:
        print(f"{cif_path} does not exist!")
        return None

def get_data(path):
    data = pd.read_csv(path)
    columns = 'target' # replace to your target column
    cif_names = 'mof-name' # replace to your mof name column

    value = data[columns]
    _mean,_std = value.mean(), value.std()
    print(f'mean and std of target values are: {_mean}, {_std}')

    return [(item[0], item[1]) for item in zip(data[cif_names], data[columns].values)]

def train_valid_test_split(data, train_ratio=0.8, valid_ratio=0.1, test_ratio=0.1):
    np.random.seed(42)
    id_list = [item[0] for item in data]
    unique_id_list = list(set(id_list))
    unique_id_list = np.random.permutation(unique_id_list)
    print(f'length of data is {len(data)}')
    print(f'length of unique_id_list is {len(unique_id_list)}')
    train_size = int(len(unique_id_list) * train_ratio)
    valid_size = int(len(unique_id_list) * valid_ratio)
    train_id_list = unique_id_list[:train_size]
    valid_id_list = unique_id_list[train_size:train_size+valid_size]
    test_id_list = unique_id_list[train_size+valid_size:]

    train_data = [item for item in data if item[0] in train_id_list]
    valid_data = [item for item in data if item[0] in valid_id_list]
    test_data = [item for item in data if item[0] in test_id_list]

    print(f'train_len:{len(train_data)}')
    print(f'valid_len:{len(valid_data)}')
    print(f'test_len:{len(test_data)}')

    return train_data, valid_data, test_data

def write_lmdb(inpath='./', outpath='./',nthreads=40):
    data = get_data(inpath)
    train_data, valid_data, test_data = train_valid_test_split(data)
    print(len(train_data), len(valid_data), len(test_data))
    for name, content in [ ('train.lmdb', train_data),
                            ('valid.lmdb', valid_data),
                            ('test.lmdb', test_data) ]:
        outputfilename = os.path.join(outpath, name)
        try:
            os.remove(outputfilename)
        except:
            pass
        env_new = lmdb.open(
            outputfilename,
            subdir=False,
            readonly=False,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=1,
            map_size=int(100e9),
        )
        txn_write = env_new.begin(write=True)
        with Pool(nthreads) as pool:
            i = 0
            for inner_output in tqdm(pool.imap(single_parser, content), total=len(content)):
                if inner_output is not None:
                    txn_write.put(f'{i}'.encode("ascii"), inner_output)
                    i += 1
                    if i % 1000 == 0:
                        txn_write.commit()
                        txn_write = env_new.begin(write=True)
            print('{} process {} lines'.format(name, i))
            txn_write.commit()
            env_new.close()


In [10]:
# Load CSV files as DataFrames
df = pd.read_csv('./data_sets/WS24s/labels.csv')

df['mof-name'] = df['refcode']
df['target'] = df['Burtch label (1=unstable; 2=low kinetic stability; 3=high kinetic stability; 4=thermodynamic stability)'].astype(int)


stability_df = df[['mof-name', 'target']]
# Display the first few rows of the combined DataFrame
display(stability_df.head())
print(stability_df.shape)
stability_df.to_csv('./data_sets/WS24s/stability_data.csv', index=False)

Unnamed: 0,mof-name,target
0,IYATEI,1
1,WOLREV,1
2,WOLRIZ,1
3,ALAMUW,1
4,UVUFUN,1


(964, 2)


In [11]:
write_lmdb(inpath='./data_sets/WS24s/stability_data.csv', outpath='./drive/MyDrive/X.C51_project/WS24-UniMOF/data/WS24', nthreads=8)

mean and std of target values are: 2.5975103734439835, 0.7514845293254065
length of data is 964
length of unique_id_list is 964
train_len:771
valid_len:96
test_len:97
771 96 97


  0%|          | 1/771 [00:01<13:44,  1.07s/it]

./data_sets/WS24s/CIFs/BELYEY.cif does not exist!


  4%|▎         | 28/771 [00:03<00:53, 13.89it/s]

./data_sets/WS24s/CIFs/IZEWIU.cif does not exist!


  6%|▌         | 47/771 [00:08<02:29,  4.84it/s]

./data_sets/WS24s/CIFs/TOKJAG.cif does not exist!
./data_sets/WS24s/CIFs/TOKJIO.cif does not exist!
./data_sets/WS24s/CIFs/TOKJOU.cif does not exist!
./data_sets/WS24s/CIFs/TOKKAH.cif does not exist!


  7%|▋         | 54/771 [00:16<05:38,  2.12it/s]

./data_sets/WS24s/CIFs/IZIKOT01.cif does not exist!


 22%|██▏       | 171/771 [00:30<01:09,  8.61it/s]

./data_sets/WS24s/CIFs/IVEQAC.cif does not exist!
./data_sets/WS24s/CIFs/LUNVUL.cif does not exist!
./data_sets/WS24s/CIFs/LUNWAS.cif does not exist!
./data_sets/WS24s/CIFs/RAXMUZ01.cif does not exist!./data_sets/WS24s/CIFs/WEHPUV.cif does not exist!



 25%|██▍       | 192/771 [00:36<01:30,  6.39it/s]

./data_sets/WS24s/CIFs/FORVUE.cif does not exist!
./data_sets/WS24s/CIFs/LITHOL.cif does not exist!
./data_sets/WS24s/CIFs/WEFXAI.cif does not exist!
./data_sets/WS24s/CIFs/GEBCAT.cif does not exist!
./data_sets/WS24s/CIFs/QAVWAN.cif does not exist!
./data_sets/WS24s/CIFs/ZONBAH.cif does not exist!


 61%|██████    | 469/771 [01:08<00:31,  9.51it/s]

./data_sets/WS24s/CIFs/LAZXIV.cif does not exist!./data_sets/WS24s/CIFs/LAZXOB.cif does not exist!



 70%|███████   | 540/771 [01:17<00:48,  4.77it/s]

./data_sets/WS24s/CIFs/CELZIE.cif does not exist!
./data_sets/WS24s/CIFs/ROSCOR.cif does not exist!


 73%|███████▎  | 566/771 [01:19<00:25,  7.95it/s]

./data_sets/WS24s/CIFs/HAWYUB.cif does not exist!


 76%|███████▌  | 584/771 [01:23<00:27,  6.84it/s]

./data_sets/WS24s/CIFs/GUGJEZ.cif does not exist!


 84%|████████▍ | 649/771 [01:27<00:11, 10.79it/s]

./data_sets/WS24s/CIFs/HENYUV.cif does not exist!
./data_sets/WS24s/CIFs/PITYUN04.cif does not exist!
./data_sets/WS24s/CIFs/VECVOQ.cif does not exist!
./data_sets/WS24s/CIFs/YEYBOV.cif does not exist!
./data_sets/WS24s/CIFs/YEYBOV01.cif does not exist!
./data_sets/WS24s/CIFs/BOHJOZ.cif does not exist!
./data_sets/WS24s/CIFs/BOHWIG.cif does not exist!
./data_sets/WS24s/CIFs/BOHXAZ.cif does not exist!
./data_sets/WS24s/CIFs/BOHXED.cif does not exist!
./data_sets/WS24s/CIFs/PAQDEQ.cif does not exist!
./data_sets/WS24s/CIFs/PAQDEQ01.cif does not exist!


 85%|████████▌ | 656/771 [01:43<01:00,  1.92it/s]

./data_sets/WS24s/CIFs/PEZCII.cif does not exist!


 90%|████████▉ | 691/771 [01:51<00:30,  2.61it/s]

./data_sets/WS24s/CIFs/FATLUJ.cif does not exist!
./data_sets/WS24s/CIFs/XICYIT.cif does not exist!
./data_sets/WS24s/CIFs/YINSIZ.cif does not exist!
./data_sets/WS24s/CIFs/DOGBEI.cif does not exist!


100%|██████████| 771/771 [02:14<00:00,  5.75it/s]


train.lmdb process 731 lines


  0%|          | 0/96 [00:00<?, ?it/s]

./data_sets/WS24s/CIFs/TOKJUA.cif does not exist!


 33%|███▎      | 32/96 [00:05<00:06,  9.32it/s]

./data_sets/WS24s/CIFs/CIZBAP.cif does not exist!


 39%|███▊      | 37/96 [00:08<00:13,  4.42it/s]

./data_sets/WS24s/CIFs/CELZOK.cif does not exist!


100%|██████████| 96/96 [00:20<00:00,  4.72it/s]

valid.lmdb process 93 lines



  0%|          | 0/97 [00:00<?, ?it/s]

./data_sets/WS24s/CIFs/BELYAU.cif does not exist!
./data_sets/WS24s/CIFs/TOKJEK.cif does not exist!


 33%|███▎      | 32/97 [00:03<00:04, 13.01it/s]

./data_sets/WS24s/CIFs/HIFHEL.cif does not exist!
./data_sets/WS24s/CIFs/CELZUQ.cif does not exist!


 36%|███▌      | 35/97 [00:06<00:15,  3.96it/s]

./data_sets/WS24s/CIFs/BOHJUF.cif does not exist!./data_sets/WS24s/CIFs/DODBUV.cif does not exist!

./data_sets/WS24s/CIFs/BOHKAM.cif does not exist!
./data_sets/WS24s/CIFs/XICNOO02.cif does not exist!
./data_sets/WS24s/CIFs/XICYUF.cif does not exist!
./data_sets/WS24s/CIFs/YEZKUL.cif does not exist!


100%|██████████| 97/97 [00:08<00:00, 11.57it/s]

test.lmdb process 87 lines



