<a href="https://colab.research.google.com/github/gcourtade/GeqShift/blob/main/GeqShift.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##GeqShift

Easy to use carbohydrate <sup>13</sup>C and <sup>1</sup>H NMR chemical shift prediction using GeqShift: an E(3) equivariant graph neural network.

The original GeqShift code is available at https://github.com/mariabankestad/GeqShift.

The dataset of 1H and 13C NMR chemical shifts are available at https://github.com/mariabankestad/GeqShift.

Please read and cite the GeqShit paper:
[Bånkestad M., Dorst K. M., Widmalm G., Rönnols J. Carbohydrate NMR chemical shift prediction by GeqShift employing E(3) equivariant graph neural networks
*RSC Advances*, 2024](https://doi.org/10.1039/D4RA03428G)

##### Disclaimer
I made this Google Colab notebook for my own use and have no connection with the authors of the GeqShift paper. This notebook was is heavily inspired by and uses code from the [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb#scrollTo=mbaIO9pWjaN0) notebook. The model was trained using 100 conformations of the carbohydrates in the training set and 20 epochs. I cannot guarantee the correctness of the results generated using this code.

--[Gaston Courtade](https://folk.ntnu.no/courtade), 2025-04-10

In [2]:
#@title Input carbohydrate SMILES, then hit `Runtime` -> `Run all`
import os


query_smiles = 'C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O' #@param {type:"string"}
#@markdown  - Tip: Use [SMILES generator/checker](https://www.cheminfo.org/flavor/malaria/Utilities/SMILES_generator___checker/index.html) to edit SMILES.
jobname = 'bDGlc' #@param {type:"string"}
# number of models to use
num_conformations = 100 #@param {type: "integer"}
#@markdown - Specify how many conformations should be generated in the ensemble for chemical shift prediction
#@markdown - Tip: Best results are expected with around 100 conformations

if os.path.exists("predict"):
  !rm -r /content/predict

In [3]:
#@title Install dependencies
import torch
if "2.6.0" in torch.__version__:
  !pip uninstall torch torch-cluster torch-geometric torch-scatter torchaudio torchvision -y
  !pip install torch==2.4.0+cu124 torchvision torchaudio -f https://download.pytorch.org/whl/cu124/torch/
  import torch

torch_url = "https://pytorch-geometric.com/whl/torch-{}.html".format(torch.__version__).replace('+', '%2B')
!pip install torch-cluster torch-geometric torch-scatter -f $torch_url
!pip install e3nn
!pip install rdkit
!pip install mdanalysis
!pip install py3Dmol

if not os.path.exists("GeqShift"):
  !git clone https://github.com/gcourtade/GeqShift.git

checkpoint_13C = "_20240823_checkpoint_epoch-20_13C_nbr-confs-100.pkl"
if not os.path.exists(checkpoint_13C):
  checkpoint_url = "https://folk.ntnu.no/courtade/GeqShift_models/" + checkpoint_13C
  !wget $checkpoint_url

checkpoint_1H = "_20250411_checkpoint_epoch-20_1H_nbr-confs-100.pkl"
if not os.path.exists(checkpoint_1H):
  checkpoint_url = "https://folk.ntnu.no/courtade/GeqShift_models/" + checkpoint_1H
  !wget $checkpoint_url


Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0%2Bcu124.html


In [4]:
#@title Run prediction
!python GeqShift/predict_gpu.py --smiles_list "$query_smiles" --mol_name "$jobname" --checkpoint_path_13C $checkpoint_13C --checkpoint_path_1H $checkpoint_1H --nbr_confs $num_conformations

Traceback (most recent call last):
  File "/content/GeqShift/predict_gpu.py", line 8, in <module>
    from model.model import O3Transformer
  File "/content/GeqShift/model/model.py", line 3, in <module>
    from.layers import TransformerLayer_with_bond, CompuseTransformerNorm_bond_attr, FeedForwardNetwork, ComposeNetworkNorm, TransformerLayer_with_bond_invariant
  File "/content/GeqShift/model/layers.py", line 3, in <module>
    from torch_scatter import scatter
  File "/usr/local/lib/python3.11/dist-packages/torch_scatter/__init__.py", line 16, in <module>
    torch.ops.load_library(spec.origin)
  File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 1295, in load_library
    ctypes.CDLL(path)
  File "/usr/lib/python3.11/ctypes/__init__.py", line 376, in __init__
    self._handle = _dlopen(self._name, mode)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
OSError: /usr/local/lib/python3.11/dist-packages/torch_scatter/_scatter_cuda.so: undefined symbol: _ZN2at23SavedTensorDefa

In [4]:
#@title Display 3D structure with <sup>13</sup>C chemical shifts

import py3Dmol
import glob

def parse_pdb(pdb_file):
    atoms = []
    with open(pdb_file, 'r') as file:
        for line in file:
            if line.startswith("HETATM"):
                # Extract atom name, coordinates, and B-factor
                atom_name = line[12:16].strip()
                x = float(line[30:38].strip())
                y = float(line[38:46].strip())
                z = float(line[46:54].strip())
                bfactor = float(line[60:66].strip())
                atoms.append((atom_name, x, y, z, bfactor))
    return atoms

pdb_filename = f"/content/predict/{jobname}_13C_shifts.pdb"
pdb_file = glob.glob(pdb_filename)
atoms = parse_pdb(pdb_filename)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(pdb_file[0],'r').read(),'pdb')

view.setStyle({'stick': {}})

for atom_name, x, y, z, bfactor in atoms:
    if bfactor != 0:
        label_content = f"{bfactor:.2f}"
        view.addLabel(label_content, {'position': {'x': x, 'y': y, 'z': z}, 'fontSize': 12})

view.zoomTo()

view.show()


FileNotFoundError: [Errno 2] No such file or directory: '/content/predict/bDGlc_13C_shifts.pdb'

In [None]:
#@title Display 3D structure with <sup>1</sup>H chemical shifts

pdb_filename = f"/content/predict/{jobname}_1H_shifts.pdb"
pdb_file = glob.glob(pdb_filename)
atoms = parse_pdb(pdb_filename)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(pdb_file[0],'r').read(),'pdb')

view.setStyle({'stick': {}})

for atom_name, x, y, z, bfactor in atoms:
    if bfactor != 0:
        label_content = f"{bfactor:.2f}"
        view.addLabel(label_content, {'position': {'x': x, 'y': y, 'z': z}, 'fontSize': 12})

view.zoomTo()

view.show()


In [None]:
#@title Download the results
zip_path = f"/content/{jobname}.zip"
!zip -r  $zip_path /content/predict
from google.colab import files
files.download(zip_path)

The chemical shifts are stored in the B-factor column of the _shifts.pdb file. To visualize chemical shifts on the structure, you can label the atoms by B-factor:
1. Open `JOBNAME_shifts.pdb ` on PyMOL
2. Type `label all, b` or click `L > b-factor`