# Inference with UniMOF Model for WS24

This notebook demonstrates how to do inference on a finetuned UniMOF model to predict water adsorption behaviors.

In [1]:
#@title Install Uni-Core and dependencies
%%bash
cd /content

# install dependencies if not done already
if [ ! -f ENV_READY ]; then
    pip3 install rdkit
    pip3 install lmdb
    pip3 install pymatgen
    touch ENV_READY
fi

UNICORE_GIT='https://github.com/dptech-corp/Uni-Core.git'
WS24_UNIMOF_GIT='https://github.com/emd-aquila/Xc51-MOFs.git'

# install Uni-Core if not done already
if [ ! -f UNICORE_READY ]; then
    git clone -b main ${UNICORE_GIT}
    # fix error in code before installing
    perl -pi -e 's/state = torch\.load\(f, map_location=torch\.device\("cpu"\)\)/state = torch.load(f, map_location=torch.device("cpu"), weights_only=False)/' ./Uni-Core/unicore/checkpoint_utils.py
    pip3 install -e ./Uni-Core
    git clone -b main ${WS24_UNIMOF_GIT}
    touch UNICORE_READY
fi

Collecting rdkit
  Downloading rdkit-2025.3.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2025.3.1-cp311-cp311-manylinux_2_28_x86_64.whl (34.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 34.6/34.6 MB 51.8 MB/s eta 0:00:00
Installing collected packages: rdkit
Successfully installed rdkit-2025.3.1
Collecting lmdb
  Downloading lmdb-1.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading lmdb-1.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (297 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 297.8/297.8 kB 5.4 MB/s eta 0:00:00
Installing collected packages: lmdb
Successfully installed lmdb-1.6.2
Collecting pymatgen
  Downloading pymatgen-2025.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting monty>=2025.1.9 (from pymatgen)
  Downloading monty-2025.3.3-py3-none-any.whl.metadata (3.6 kB)
Collecting palettable>=3.3.3 (from pymatgen)
  Downloading palettable-3.3.3

Cloning into 'Uni-Core'...
Cloning into 'Xc51-MOFs'...


## (optional) download a checkpoint from Dropbox
Find the checkpoint you want to download from the Dropbox, and copy the link here.

In [None]:
!wget 'https://www.dropbox.com/scl/fi/0fzemjaybokz673qqqeb9/checkpoint.best_f1_0.26.pt?rlkey=4vjtzmtcrqv00aj88lelqm5wf&st=93oljhun&dl=0' -O checkpoint.best_f1_0.26.pt

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [3]:
%cd /content/drive/MyDrive/Xc51-MOFs/WS24-UniMOF

/content/drive/MyDrive/Xc51-MOFs/WS24-UniMOF


## Inference

In [4]:
!chmod +x inference_scripts/infer-mostly-freeze-weighted.sh

In [5]:
!./inference_scripts/infer-mostly-freeze-weighted.sh

2025-05-12 15:31:47.362321: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747063907.384507    1445 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747063907.391162    1445 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-12 15:31:47.412807: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
fused_multi_tensor is not installed corrected
fused_rounding is not installed corrected
fused_layer_norm is not installed cor

In [None]:
import os

# Get the current working directory
current_directory = os.getcwd()
print("Current directory:", current_directory)


Current directory: /content/drive/MyDrive/Xc51-MOFs/WS24-UniMOF


In [None]:
import pickle
# Load and inspect the pickle file
with open('./evaluation/logs_train_tsne_.out.pkl', 'rb') as file:
    train_data = pickle.load(file)

with open('./evaluation/logs_valid_tsne_.out.pkl', 'rb') as file:
    valid_data = pickle.load(file)

with open('./evaluation/logs_test_tsne_.out.pkl', 'rb') as file:
    test_data = pickle.load(file)

# Display the contents of the loaded data
print(test_data)

[{'loss': 0.430419921875, 'logits': tensor([[-3.6367, -0.1173,  2.5176,  1.7500]], dtype=torch.float16), 'predict': tensor([3]), 'target': tensor([3.]), 'bsz': 1, 'sample_size': 1}, {'loss': 3.740234375, 'logits': tensor([[-2.3711,  2.5410,  1.6777, -0.8174]], dtype=torch.float16), 'predict': tensor([2]), 'target': tensor([4.]), 'bsz': 1, 'sample_size': 1}, {'loss': 5.3828125, 'logits': tensor([[-0.4792, -0.5312,  3.0352, -2.2871]], dtype=torch.float16), 'predict': tensor([3]), 'target': tensor([4.]), 'bsz': 1, 'sample_size': 1}, {'loss': 0.60546875, 'logits': tensor([[ 0.5254,  0.9966,  1.6738, -3.2969]], dtype=torch.float16), 'predict': tensor([3]), 'target': tensor([3.]), 'bsz': 1, 'sample_size': 1}, {'loss': 0.018707275390625, 'logits': tensor([[-2.3965, -0.5586,  4.1641, -0.5947]], dtype=torch.float16), 'predict': tensor([3]), 'target': tensor([3.]), 'bsz': 1, 'sample_size': 1}, {'loss': 0.5087890625, 'logits': tensor([[-3.3828,  1.9805,  2.4707, -0.5728]], dtype=torch.float16), '

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
from sklearn.preprocessing import label_binarize

# Suppose all your outputs are gathered like this:
all_preds = []
all_targets = []
all_logits = []

# During inference, collect from each batch
for log in logging_outputs:
    all_preds.append(log["predict"])
    all_targets.append(log["target"])
    all_logits.append(log["logits"])

# Concatenate
y_pred = torch.cat(all_preds).numpy()
y_true = torch.cat(all_targets).numpy()
y_logits = torch.cat(all_logits).numpy()  # shape: (N, num_classes)

# -------- 1. Confusion Matrix --------
plt.figure(figsize=(6, 5))
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(1, 5), yticklabels=range(1, 5))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# -------- 2. AUC-ROC Curves (Multiclass) --------
# Binarize targets for ROC curve
num_classes = y_logits.shape[1]
y_true_bin = label_binarize(y_true, classes=range(num_classes))  # e.g., [0, 1, 2, 3]

fpr = {}
tpr = {}
roc_auc = {}

for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_logits[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot
plt.figure(figsize=(8, 6))
for i in range(num_classes):
    plt.plot(fpr[i], tpr[i], label=f"Class {i} (AUC = {roc_auc[i]:.2f})")

plt.plot([0, 1], [0, 1], 'k--')  # chance
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve (Multiclass)")
plt.legend()
plt.grid(True)
plt.show()