#ChemFlow Demo

##Friendly advice:
### 1. Please follow the detailed step-by-step instructions for each code block to run the demo.

### 2. If you encounter any error, please re-run all code blocks again from the first code block.

### 3. Please use GPU (CUDA) to run this notebook.

In [None]:
#@title Install Packages (Need to run this block twice! See the following for details; ~3min) { display-mode: "form" }

# @markdown 1. First, run this code block to install required packages.

# @markdown 2. After all packages installed, Colab will promp you to restart the session.

# @markdown 3. Click restart session

# @markdown 4. Then run this code blocks again to load installed packages to the new session.

!pip install rdkit
!pip install PyTDC
!pip install selfies
!pip install lightning
!pip install pandarallel

In [None]:
# @title Clone ChemFlow Repo (~0.5min) { display-mode: "form" }
%%capture
!pip install rdkit
!pip install PyTDC
!pip install selfies
!pip install lightning
!pip install pandarallel

import os
os.chdir('/content')
CODE_DIR = 'ChemFlow'
!git clone https://anonymous.4open.science/r/ChemFlow-7409 $CODE_DIR
os.chdir(f'./{CODE_DIR}')

In [None]:
#@title Import Libraries (~0.5min) { display-mode: "form" }
%%capture

from typing import Callable
import wget

import random
import numpy as np
import matplotlib.pyplot as plt

import torch

from src.vae import load_vae_demo
from src.pinn.pde import load_wavepde
from src.pinn import PropGenerator, VAEGenerator
from src.predictor import Predictor

from rdkit import Chem
from rdkit.Chem import Draw
from tdc import Oracle

SmilesScorer = Callable[[str | list[str]], float | list[float]]

smiles2sa: SmilesScorer = Oracle(name="SA")
smiles2qed: SmilesScorer = Oracle(name="QED")
smiles2plogp: SmilesScorer = Oracle(name="LogP")  # This is actually pLogP
smiles2gsk3b: SmilesScorer = Oracle(name="GSK3B")
smiles2jnk3: SmilesScorer = Oracle(name="JNK3")
smiles2drd2: SmilesScorer = Oracle(name="DRD2")

PROP_FN = {
    "sa": smiles2sa,
    "qed": smiles2qed,
    "plogp": smiles2plogp,
    "gsk3b": smiles2gsk3b,
    "jnk3": smiles2jnk3,
    "drd2": smiles2drd2,
}

In [None]:
#@title Demo Setup (~1min) { display-mode: "form" }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
    raise ValueError("This notebook requires a GPU")

def normalize(x, step_size=None, relative=False):
    if step_size is None:
        return x
    if relative:
        return x * step_size
    try:
        return x / torch.norm(x, dim=-1, keepdim=True) * step_size
    except AttributeError:
        return x


def get_dm_vae():
    try:
        _, vae = load_vae_demo(device=device)
    except Exception:
        !wget -O temp.npy https://www.dropbox.com/scl/fi/mqfrzphivf1n6t12kx5vr/vocab.npy?rlkey=5g6zu596a4i56e5q7ete5pk83&st=h6o7pjvp&dl=1
        !mv temp.npy vocab.npy
        os.makedirs(f"checkpoints/vae", exist_ok=True)
        !wget -O checkpoints/vae.zip https://www.dropbox.com/scl/fo/sg1xs50b8lzbg3i5ezn37/AG6hzxYnnytWW6tAXCjhgvQ?rlkey=z74v3omdqg63lp3b5opsyqjf4&st=uebhl5k0&dl=1
        !unzip -o checkpoints/vae.zip -d checkpoints/vae
        !rm checkpoints/vae.zip
        _, vae = load_vae_demo(device=device)
    return _, vae


def get_model(model_name, prop_name, dm, vae):
    """Gets model by name."""

    if model_name == 'Random-1D':
        rand_1d_u_z = torch.zeros(1024).to(device)
        rand_1d_u_z[torch.randint(0, 1024, (1,))] = random.choice([-1, 1])
        return rand_1d_u_z, None

    elif model_name == 'Random':
        rand_u_z = torch.randn(1024, device=device)
        return rand_u_z, None

    elif model_name == 'ChemSpace':
        try:
            boundary = np.load(f"checkpoints/chemspace_boundary/zmc/boundary_{prop_name}.npy")
            boundary = torch.tensor(np.repeat(boundary, 1, axis=0)).to(device)
        except Exception:
            os.makedirs(f"checkpoints/chemspace_boundary", exist_ok=True)
            !wget -O checkpoints/chemspace_boundary.zip https://www.dropbox.com/scl/fo/t9n3j50g300tklka5ixun/AMkDmjr4rentHybNZ9AmqxY?rlkey=nd8cnaehsdwpsgyxe44p1jn4t&st=ecf4ooml&dl=1
            !unzip -o checkpoints/chemspace_boundary.zip -d checkpoints/chemspace_boundary
            !rm checkpoints/chemspace_boundary.zip
            boundary = np.load(f"checkpoints/chemspace_boundary/zmc/boundary_{prop_name}.npy")
            boundary = torch.tensor(np.repeat(boundary, 1, axis=0)).to(device)
        return boundary, None


    predictor = Predictor(dm.max_len * dm.vocab_size)

    try:
        prop_name = prop_name.lower()
        predictor.load_state_dict(
            torch.load(
                f"checkpoints/prop_predictor/{prop_name}/checkpoint.pt", map_location=device
            )
        )
    except Exception:
        os.makedirs(f"checkpoints/prop_predictor", exist_ok=True)
        !wget -O checkpoints/prop_predictor.zip https://www.dropbox.com/scl/fo/styd2ag4ck56c4yjj70uo/AKNpsVNrzT45WWlOZl38qyA?rlkey=rb80cacesxspwjsc93l8lhfil&st=pw1fyyig&dl=1
        !unzip -o checkpoints/prop_predictor.zip -d checkpoints/prop_predictor
        !rm checkpoints/prop_predictor.zip
        prop_name = prop_name.lower()
        predictor.load_state_dict(
            torch.load(
                f"checkpoints/prop_predictor/{prop_name}/checkpoint.pt", map_location=device
            )
        )

    # print('loaded predictor')
    for p in predictor.parameters():
        p.requires_grad = False

    sup_generator = PropGenerator(vae, predictor).to(device)
    for p in sup_generator.parameters():
        p.requires_grad = False

    unsup_generator = VAEGenerator(vae).to(device)

    try:
        unsup_pde = load_wavepde(
        checkpoint=f"checkpoints/wavepde/zmc/checkpoint.pt",
        generator=unsup_generator,
        k=10,
        device=device,
    )
    except Exception:
        os.makedirs(f"checkpoints/wavepde", exist_ok=True)
        !wget -O checkpoints/wavepde.zip https://www.dropbox.com/scl/fo/hz12c47c9m6fwtgrl98vf/ADsbM2662Vd_nwH_zKJMwms?rlkey=n19aqtqvth8zqdwj3zk7q5h24&st=tcs3otfh&dl=1
        !unzip -o checkpoints/wavepde.zip -d checkpoints/wavepde
        !rm checkpoints/wavepde.zip
        unsup_pde = load_wavepde(
        checkpoint=f"checkpoints/wavepde/zmc/checkpoint.pt",
        generator=unsup_generator,
        k=10,
        device=device,
    )

    # print('loaded unsup pde')
    wave_idx_map = {"plogp": 0, "sa": 6, "qed": 4, "drd2": 2, "jnk3": 0, "gsk3b": 0}
    unsup_pde_idx = wave_idx_map[prop_name]
    for p in unsup_pde.parameters():
        p.requires_grad = False

    try:
        sup_pde = load_wavepde(
            checkpoint=f"checkpoints/wavepde_prop/zmc/{prop_name}/checkpoint.pt",
            generator=sup_generator,
            k=1,
            device=device,
        )
    except Exception:
        os.makedirs(f"checkpoints/wavepde_prop", exist_ok=True)
        !wget -O checkpoints/wavepde_prop.zip https://www.dropbox.com/scl/fo/irebwgigzvn728spntxeb/APJV5CgE9nx7f7WYshpwQYM?rlkey=4crk9aco465qd6yblpokdcy2a&st=x6z9cpv6&dl=1
        !unzip -o checkpoints/wavepde_prop.zip -d checkpoints/wavepde_prop
        !rm checkpoints/wavepde_prop.zip
        sup_pde = load_wavepde(
            checkpoint=f"checkpoints/wavepde_prop/zmc/{prop_name}/checkpoint.pt",
            generator=sup_generator,
            k=1,
            device=device,
        )

    # print('loaded sup pde')
    sup_pde_idx = 0
    for p in sup_pde.parameters():
        p.requires_grad = False

    if model_name == 'Wave (unsupervised)':
        return unsup_pde, unsup_pde_idx
    elif model_name == 'Wave (supervised)':
        return sup_pde, sup_pde_idx
    elif model_name == 'Gradient Flow' or model_name == 'Langevin Dynamics':
        return sup_pde, None,

    try:
        unsup_hj = load_wavepde(
            checkpoint=f"checkpoints/hjpde/zmc/checkpoint.pt",
            generator=unsup_generator,
            k=10,
            device=device,
        )
    except Exception:
        os.makedirs(f"checkpoints/hjpde", exist_ok=True)
        !wget -O checkpoints/hjpde.zip https://www.dropbox.com/scl/fo/b5sxhaour6d45qt9d6hs0/AH_PCsLmkiNoQeZGbQHmHL0?rlkey=p70qyi120depfk4as1o6pbe25&st=34ze8kcp&dl=1
        !unzip -o checkpoints/hjpde.zip -d checkpoints/hjpde
        !rm checkpoints/hjpde.zip
        unsup_hj = load_wavepde(
            checkpoint=f"checkpoints/hjpde/zmc/checkpoint.pt",
            generator=unsup_generator,
            k=10,
            device=device,
        )

    # print('loaded unsup hj')
    hj_idx_map = {"plogp": 5, "sa": 1, "qed": 9, "drd2": 6, "jnk3": 5, "gsk3b": 5}
    unsup_hj_idx = hj_idx_map[prop_name]
    for p in unsup_hj.parameters():
        p.requires_grad = False

    try:
        sup_hj = load_wavepde(
            checkpoint=f"checkpoints/hjpde_prop/zmc/{prop_name}/checkpoint.pt",
            generator=sup_generator,
            k=1,
            device=device,
        )
    except Exception:
        os.makedirs(f"checkpoints/hjpde_prop", exist_ok=True)
        !wget -O checkpoints/hjpde_prop.zip https://www.dropbox.com/scl/fo/wnphbhkoicx8elggxlbg7/AH2Nr7Vjfeb3-qn55PWXJt8?rlkey=mcisxpw4w16hlk95dtb0jm78n&st=jdycagl1&dl=1
        !unzip -o checkpoints/hjpde_prop.zip -d checkpoints/hjpde_prop
        !rm checkpoints/hjpde_prop.zip
        sup_hj = load_wavepde(
            checkpoint=f"checkpoints/hjpde_prop/zmc/{prop_name}/checkpoint.pt",
            generator=sup_generator,
            k=1,
            device=device,
        )

    # print('loaded sup hj')
    sup_hj_idx = 0
    for p in sup_hj.parameters():
        p.requires_grad = False


    if model_name == 'HJ (unsupervised)':
        return unsup_hj, unsup_hj_idx
    elif model_name == 'HJ (supervised)':
        return sup_hj, sup_hj_idx
    else:
        raise ValueError(f"Invalid model name: {model_name}")



def sample(num_samples=1):
    """Samples latent codes."""
    z0 = torch.randn((num_samples, 1024), device=device)
    return z0


def synthesize_original(_code, _vae, _dm, _device):
    z = _code.clone()
    z = torch.from_numpy(z.detach().cpu().numpy()).float().to(_device)
    x = _vae.decode(z).exp()
    smiles = _dm.decode(x)[0]
    plogp_val = PROP_FN['plogp'](smiles)
    sa_val = PROP_FN['sa'](smiles)
    qed_val = PROP_FN['qed'](smiles)
    drd2_val = PROP_FN['drd2'](smiles)
    jnk3_val = PROP_FN['jnk3'](smiles)
    gsk3b_val = PROP_FN['gsk3b'](smiles)
    label = f'{smiles} \n plogp⬆ {plogp_val:.3f} \n sa⬇ {sa_val:.3f} \n qed⬆ {qed_val:.3f} \n drd2⬆ {drd2_val:.3f} \n jnk3⬆ {jnk3_val:.3f} \n gsk3b⬆ {gsk3b_val:.3f}'
    mol = Chem.MolFromSmiles(smiles)
    img = Draw.MolToImage(mol, legends=label)
    return img, label


def synthesize_final(_code, _vae, _dm, _device):
    z = torch.from_numpy(_code.detach().cpu().numpy()).float().to(_device)
    x = _vae.decode(z).exp()
    smiles = _dm.decode(x)[0]
    plogp_val = PROP_FN['plogp'](smiles)
    sa_val = PROP_FN['sa'](smiles)
    qed_val = PROP_FN['qed'](smiles)
    drd2_val = PROP_FN['drd2'](smiles)
    jnk3_val = PROP_FN['jnk3'](smiles)
    gsk3b_val = PROP_FN['gsk3b'](smiles)
    label = f'{smiles} \n plogp⬆ {plogp_val:.3f} \n sa⬇ {sa_val:.3f} \n qed⬆ {qed_val:.3f} \n drd2⬆ {drd2_val:.3f} \n jnk3⬆ {jnk3_val:.3f} \n gsk3b⬆ {gsk3b_val:.3f}'
    mol = Chem.MolFromSmiles(smiles)
    img = Draw.MolToImage(mol, legends=label)
    return img, label

properties = ["plogp", "sa", "qed", "drd2", "jnk3", "gsk3b"]

dm, vae = get_dm_vae()

In [None]:
#@title Choose a method (~5min, including download checkpoints) { display-mode: "form", run: "auto" }
model_name = "Wave (unsupervised)" #@param ['Wave (unsupervised)', 'Wave (supervised)', 'HJ (unsupervised)', 'HJ (supervised)']

model_lst = []
for prop_name in properties:
    model = get_model(model_name, prop_name, dm, vae)
    model_lst.append(model)

In [None]:
#@title Randomly sample a molecule (~0.5min){ display-mode: "form", run: "auto" }

base_codes = sample(num_samples=1)
z = base_codes.clone()

original_image, original_label = synthesize_original(z, vae, dm, device)

plt.figure(figsize=(10, 5))
plt.imshow(original_image)
plt.title('Initial Molecule\n'+original_label)
plt.axis('off')
print('')

In [None]:
#@title Adjust the traversal steps for each property { display-mode: "form", run: "auto" }
z = base_codes.clone()

num_semantics = 6

#@markdown plogP ⬆
plogp_steps = 0 #@param {type:"slider", min:-10, max:10, step:1}
#@markdown SA ⬇
sa_steps  = 0 #@param {type:"slider", min:-10, max:10, step:1}
#@markdown QED ⬆
qed_steps  = 0 #@param {type:"slider", min:-10, max:10, step:1}
#@markdown DRD2 ⬆
drd2_steps  = 0 #@param {type:"slider", min:-10, max:10, step:1}
#@markdown JNK3 ⬆
jnk3_steps  = 0 #@param {type:"slider", min:-10, max:10, step:1}
#@markdown GSK3B  ⬆
gsk3b_steps = 0 #@param {type:"slider", min:-10, max:10, step:1}

steps = {0 : plogp_steps, 1 : sa_steps, 2 : qed_steps, 3 : drd2_steps, 4 : jnk3_steps, 5 : gsk3b_steps}

step_size = 0.05
relative = True

for sem_idx, step in steps.items():
    if step < 0:
        total_steps = -step
    else:
        total_steps = step
    for t in range(total_steps):
        model = model_lst[sem_idx]
        if t == 0:
            u_z = 0
        else:
            if model_name == "Random-1D":
                u_z = model[0]
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "Random":
                u_z = model[0]
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "Wave (unsupervised)":
                u, u_z = model[0].inference(model[1], z, t % 10)
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "Wave (supervised)":
                u, u_z = model[0].inference(model[1], z, t % 10)
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "HJ (unsupervised)":
                u, u_z = model[0].inference(model[1], z, t % 10)
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "HJ (supervised)":
                u, u_z = model[0].inference(model[1], z, t % 10)
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "Langevin Dynamics":
                assert relative and step_size is not None
                z = z.detach().requires_grad_(True)
                u_z = torch.autograd.grad(model[0].generator(z).sum(), z)[0]
                u_z = u_z * step_size + torch.randn_like(u_z) * np.sqrt(2 * step_size) * 0.01
            elif model_name == "Gradient Flow":
                z = z.detach().requires_grad_(True)
                u_z = torch.autograd.grad(model[0].generator(z).sum(), z)[0]
                u_z = normalize(u_z, step_size, relative)
            elif model_name == "ChemSpace":
                u_z = normalize(model[0], step_size, relative)
            else:
                raise ValueError(f"Unknown model_name {model_name}")

        if properties[sem_idx] in ["sa", "molwt"]:
            u_z = -u_z
        if step < 0:
            z = z - u_z
        else:
            z = z + u_z

# original_image, original_label = synthesize_original(z, vae, dm, device)
# plt.figure(figsize=(10, 5))
# plt.imshow(original_image)
# plt.title('Initial Molecule\n'+original_label)
# plt.axis('off')
# print('')

final_image, final_label = synthesize_final(z, vae, dm, device)
plt.figure(figsize=(10, 5))
plt.imshow(final_image)
plt.title('After Traversal:\n'+final_label)
plt.axis('off')
print('')