In [None]:
from unimol_tools.models.unimolv2 import UniMolV2Model, LinearHead
from unimol_tools.data.conformer import UniMolV2Feature, ConformerGen, mol2unimolv2, inner_smi2coords, create_mol_from_atoms_and_coords
from unimol_tools.data.datahub import DataHub, MolDataReader, TargetScaler
from unimol_tools.tasks.trainer import Trainer
from unimol_tools.train import MolTrain
from unimol_tools.predict import MolPredict
from unimol_tools.models.nnmodel import NNModel, NNMODEL_REGISTER
from unimol_tools.utils import logger
import os
import math
import joblib
import json
import numpy as np
import torch
from torch import nn
from rdkit.Chem import Descriptors, rdEHTTools

class UniMolFusion(UniMolV2Model):
    def __init__(self, output_dim=2, model_size='84m', **params):
        super().__init__(output_dim, model_size, **params)

        self.graph_feature_dim = params.get('graph_feature_dim', None)
        proj_dim = self.args.encoder_embed_dim
        if self.graph_feature_dim is not None:
            self.graph_proj = nn.Sequential(
                nn.Linear(self.graph_feature_dim, proj_dim),
                nn.ReLU(),
                nn.Dropout(self.args.pooler_dropout),
            )

        final_input_dim = self.args.encoder_embed_dim + (proj_dim if self.graph_feature_dim else 0)
        self.classification_head = LinearHead(
            input_dim=final_input_dim,
            num_classes=self.output_dim,
            pooler_dropout=self.args.pooler_dropout,
        )

    def forward(
        self,
        atom_feat,
        atom_mask,
        edge_feat,
        shortest_path,
        degree,
        pair_type,
        attn_bias,
        src_tokens,
        src_coord,
        return_repr=False,
        return_atomic_reprs=False,
        graph_feat=None,
        **kwargs
    ):
        pos = src_coord

        n_mol, n_atom = atom_feat.shape[:2]
        token_feat = self.embed_tokens(src_tokens)
        x = self.atom_feature({'atom_feat': atom_feat, 'degree': degree}, token_feat)

        dtype = self.dtype

        x = x.type(dtype)

        attn_mask = attn_bias.clone()
        attn_bias = torch.zeros_like(attn_mask)
        attn_mask = attn_mask.unsqueeze(1).repeat(
            1, self.args.encoder_attention_heads, 1, 1
        )
        attn_bias = attn_bias.unsqueeze(-1).repeat(1, 1, 1, self.args.pair_embed_dim)
        attn_bias = self.edge_feature(
            {'shortest_path': shortest_path, 'edge_feat': edge_feat}, attn_bias
        )
        attn_mask = attn_mask.type(self.dtype)

        atom_mask_cls = torch.cat(
            [
                torch.ones(n_mol, 1, device=atom_mask.device, dtype=atom_mask.dtype),
                atom_mask,
            ],
            dim=1,
        ).type(self.dtype)

        pair_mask = atom_mask_cls.unsqueeze(-1) * atom_mask_cls.unsqueeze(-2)

        def one_block(x, pos, return_x=False):
            delta_pos = pos.unsqueeze(1) - pos.unsqueeze(2)
            dist = delta_pos.norm(dim=-1)
            attn_bias_3d = self.se3_invariant_kernel(dist.detach(), pair_type)
            new_attn_bias = attn_bias.clone()
            new_attn_bias[:, 1:, 1:, :] = new_attn_bias[:, 1:, 1:, :] + attn_bias_3d
            new_attn_bias = new_attn_bias.type(dtype)
            x, pair = self.encoder(
                x,
                new_attn_bias,
                atom_mask=atom_mask_cls,
                pair_mask=pair_mask,
                attn_mask=attn_mask,
            )
            node_output = self.movement_pred_head(
                x[:, 1:, :],
                pair[:, 1:, 1:, :],
                attn_mask[:, :, 1:, 1:],
                delta_pos.detach(),
            )
            if return_x:
                return x, pair, pos + node_output
            else:
                return pos + node_output

        x, pair, pos = one_block(x, pos, return_x=True)
        cls_repr = x[:, 0, :]  # CLS token repr
        all_repr = x[:, :, :]  # all token repr

        if graph_feat is not None:
            graph_repr = self.graph_proj(graph_feat)
            cls_repr = torch.concat([cls_repr, graph_repr], dim=-1)

        if return_repr:
            filtered_tensors = []
            filtered_coords = []

            for tokens, coord in zip(src_tokens, src_coord):
                filtered_tensor = tokens[
                    (tokens != 0) & (tokens != 1) & (tokens != 2)
                ]  # filter out BOS(0), EOS(1), PAD(2)
                filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)]
                filtered_tensors.append(filtered_tensor)
                filtered_coords.append(filtered_coord)

            lengths = [
                len(filtered_tensor) for filtered_tensor in filtered_tensors
            ]  # Compute the lengths of the filtered tensors
            if return_atomic_reprs:
                cls_atomic_reprs = []
                atomic_symbols = []
                for i in range(len(all_repr)):
                    atomic_reprs = x[i, 1 : lengths[i] + 1, :]
                    atomic_symbol = filtered_tensors[i]
                    atomic_symbols.append(atomic_symbol)
                    cls_atomic_reprs.append(atomic_reprs)
                return {
                    'cls_repr': cls_repr,
                    'atomic_symbol': atomic_symbols,
                    'atomic_coords': filtered_coords,
                    'atomic_reprs': cls_atomic_reprs,
                }
            else:
                return {'cls_repr': cls_repr}

        logits = self.classification_head(cls_repr)
        return logits
    
    def batch_collate_fn(self, samples):
        batch, label = super().batch_collate_fn(samples)

        if 'graph_feat' in samples[0][0]:
            batch['graph_feat'] = torch.stack([torch.tensor(s[0]['graph_feat']) for s in samples])
                                              
        return batch, label

def calc_qc_properties(mol):
    if mol and mol.GetNumConformers() > 0:
        _, result = rdEHTTools.RunMol(mol)

        orbitals = result.GetOrbitalEnergies()

        HOMO = orbitals[math.ceil(result.numElectrons / 2)]
        LUMO = orbitals[math.ceil(result.numElectrons / 2) + 1]
        energy = result.totalEnergy
        charges = result.GetAtomicCharges()
        dipole = 0
        for pos, charge in zip(mol.GetConformer(0).GetPositions(), charges):
            dipole += pos * charge

        return [HOMO, LUMO, energy, np.linalg.norm(dipole)]

    return [None] * 4

def mol2unimolfusion(mol, max_atoms=128, remove_hs=True, gen_graph_features=True, **params):
    feat = mol2unimolv2(mol, max_atoms, remove_hs, **params)

    if gen_graph_features:
        graph_feat = calc_qc_properties(mol)
        graph_feat += list(Descriptors.rdMolDescriptors.CalcCrippenDescriptors(mol))
        graph_feat += [Descriptors.rdMolDescriptors.CalcTPSA(mol)]
        graph_feat += [Descriptors.BalabanJ(mol)]
        graph_feat += [Descriptors.BertzCT(mol)]
        feat['graph_feat'] = graph_feat

    return feat

class UniMolFusionFeature(UniMolV2Feature):
    def single_process(self, smiles):
        """
        Processes a single SMILES string to generate conformers using the specified method.

        :param smiles: (str) The SMILES string representing the molecule.
        :return: A unimolecular data representation (dictionary) of the molecule.
        :raises ValueError: If the conformer generation method is unrecognized.
        """
        if self.method == 'rdkit_random':
            mol = inner_smi2coords(
                smiles,
                seed=self.seed,
                mode=self.mode,
                remove_hs=self.remove_hs,
                return_mol=True,
            )
            return mol2unimolfusion(mol, self.max_atoms, remove_hs=self.remove_hs)
        else:
            raise ValueError(
                'Unknown conformer generation method: {}'.format(self.method)
            )

    def transform_raw(self, atoms_list, coordinates_list):

        inputs = []
        for atoms, coordinates in zip(atoms_list, coordinates_list):
            mol = create_mol_from_atoms_and_coords(atoms, coordinates)
            inputs.append(mol2unimolfusion(mol, self.max_atoms, remove_hs=self.remove_hs))
        return inputs

class MolTrainFusion(MolTrain):
    def __init__(self, task='classification', data_type='molecule', epochs=10, learning_rate=0.0001, batch_size=16, early_stopping=5, metrics="none", split='random', split_group_col='scaffold', kfold=5, save_path='./exp', remove_hs=False, smiles_col='SMILES', target_cols=None, target_col_prefix='TARGET', target_anomaly_check=False, smiles_check="filter", target_normalize="auto", max_norm=5, use_cuda=True, use_amp=True, use_ddp=False, use_gpu="all", freeze_layers=None, freeze_layers_reversed=False, load_model_dir=None, model_name='unimolv1', model_size='84m', **params):
        super().__init__(task, data_type, epochs, learning_rate, batch_size, early_stopping, metrics, split, split_group_col, kfold, save_path, remove_hs, smiles_col, target_cols, target_col_prefix, target_anomaly_check, smiles_check, target_normalize, max_norm, use_cuda, use_amp, use_ddp, use_gpu, freeze_layers, freeze_layers_reversed, load_model_dir, model_name, model_size, **params)
        self.config['grap']
    
    def fit(self, data):
        if self.config.model_name == 'unimolfusion':
            self.datahub = DataHubFusion(
                data=data, is_train=True, save_path=self.save_path, **self.config
            )
        else:
            self.datahub = DataHub(
                data=data, is_train=True, save_path=self.save_path, **self.config
            )
        self.data = self.datahub.data
        self.update_and_save_config()
        self.trainer = Trainer(save_path=self.save_path, **self.config)
        self.model = NNModel(self.data, self.trainer, **self.config)
        self.model.run()
        scalar = self.data['target_scaler']
        y_pred = self.model.cv['pred']
        y_true = np.array(self.data['target'])
        metrics = self.trainer.metrics
        if scalar is not None:
            y_pred = scalar.inverse_transform(y_pred)
            y_true = scalar.inverse_transform(y_true)

        if self.config["task"] in ['classification', 'multilabel_classification']:
            threshold = metrics.calculate_classification_threshold(y_true, y_pred)
            joblib.dump(threshold, os.path.join(self.save_path, 'threshold.dat'))

        self.cv_pred = y_pred
        return
    
class DataHubFusion(DataHub):
    def _init_data(self, **params):
        self.data = MolDataReader().read_data(self.data, self.is_train, **params)
        self.data['target_scaler'] = TargetScaler(
            self.ss_method, self.task, self.save_path
        )
        if self.task == 'regression':
            target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.float32)
            if self.is_train:
                self.data['target_scaler'].fit(target, self.save_path)
                self.data['target'] = self.data['target_scaler'].transform(target)
            else:
                self.data['target'] = target
        elif self.task == 'classification':
            target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32)
            self.data['target'] = target
        elif self.task == 'multiclass':
            target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32)
            self.data['target'] = target
            if not self.is_train:
                self.data['multiclass_cnt'] = self.multiclass_cnt
        elif self.task == 'multilabel_regression':
            target = (
                np.array(self.data['raw_target'])
                .reshape(-1, self.data['num_classes'])
                .astype(np.float32)
            )
            if self.is_train:
                self.data['target_scaler'].fit(target, self.save_path)
                self.data['target'] = self.data['target_scaler'].transform(target)
            else:
                self.data['target'] = target
        elif self.task == 'multilabel_classification':
            target = (
                np.array(self.data['raw_target'])
                .reshape(-1, self.data['num_classes'])
                .astype(np.int32)
            )
            self.data['target'] = target
        elif self.task == 'repr':
            self.data['target'] = self.data['raw_target']
        else:
            raise ValueError('Unknown task: {}'.format(self.task))

        if params.get('model_name', None) == 'unimolv1':
            if 'atoms' in self.data and 'coordinates' in self.data:
                no_h_list = ConformerGen(**params).transform_raw(
                    self.data['atoms'], self.data['coordinates']
                )
            else:
                smiles_list = self.data["smiles"]
                no_h_list = ConformerGen(**params).transform(smiles_list)
        elif params.get('model_name', None) == 'unimolv2':
            if 'atoms' in self.data and 'coordinates' in self.data:
                no_h_list = UniMolV2Feature(**params).transform_raw(
                    self.data['atoms'], self.data['coordinates']
                )
            else:
                smiles_list = self.data["smiles"]
                no_h_list = UniMolV2Feature(**params).transform(smiles_list)
        elif params.get('model_name', None) == 'unimolfusion':
            if 'atoms' in self.data and 'coordinates' in self.data:
                no_h_list = UniMolFusionFeature(**params).transform_raw(
                    self.data['atoms'], self.data['coordinates']
                )
            else:
                smiles_list = self.data["smiles"]
                no_h_list = UniMolFusionFeature(**params).transform(smiles_list)

        self.data['unimol_input'] = no_h_list

class NNModelFusion(NNModel):
    def _init_model(self, model_name, **params):
        if self.task in ['regression', 'multilabel_regression']:
            params['pooler_dropout'] = 0
            logger.debug("set pooler_dropout to 0 for regression task")
        else:
            pass
        freeze_layers = params.get('freeze_layers', None)
        freeze_layers_reversed = params.get('freeze_layers_reversed', False)
        if model_name in NNMODEL_REGISTER or model_name == 'unimolfusion':
            model = NNMODEL_REGISTER[model_name](**params) if model_name in NNMODEL_REGISTER \
                    else UniMolFusion(**params)
            if isinstance(freeze_layers, str):
                freeze_layers = freeze_layers.replace(' ', '').split(',')
            if isinstance(freeze_layers, list):
                for layer_name, layer_param in model.named_parameters():
                    should_freeze = any(
                        layer_name.startswith(freeze_layer)
                        for freeze_layer in freeze_layers
                    )
                    layer_param.requires_grad = not (
                        freeze_layers_reversed ^ should_freeze
                    )
        else:
            raise ValueError('Unknown model: {}'.format(self.model_name))
        return model
    
class MolPredictFusion(MolPredict):
    def predict(self, data, save_path=None, metrics='none'):
        self.save_path = save_path
        if not metrics or metrics != 'none':
            self.config.metrics = metrics
        ## load test data
        if self.config.model_name == 'unimolfusion':
            self.datahub = DataHubFusion(
                data=data, is_train=False, save_path=self.load_model, **self.config
            )
            self.config.use_dpp = False
            self.trainer = Trainer(
                save_path=self.load_model, **self.config
            )
            self.model = NNModelFusion(self.datahub.data, self.trainer, **self.config)
        else:
            self.datahub = DataHub(
                data=data, is_train=False, save_path=self.load_model, **self.config
            )
            self.config.use_ddp = False
            self.trainer = Trainer(
                save_path=self.load_model, **self.config
            )
            self.model = NNModel(self.datahub.data, self.trainer, **self.config)

        self.model.evaluate(self.trainer, self.load_model)

        y_pred = self.model.cv['test_pred']
        scalar = self.datahub.data['target_scaler']
        if scalar is not None:
            y_pred = scalar.inverse_transform(y_pred)

        df = self.datahub.data['raw_data'].copy()
        predict_cols = ['predict_' + col for col in self.target_cols]
        if self.task == 'multiclass' and self.config.multiclass_cnt is not None:
            prob_cols = ['prob_' + str(i) for i in range(self.config.multiclass_cnt)]
            df[prob_cols] = y_pred
            df[predict_cols] = np.argmax(y_pred, axis=1).reshape(-1, 1)
        elif self.task in ['classification', 'multilabel_classification']:
            threshold = joblib.load(
                open(os.path.join(self.load_model, 'threshold.dat'), "rb")
            )
            prob_cols = ['prob_' + col for col in self.target_cols]
            df[prob_cols] = y_pred
            df[predict_cols] = (y_pred > threshold).astype(int)
        else:
            prob_cols = predict_cols
            df[predict_cols] = y_pred
        if self.save_path:
            os.makedirs(self.save_path, exist_ok=True)
        if not (df[self.target_cols] == -1.0).all().all():
            metrics = self.trainer.metrics.cal_metric(
                df[self.target_cols].values, df[prob_cols].values
            )
            logger.info("final predict metrics score: \n{}".format(metrics))
            if self.save_path:
                joblib.dump(metrics, os.path.join(self.save_path, 'test_metric.result'))
                with open(os.path.join(self.save_path, 'test_metric.json'), 'w') as f:
                    json.dump(metrics, f)
        else:
            df.drop(self.target_cols, axis=1, inplace=True)
        if self.save_path:
            prefix = (
                data.split('/')[-1].split('.')[0] if isinstance(data, str) else 'test'
            )
            self.save_predict(df, self.save_path, prefix)
            logger.info("pipeline finish!")

        return y_pred

In [None]:
from rdkit import RDLogger

RDLogger.DisableLog('rdApp.*')

clf = MolTrainFusion(
    task='regression', 
    data_type='molecule', 
    model_name='unimolfusion',
    epochs=1, 
    batch_size=32, 
    metrics='mse',
    target_cols=['LogP'],
    use_cuda=False,
    kfold=1,
    remove_hs=True,
    target_anomaly_check=True,
    graph_feature_dim=9
)
pred = clf.fit(data='clean_train.csv')