### Load LlaMol

In [None]:
import rdkit
from rdkit import Chem
import rdkit.rdBase as rkrb
import rdkit.RDLogger as rkl
import os
import torch 
import logging
import numpy as np
from plot_utils import check_metrics
from sample import Sampler
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if "cuda" in device:
    # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
    dtype = "float16" if torch.cuda.is_available() else "float32"
else:
    dtype = "float32"

logger = rkl.logger()
logger.setLevel(rkl.ERROR)
rkrb.DisableLog("rdApp.error")

torch.set_num_threads(8)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# sampler
sampler = Sampler(
    load_path=os.path.join(
        os.getcwd(), "out", "llama2-M-Full-RSS-Canonical.pt"
    ),
    device=device,
    seed=1234,
    dtype=dtype,
    compile=False,
)

# setting   
num_samples = 100
df_comp = pd.read_parquet(os.path.join(os.getcwd(),"data","OrganiX13.parquet"))
df_comp = df_comp.sample(n=2_500_000)
comp_context_dict = {c: df_comp[c].to_numpy() for c in ["logp", "sascore", "mol_weight"]} 
comp_smiles = df_comp["smiles"]

### Convert SMILES

In [None]:
from typing import List, Dict
import json
from rdkit.Chem import AllChem

@torch.no_grad()
def convert_to_chemiscope(smiles_list : List[str], context_dict : Dict[str, List[float]]):
    # For more details on the file format: https://chemiscope.org/docs/tutorial/input-reference.html

    structures = []
    remove_list = []
    for i,smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            logging.info(f"Mol invalid: {smi} ! Skipping...")
            remove_list.append(i)
            continue

        res = AllChem.EmbedMolecule(mol,randomSeed=0xf00d, maxAttempts=20)
        # res = AllChem.Compute2DCoords(mol)

        if res != 0:
            logging.info(f"Could not calculate coordinates for {smi}! Skipping..")
            remove_list.append(i)
            continue
        

        conf = list(mol.GetConformers())[0]
        x,y,z = [],[],[]
        symbols = []
        for atom, coords in zip(mol.GetAtoms(), conf.GetPositions()):
            symbols.append(atom.GetSymbol())
            x.append(coords[0])
            y.append(coords[1])
            z.append(coords[2])
        
        structures.append({
            "size": len(x),
            "names": symbols,
            "x": x,
            "y": y,
            "z" : z
        })

    properties = {}
    
    for c in context_dict:
        properties[c] = {
            "target": "structure",
            "values": [v for i, v in enumerate(context_dict[c]) if i not in remove_list]
        }
        
        
    data = {
        "meta": {
            # // the name of the dataset
            "name": "Test Dataset",
            # // description of the dataset, OPTIONAL
            "description": "This contains data from generated molecules",
            # // authors of the dataset, OPTIONAL
            "authors": ["Niklas Dobberstein, niklas.dobberstein@scai.fraunhofer.de"],
            # // references for the dataset, OPTIONAL
            "references": [
                "",
            ],
        
        },
        "properties": properties,
        "structures": structures
    }
    
    out_path = os.path.join(os.getcwd(), "chemiscope_gen.json")
    with open(out_path, "w") as f:
        json.dump(data, f)

    logging.info(f"Wrote file {out_path}")

convert_to_chemiscope([
    "CC=O",
    "s1ccnc1"
], {"logp": [1.0,2.0], "sascore": [1.5,-2.0]})

### Condition: SMILES token sequence

In [None]:
context_smi = "C=C(C)C(=O)O[11CH3]"  # Define your SMILES input if needed

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw
import os

from plot_utils import calc_context_from_smiles

# Set parameters
context_cols_options = ["logp", "sascore", "mol_weight"]


temperature = 0.8
num_samples = 40  # Define number of samples
device = 'cpu'  # Define device for PyTorch

# Define sliders and parameters for logp, sascore, and mol_weight
logp = 0.0
sascore = 2.0
mol_weight = 3.0

# Define context dictionary
selected_context_cols = ["logp", "sascore", "mol_weight"]
context_dict = {}
for c in selected_context_cols:
    if c == "logp":
        val = logp
    elif c == "sascore":
        val = sascore
    else:
        val = mol_weight
    val = round(val, 2)
    context_dict[c] = val * torch.ones((num_samples,), device=device, dtype=torch.float)

# Generate SMILES using the provided context
smiles, context = sampler.generate(
    context_cols=context_dict,
    context_smi=context_smi,
    start_smiles=None,
    num_samples=num_samples,
    max_new_tokens=256,
    temperature=temperature,
    top_k=25,
    total_gen_steps=int(np.ceil(num_samples / 1000)),
    return_context=True
)

# Save SMILES to gen_smiles.txt
with open("gen_smiles.txt", "w") as f:
    for s in smiles:
        f.write(f"{s}\n")

# Display SMILES as RDKit molecules
def display_molecules(smiles_list, context_dict):
    molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
    
    # Convert RDKit molecules to images and store them in a list
    images = [Draw.MolToImage(mol) for mol in molecules]

    # Create a subplot grid to display the images
    num_images = len(images)
    num_cols = 1  # Number of columns in the grid
    num_rows = (num_images + num_cols - 1) // num_cols  # Calculate the number of rows

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(150, 150))
    fig.subplots_adjust(hspace=0.5)

    calculated_context = {c: [] for c in context_dict}
    for i, ax in enumerate(axes.flat):
        if i < num_images:
            ax.imshow(images[i])
            for j, c in enumerate(context_dict):
                smiles = smiles_list[i]
                smi_con = round(calc_context_from_smiles([smiles], c)[0], 2)
                calculated_context[c].append(smi_con)
                ax.text(0.5, -0.1 * j, f"{c}: {context_dict[c][i]} vs {smi_con}", 
                        transform=ax.transAxes, fontsize=10, ha='center')

            ax.axis('off')
        else:
            fig.delaxes(ax)  # Remove empty subplots if there are more rows than images

    plt.savefig("gen_mols.png")
    plt.close()

# Generate molecule display
display_molecules(smiles, context_dict)