In [1]:
# Let's import the stuff we need
import torch
from gflownet.config import init_empty, Config
from gflownet.models.graph_transformer import GraphTransformerGFN
from gflownet.envs.graph_building_env import GraphBuildingEnv
from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext
from gflownet.algo.trajectory_balance import TrajectoryBalance
import numpy as np
torch.__version__

from rdkit import Chem
import numpy as np
from rdkit import RDLogger

# Get the logger
lg = RDLogger.logger()

# Disable all RDKit warnings
lg.setLevel(RDLogger.CRITICAL)

def graph_to_obj(self, g):

    def safe_mol(m):
        """Convert None or invalid input into an empty RDKit molecule."""
        return m if isinstance(m, Chem.Mol) else Chem.RWMol().GetMol()

    # ---------------------------------------------------------
    # 1) Base molecule: always start with an empty mol
    # ---------------------------------------------------------
    mol = Chem.RWMol().GetMol()      # <-- ALWAYS valid, never None
    first = True                     # track first fragment

    # Atom index offsets
    offsets = np.cumsum(
        [0] + [self.frags_numatm[g.nodes[i]["v"]] for i in g]
    )

    # ---------------------------------------------------------
    # 2) Add all fragments
    # ---------------------------------------------------------
    for i in g.nodes:
        idx = g.nodes[i]["v"]

        # force fragment into a valid molecule
        frag = safe_mol(self.frags_mol[idx])

        if first:
            mol = frag
            first = False
        else:
            mol = Chem.CombineMols(mol, frag)

    # If graph had *no nodes*, mol is still empty â€” that's OK!
    rw = Chem.RWMol(mol)  # <-- NOW GUARANTEED SAFE

    # ---------------------------------------------------------
    # 3) Add bonds
    # ---------------------------------------------------------
    bond_atoms = []
    for a, b in g.edges:
        afrag = g.nodes[a]["v"]
        bfrag = g.nodes[b]["v"]

        src = g.edges[(a, b)].get("src_attach", 0)
        dst = g.edges[(a, b)].get("dst_attach", 0)

        try:
            u = int(self.frags_stems[afrag][src] + offsets[a])
            v = int(self.frags_stems[bfrag][dst] + offsets[b])
        except Exception:
            continue

        try:
            rw.AddBond(u, v, Chem.BondType.SINGLE)
            bond_atoms.extend([u, v])
        except Exception:
            pass

    mol = rw.GetMol()

    # ---------------------------------------------------------
    # 4) Remove one H from each attachment atom if possible
    # ---------------------------------------------------------
    for idx in bond_atoms:
        atom = mol.GetAtomWithIdx(idx)
        h = atom.GetNumExplicitHs()
        if h > 0:
            atom.SetNumExplicitHs(h - 1)

    # ---------------------------------------------------------
    # 5) Try sanitizing, but tolerate failure
    # ---------------------------------------------------------
    try:
        Chem.SanitizeMol(mol)
    except Exception:
        pass

    return mol




FragMolBuildingEnvContext.graph_to_obj = graph_to_obj


In [2]:
torch.manual_seed(1)  # For demonstration purposes
cfg = Config()
env = GraphBuildingEnv()
ctx = FragMolBuildingEnvContext()
model = GraphTransformerGFN(ctx, cfg)
algo = TrajectoryBalance(env, ctx, cfg)


In [3]:
dev = torch.device('cpu')
dev


device(type='cpu')

In [4]:
from gflownet.tasks.seh_frag import SEHTask
task = SEHTask(cfg)

In [5]:
from tqdm.notebook import tqdm
from logger import *
import rdkit.Chem as Chem
from rdkit.Chem import Draw
import base64
from gflownet.envs.graph_building_env import ActionIndex, Graph
from rdkit.Chem import AllChem, DataStructs

def imagefn(mols):
    out = []
    for mol in mols:
        if mol is None:
            out.append(None)
            continue
        svg_obj = Draw.MolsToGridImage([mol], molsPerRow=1, subImgSize=(200, 200), useSVG=True)
        svg_str = svg_obj.data if hasattr(svg_obj, "data") else svg_obj._repr_svg_()
        b64 = base64.b64encode(svg_str.encode("utf-8")).decode("utf-8")
        out.append(b64)
    return out

from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
import numpy as np

def featurefn(mol_list, radius=2, n_bits=1024):
    """
    Compute Morgan fingerprints for a list of RDKit molecules, handling partial/incomplete molecules.
    
    Returns:
        fps: np.ndarray of shape (len(mol_list), n_bits)
        success_mask: list of bool, True if fingerprint computed successfully, False if fallback used
    """
    fps = np.zeros((len(mol_list), n_bits), dtype=np.uint8)
    success_mask = [False] * len(mol_list)

    for i, mol in enumerate(mol_list):
        if mol is None or mol.GetNumAtoms() == 0:
            continue  # leave as zero and False
        try:
            mol_copy = Chem.Mol(mol)
            # Try to sanitize, skip kekulization to avoid RingInfo errors
            try:
                Chem.SanitizeMol(mol_copy, Chem.SANITIZE_ALL ^ Chem.SANITIZE_KEKULIZE)
            except:
                pass

            # Compute Morgan fingerprint
            fp = AllChem.GetMorganFingerprintAsBitVect(mol_copy, radius=radius, nBits=n_bits)
            arr = np.zeros((n_bits,), dtype=np.int32)
            DataStructs.ConvertToNumpyArray(fp, arr)
            fps[i] = arr.astype(np.uint8)

            success_mask[i] = True
        except Exception as e:
            # Leave as zero vector, success_mask[i] remains False
            #print(f"{i}: fingerprint failed -> {e}")
            pass

    return fps, success_mask




def textfn(mol_list):
    smiles_list = []
    for mol in mol_list:
        if mol is None:
            smiles_list.append("[INVALID_NONE]")
            continue
        try:
            # Try normal smiles
            smi = Chem.MolToSmiles(mol, canonical=False)
            if smi:
                smiles_list.append(smi)
                continue
        except:
            pass
        try:
            # Try unsanitized SMILES fallback
            smi = Chem.MolToSmiles(mol, canonical=False, isomericSmiles=False)
            smiles_list.append(smi if smi else "[INVALID_EMPTY]")
        except:
            smiles_list.append("[INVALID_ERROR]")
    return smiles_list
        

logger = VisLogger(
    #path="./seh_small",
    s0_included = True,
    fn_compute_features=featurefn,
    fn_state_to_text=textfn,
)

In [6]:
def sample(n):
    with torch.no_grad():  # We don't need to compute gradients here, they will be later
        trajs = algo.create_training_data_from_own_samples(model, n)

        objs = [ctx.graph_to_obj(i['result']) for i in trajs]
        obj_props, _ = task.compute_obj_properties(objs)
        log_rewards = task.cond_info_to_logreward({'beta': torch.ones(len(trajs))}, obj_props)
        batch = algo.construct_batch(trajs, None, log_rewards).to(dev)
        _, _, losses = algo.compute_batch_losses(model, batch)
        #losses.append(loss.item())
        avg_rewards.append((log_rewards).exp().mean().item())
    batch_idx = []
    states = []
    logprobs_bw=[]
    logprobs_fw=[]
    for j, t in enumerate(trajs):
        tl = len(t["traj"])
        batch_idx += [j]*tl
        for s in t["traj"]:
            states.append(ctx.graph_to_obj(s[0]))
        logprobs_bw.append(t["bck_logprobs"])
        # shift to next one
        logprobs_fw.append(torch.cat((torch.Tensor([0]), t["fwd_logprobs"][:-1])))
    return np.array(batch_idx), states, log_rewards.exp(), losses, torch.cat(logprobs_fw), torch.cat(logprobs_bw)
        
    

In [7]:
beta = 32
log_every = 2
iterations = 10
samples_per_log = 100
losses = []
avg_rewards = []
opt = torch.optim.Adam(model.parameters(), 3e-4)

for i in tqdm(range(iterations)):
    with torch.no_grad():  # We don't need to compute gradients here, they will be later
        trajs = algo.create_training_data_from_own_samples(model, 64)

        objs = [ctx.graph_to_obj(i['result']) for i in trajs]
        obj_props, _ = task.compute_obj_properties(objs)
        log_rewards = task.cond_info_to_logreward({'beta': torch.ones(len(trajs)) * beta}, obj_props)

    batch = algo.construct_batch(trajs, None, log_rewards).to(dev)
    loss, _, _ = algo.compute_batch_losses(model, batch)
    loss.backward()
    opt.step()
    opt.zero_grad()

    losses.append(loss.item())
    avg_rewards.append((log_rewards / beta).exp().mean().item())

    #logging
    if (i+1)%log_every==0:
        # sample new ones on policy
        batch_idx, states, rewards, loss, logprobs_fw, logprobs_bw = sample(samples_per_log)
        print(batch_idx.shape)
        print(len(states))
        print(rewards.shape)
        print(len(loss))
        print(logprobs_fw.shape)
        print(logprobs_bw.shape)
    
        logger.log(
            batch_idx=batch_idx,
            states=states,
            total_reward=rewards,
            loss = loss,
            iteration=i,
            logprobs_backward = logprobs_bw,
            logprobs_forward = logprobs_fw
        )
        logger.write_to_db()

  0%|          | 0/10 [00:00<?, ?it/s]

(1505,)
1505
torch.Size([100])
100
torch.Size([1505])
torch.Size([1505])
Building initial graph...
Initial graph built successfully

=== Graph Statistics ===
Total nodes: 866
Total edges: 915
Node types:
  final: 99
  standard: 766
  start: 1

Creating indexes...

Starting truncation...
Step 1: Identifying removable nodes...
  Found 738 removable nodes
Step 2: Creating successor lookup table...
Step 3: Bypassing chains...
  Iteration 1: Updating 84 edges...
  Iteration 2: Updating 77 edges...
  Iteration 3: Updating 64 edges...
  Iteration 4: Updating 60 edges...
  Iteration 5: Updating 51 edges...
  Iteration 6: Updating 49 edges...
  Iteration 7: Updating 46 edges...
  Iteration 8: Updating 44 edges...
  Iteration 9: Updating 40 edges...
  Iteration 10: Updating 40 edges...
  Iteration 11: Updating 38 edges...
  Iteration 12: Updating 37 edges...
  Iteration 13: Updating 36 edges...
  Iteration 14: Updating 36 edges...
  Iteration 15: Updating 28 edges...
  Iteration 16: Updating 16 