<a href="https://colab.research.google.com/github/gcourtade/GeqShift/blob/main/GeqShift.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##GeqShift

Easy to use carbohydrate <sup>13</sup>C and <sup>1</sup>H NMR chemical shift prediction using GeqShift: an E(3) equivariant graph neural network.

The original GeqShift code is available at https://github.com/mariabankestad/GeqShift.

The dataset of 1H and 13C NMR chemical shifts are available at https://github.com/mariabankestad/GeqShift.

Please read and cite the GeqShit paper:
[Bånkestad M., Dorst K. M., Widmalm G., Rönnols J. Carbohydrate NMR chemical shift prediction by GeqShift employing E(3) equivariant graph neural networks
*RSC Advances*, 2024](https://doi.org/10.1039/D4RA03428G)

##### Disclaimer
I made this Google Colab notebook for my own use and have no connection with the authors of the GeqShift paper. This notebook was is heavily inspired by and uses code from the [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb#scrollTo=mbaIO9pWjaN0) notebook. The model was trained using 100 conformations of the carbohydrates in the training set and 20 epochs. I cannot guarantee the correctness of the results generated using this code.

--[Gaston Courtade](https://folk.ntnu.no/courtade), 2025-04-10

In [8]:
#@title Input carbohydrate SMILES, then hit `Runtime` -> `Run all`
import os


query_smiles = 'C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O' #@param {type:"string"}
#@markdown  - Tip: Use [SMILES generator/checker](https://www.cheminfo.org/flavor/malaria/Utilities/SMILES_generator___checker/index.html) to edit SMILES.
jobname = 'bDGlc' #@param {type:"string"}
# number of models to use
num_conformations = 100 #@param {type: "integer"}
#@markdown - Specify how many conformations should be generated in the ensemble for chemical shift prediction
#@markdown - Tip: Best results are expected with around 100 conformations

if os.path.exists("predict"):
  !rm -r /content/predict

In [4]:
#@title Install dependencies
import torch
if "2.6.0" in torch.__version__:
  !pip uninstall torch torch-cluster torch-geometric torch-scatter torchaudio torchvision -y
  !pip install torch==2.4.0+cu124 torchvision torchaudio -f https://download.pytorch.org/whl/cu124/torch/
  import torch

torch_url = "https://pytorch-geometric.com/whl/torch-{}.html".format(torch.__version__).replace('+', '%2B')
!pip install torch-cluster torch-geometric torch-scatter -f $torch_url
!pip install e3nn
!pip install rdkit
!pip install mdanalysis
!pip install py3Dmol

if not os.path.exists("GeqShift"):
  !git clone https://github.com/gcourtade/GeqShift.git

checkpoint_13C = "_20240823_checkpoint_epoch-20_13C_nbr-confs-100.pkl"
if not os.path.exists(checkpoint_13C):
  checkpoint_url = "https://folk.ntnu.no/courtade/GeqShift_models/" + checkpoint_13C
  !wget $checkpoint_url

checkpoint_1H = "_20250410_checkpoint_epoch-7_1H_nbr-confs-100.pkl"
if not os.path.exists(checkpoint_1H):
  checkpoint_url = "https://folk.ntnu.no/courtade/GeqShift_models/" + checkpoint_1H
  !wget $checkpoint_url


Found existing installation: torch 2.4.0+cu124
Uninstalling torch-2.4.0+cu124:
  Successfully uninstalled torch-2.4.0+cu124
Found existing installation: torch_cluster 1.6.3+pt24cu124
Uninstalling torch_cluster-1.6.3+pt24cu124:
  Successfully uninstalled torch_cluster-1.6.3+pt24cu124
Found existing installation: torch-geometric 2.6.1
Uninstalling torch-geometric-2.6.1:
  Successfully uninstalled torch-geometric-2.6.1
Found existing installation: torch_scatter 2.1.2+pt24cu124
Uninstalling torch_scatter-2.1.2+pt24cu124:
  Successfully uninstalled torch_scatter-2.1.2+pt24cu124
Found existing installation: torchaudio 2.4.0
Uninstalling torchaudio-2.4.0:
  Successfully uninstalled torchaudio-2.4.0
Found existing installation: torchvision 0.19.0
Uninstalling torchvision-0.19.0:
  Successfully uninstalled torchvision-0.19.0
Looking in links: https://download.pytorch.org/whl/cu124/torch/
Collecting torch==2.4.0+cu124
  Using cached https://download.pytorch.org/whl/cu124/torch-2.4.0%2Bcu124-cp31

Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0%2Bcu124.html
Collecting torch-cluster
  Using cached https://data.pyg.org/whl/torch-2.4.0%2Bcu124/torch_cluster-1.6.3%2Bpt24cu124-cp311-cp311-linux_x86_64.whl (3.4 MB)
Collecting torch-geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting torch-scatter
  Using cached https://data.pyg.org/whl/torch-2.4.0%2Bcu124/torch_scatter-2.1.2%2Bpt24cu124-cp311-cp311-linux_x86_64.whl (10.7 MB)
Using cached torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
Installing collected packages: torch-scatter, torch-cluster, torch-geometric
Successfully installed torch-cluster-1.6.3+pt24cu124 torch-geometric-2.6.1 torch-scatter-2.1.2+pt24cu124
--2025-04-10 12:39:48--  https://folk.ntnu.no/courtade/GeqShift_models/_20240823_checkpoint_epoch-20_13C_nbr-confs-100.pkl
Resolving folk.ntnu.no (folk.ntnu.no)... 129.241.56.95, 2001:700:300:3::95
Connecting to folk.ntnu.no (folk.ntnu.no)|129.241.56.95|:443... connected

In [5]:
#@title Run prediction
!python GeqShift/predict_gpu.py --smiles_list "$query_smiles" --mol_name "$jobname" --checkpoint_path_13C $checkpoint_13C --checkpoint_path_1H $checkpoint_1H --nbr_confs $num_conformations

  @torch.cuda.amp.autocast(enabled=False)
{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.0, }
{0: [0.0,0], 1: [0.0,0], 2: [0.0,0], 3: [0.0,0], 4: [0.0,0], 5: [0.0,0], }
Loaded checkpoint from _20240823_checkpoint_epoch-20_13C_nbr-confs-100.pkl
Loaded checkpoint from _20250410_checkpoint_epoch-7_1H_nbr-confs-100.pkl
Predict 13C chemical shifts ...
Predict 1H chemical shifts ...
Predictions saved to predict//bDGlc_predictions_13C.pkl
Predictions saved to predict//bDGlc_predictions_1H.pkl
[12:40:19] Molecule does not have explicit Hs. Consider calling AddHs()
[12:40:19] Molecule does not have explicit Hs. Consider calling AddHs()
Coordinates saved to predict//bDGlc.pdb
Coordinates with shifts in Bfactor col saved to predict//bDGlc_13C_shifts.pdb
Coordinates with shifts in Bfactor col saved to predict//bDGlc_1H_shifts.pdb
Predicted chemical shifts:
13C_idx 13C_CS 1H_CS
0 61.84 3.81
1 75.22 3.60
2 70.69 3.42
3 75.69 3.58
4 74.22 3.35
5 95.42 4.85


In [6]:
#@title Display 3D structure with <sup>13</sup>C chemical shifts

import py3Dmol
import glob

def parse_pdb(pdb_file):
    atoms = []
    with open(pdb_file, 'r') as file:
        for line in file:
            if line.startswith("HETATM"):
                # Extract atom name, coordinates, and B-factor
                atom_name = line[12:16].strip()
                x = float(line[30:38].strip())
                y = float(line[38:46].strip())
                z = float(line[46:54].strip())
                bfactor = float(line[60:66].strip())
                atoms.append((atom_name, x, y, z, bfactor))
    return atoms

pdb_filename = f"/content/predict/{jobname}_13C_shifts.pdb"
pdb_file = glob.glob(pdb_filename)
atoms = parse_pdb(pdb_filename)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(pdb_file[0],'r').read(),'pdb')

view.setStyle({'stick': {}})

for atom_name, x, y, z, bfactor in atoms:
    if bfactor != 0:
        label_content = f"{bfactor:.2f}"
        view.addLabel(label_content, {'position': {'x': x, 'y': y, 'z': z}, 'fontSize': 12})

view.zoomTo()

view.show()


In [7]:
#@title Display 3D structure with <sup>1</sup>H chemical shifts

pdb_filename = f"/content/predict/{jobname}_1H_shifts.pdb"
pdb_file = glob.glob(pdb_filename)
atoms = parse_pdb(pdb_filename)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(pdb_file[0],'r').read(),'pdb')

view.setStyle({'stick': {}})

for atom_name, x, y, z, bfactor in atoms:
    if bfactor != 0:
        label_content = f"{bfactor:.2f}"
        view.addLabel(label_content, {'position': {'x': x, 'y': y, 'z': z}, 'fontSize': 12})

view.zoomTo()

view.show()


In [None]:
#@title Download the results
zip_path = f"/content/{jobname}.zip"
!zip -r  $zip_path /content/predict
from google.colab import files
files.download(zip_path)

  adding: content/predict/ (stored 0%)
  adding: content/predict/bDGlc_shifts.pdb (deflated 65%)
  adding: content/predict/bDGlc_conformations.pickle (deflated 13%)
  adding: content/predict/bDGlc_predictions.pkl (deflated 16%)
  adding: content/predict/bDGlc_pred_data.pkl (deflated 94%)
  adding: content/predict/bDGlc.pdb (deflated 73%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

The chemical shifts are stored in the B-factor column of the _shifts.pdb file. To visualize chemical shifts on the structure, you can label the atoms by B-factor:
1. Open `JOBNAME_shifts.pdb ` on PyMOL
2. Type `label all, b` or click `L > b-factor`