# Prepare labels of SMILES dataset for multi-level multi-tasking
Use this notebook to convert the labels of a SMILES-based dataset, removing hydrogen atoms.

In [8]:
import pandas as pd
import numpy as np
from rdkit import Chem

In [11]:
def get_non_hydrogen_node_labels(smiles, node_labels):
        mol = Chem.MolFromSmiles(smiles)
        
        non_hydrogen_node_labels = [
             node_labels[i] for i, atom in enumerate(mol.GetAtoms()) if atom.GetAtomicNum() != 1
        ]

        if len(non_hydrogen_node_labels) == 0:
            non_hydrogen_node_labels = node_labels

        if isinstance(node_labels, list):
            return non_hydrogen_node_labels
        elif isinstance(node_labels, np.ndarray):
            return np.array(non_hydrogen_node_labels, dtype=node_labels.dtype)
        else:
            raise ValueError("Labels need to be list or np.ndarray")

def get_non_hydrogen_edge_labels(smiles, edge_labels):
    mol = Chem.MolFromSmiles(smiles)

    non_hydrogen_edge_labels = [
        label
        for i, bond in enumerate(mol.GetBonds())
        for label in (edge_labels[i], edge_labels[i])  # duplicate the label for both directions
        if bond.GetBeginAtom().GetSymbol() != "H" and bond.GetEndAtom().GetSymbol() != "H"
    ]

    if len(non_hydrogen_edge_labels) == 0:
       non_hydrogen_edge_labels = edge_labels

    if isinstance(edge_labels, list):
        return non_hydrogen_edge_labels
    elif isinstance(edge_labels, np.ndarray):
        return np.array(non_hydrogen_edge_labels, dtype=edge_labels.dtype)
    else:
        raise ValueError("Labels need to be list or np.ndarray")

def get_non_hydrogen_nodepair_labels(smiles, nodepair_labels):
    mol = Chem.MolFromSmiles(smiles)

    non_hydrogen_nodepair_labels = [
            nodepair_labels[i * mol.GetNumAtoms() + j]
            for i, atom_i in enumerate(mol.GetAtoms()) if atom_i.GetAtomicNum() != 1
            for j, atom_j in enumerate(mol.GetAtoms()) if atom_j.GetAtomicNum() != 1
    ]

    if len(non_hydrogen_nodepair_labels) == 0:
       non_hydrogen_nodepair_labels = nodepair_labels

    if isinstance(nodepair_labels, list):
        return non_hydrogen_nodepair_labels
    elif isinstance(nodepair_labels, np.ndarray):
        return np.array(non_hydrogen_nodepair_labels, dtype=nodepair_labels.dtype)
    else:
        raise ValueError("Labels need to be list or np.ndarray")

def extract_hydrogen_from_labels(data: pd.DataFrame, smiles_col: str, node_labels = [], edge_labels = [], nodepair_labels = []):
    """Extracts hydrogens from labels. Columns not explicity named are not modified.
    
    Parameters:
        data: input data
        smiles_col: name by which the smiles string can be fetched
        node_labels: name of node labels to convert
        edge_labels: name of edge labels to convert
        nodepair_labels: name of nodepair labels to convert
    Returns: 
        Converted data
    """
    def _extract_hydrogen(label_name, smiles, labels):
        if label_name in node_labels:
            return get_non_hydrogen_node_labels(smiles, labels)
        elif label_name in edge_labels:
            return get_non_hydrogen_edge_labels(smiles, labels)
        elif label_name in nodepair_labels:
            return get_non_hydrogen_nodepair_labels(smiles, labels)
        return labels

    def _extract_hydrogen_by_graph(graph: pd.Series):
        graph = graph.to_dict()
        return pd.Series({ k: _extract_hydrogen(k, graph[smiles_col], v)  for k, v in graph.items()})
    
    return data.apply(_extract_hydrogen_by_graph, axis="columns")

Read a `.parquet` file, extract hydrogens and write back to disk.

In [12]:
data = pd.read_parquet("../tests/fake_multilevel_data.parquet")
converted_data = extract_hydrogen_from_labels(data, "ordered_smiles", ["node_label_list", "node_label_np"], ["edge_label_list", "edge_label_np"], ["nodepair_label_list", "nodepair_label_np"])
converted_data.to_parquet("../tests/converted_fake_multilevel_data.parquet")