In [4]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from dgl import function as fn
from dgl.nn.pytorch.softmax import edge_softmax


class AttentiveGRU1(nn.Module):
    """Update node features with attention and GRU.

    Parameters
    ----------
    node_feat_size : int
        Size for the input node (atom) features.
    edge_feat_size : int
        Size for the input edge (bond) features.
    edge_hidden_size : int
        Size for the intermediate edge (bond) representations.
    dropout : float
        The probability for performing dropout.
    """
    def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
        super(AttentiveGRU1, self).__init__()

        self.edge_transform = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(edge_feat_size, edge_hidden_size)
        )
        self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)

    def forward(self, g, edge_logits, edge_feats, node_feats):
        """
        Parameters
        ----------
        g : DGLGraph
        edge_logits : float32 tensor of shape (E, 1)
            The edge logits based on which softmax will be performed for weighting
            edges within 1-hop neighborhoods. E represents the number of edges.
        edge_feats : float32 tensor of shape (E, M1)
            Previous edge features.
        node_feats : float32 tensor of shape (V, M2)
            Previous node features.

        Returns
        -------
        float32 tensor of shape (V, M2)
            Updated node features.
        """
        g = g.local_var()
        g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
        g.update_all(fn.copy_e('e', 'm'), fn.sum('m', 'c'))
        context = F.elu(g.ndata['c'])
        return F.relu(self.gru(context, node_feats))

class AttentiveGRU2(nn.Module):
    """Update node features with attention and GRU.

    Parameters
    ----------
    node_feat_size : int
        Size for the input node (atom) features.
    edge_hidden_size : int
        Size for the intermediate edge (bond) representations.
    dropout : float
        The probability for performing dropout.
    """
    def __init__(self, node_feat_size, edge_hidden_size, dropout):
        super(AttentiveGRU2, self).__init__()

        self.project_node = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(node_feat_size, edge_hidden_size)
        )
        self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)

    def forward(self, g, edge_logits, node_feats):
        """
        Parameters
        ----------
        g : DGLGraph
        edge_logits : float32 tensor of shape (E, 1)
            The edge logits based on which softmax will be performed for weighting
            edges within 1-hop neighborhoods. E represents the number of edges.
        node_feats : float32 tensor of shape (V, M2)
            Previous node features.

        Returns
        -------
        float32 tensor of shape (V, M2)
            Updated node features.
        """
        g = g.local_var()
        g.edata['a'] = edge_softmax(g, edge_logits)
        g.ndata['hv'] = self.project_node(node_feats)

        g.update_all(fn.u_mul_e('hv', 'a', 'm'), fn.sum('m', 'c'))
        context = F.elu(g.ndata['c'])
        return F.relu(self.gru(context, node_feats))

class GetContext(nn.Module):
    """Generate context for each node (atom) by message passing at the beginning.

    Parameters
    ----------
    node_feat_size : int
        Size for the input node (atom) features.
    edge_feat_size : int
        Size for the input edge (bond) features.
    graph_feat_size : int
        Size of the learned graph representation (molecular fingerprint).
    dropout : float
        The probability for performing dropout.
    """
    def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
        super(GetContext, self).__init__()

        self.project_node = nn.Sequential(
            nn.Linear(node_feat_size, graph_feat_size),
            nn.LeakyReLU()
        )
        self.project_edge1 = nn.Sequential(
            nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
            nn.LeakyReLU()
        )
        self.project_edge2 = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(2 * graph_feat_size, 1),
            nn.LeakyReLU()
        )
        self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
                                           graph_feat_size, dropout)

    def apply_edges1(self, edges):
        """Edge feature update."""
        return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}

    def apply_edges2(self, edges):
        """Edge feature update."""
        return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}

    def forward(self, g, node_feats, edge_feats):
        """
        Parameters
        ----------
        g : DGLGraph
            Constructed DGLGraphs.
        node_feats : float32 tensor of shape (V, N1)
            Input node features. V for the number of nodes and N1 for the feature size.
        edge_feats : float32 tensor of shape (E, N2)
            Input edge features. E for the number of edges and N2 for the feature size.

        Returns
        -------
        float32 tensor of shape (V, N3)
            Updated node features.
        """
        g = g.local_var()
        g.ndata['hv'] = node_feats
        g.ndata['hv_new'] = self.project_node(node_feats)
        temp_hv_new = g.ndata['hv_new']
        g.edata['he'] = edge_feats

        g.apply_edges(self.apply_edges1)
        temp_he1_new = g.edata['he1']
        g.edata['he1'] = self.project_edge1(g.edata['he1'])
        temp_ehe1_new = g.edata['he1']
        g.apply_edges(self.apply_edges2)
        temp_he2_new = g.edata['he2']
        logits = self.project_edge2(g.edata['he2'])

        return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])

class GNNLayer(nn.Module):
    """GNNLayer for updating node features.

    Parameters
    ----------
    node_feat_size : int
        Size for the input node features.
    graph_feat_size : int
        Size for the input graph features.
    dropout : float
        The probability for performing dropout.
    """
    def __init__(self, node_feat_size, graph_feat_size, dropout):
        super(GNNLayer, self).__init__()

        self.project_edge = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(2 * node_feat_size, 1),
            nn.LeakyReLU()
        )
        self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
        self.g = None

    def apply_edges(self, edges):
        """Edge feature update by concatenating the features of the destination
        and source nodes."""
        return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}

    def forward(self, g, node_feats):
        """
        Parameters
        ----------
        g : DGLGraph
            Constructed DGLGraphs.
        node_feats : float32 tensor of shape (V, N1)
            Input node features. V for the number of nodes and N1 for the feature size.

        Returns
        -------
        float32 tensor of shape (V, N1)
            Updated node features.
        """
        if 'he' in g.edata:
            pass
        g = g.local_var()
        g.ndata['hv'] = node_feats
        g.apply_edges(self.apply_edges)
        logits = self.project_edge(g.edata['he'])

        self.g = g

        if 'he' in g.edata:
            pass

        return self.attentive_gru(g, logits, node_feats)


class Pka_acidic(nn.Module):

    def __init__(self,
                 node_feat_size,
                 edge_feat_size,
                 num_layers,
                 graph_feat_size,
                 output_size,
                 dropout):
        super(Pka_acidic,self).__init__()

        self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
        self.gnn_layers = nn.ModuleList()
        for i in range(num_layers - 1):
            self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))

        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(graph_feat_size, output_size)
        )

    def forward(self, g, node_feats, edge_feats, get_node_weight=False):

        mask = torch.sum(g.ndata['h'][:,-4:],dim = 1) * (1 - g.ndata['h'][:,0])
        mask = 1/mask -1
        node_feats = self.init_context(g, node_feats, edge_feats)
        for gnn in self.gnn_layers:
            node_feats = gnn(g, node_feats)
        atom_pka = self.predict(node_feats)
        # atom_pka = -(atom_pka + mask.reshape(-1,1))
        g.ndata['h'] = torch.pow(10,atom_pka)
        g_feats = -torch.log10(dgl.sum_nodes(g, 'h')+10**-38)
        return g_feats


class Pka_basic(nn.Module):
    def __init__(self,
                 node_feat_size,
                 edge_feat_size,
                 num_layers,
                 graph_feat_size,
                 output_size,
                 dropout):
        super(Pka_basic, self).__init__()

        self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
        self.gnn_layers = nn.ModuleList()
        for i in range(num_layers - 1):
            self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))

        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(graph_feat_size, output_size)
        )

    def forward(self, g, node_feats, edge_feats, get_node_weight=False):
        mask = g.ndata['h'][:,1] * (1 - g.ndata['h'][:,61])
        mask = -1/mask +1
        node_feats = self.init_context(g, node_feats, edge_feats)
        for gnn in self.gnn_layers:
            node_feats = gnn(g, node_feats)
        atom_pka = self.predict(node_feats)
        atom_pka = (atom_pka + mask.reshape(-1,1))
        g.ndata['h'] = torch.pow(10,atom_pka)
        g_feats = torch.log10(dgl.sum_nodes(g, 'h'))
        return g_feats


class Pka_acidic_view(nn.Module):
    def __init__(self,
                 node_feat_size,
                 edge_feat_size,
                 num_layers,
                 graph_feat_size,
                 output_size,
                 dropout):
        super(Pka_acidic_view,self).__init__()

        self.device = torch.device("cpu")

        self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
        self.gnn_layers = nn.ModuleList()
        for i in range(num_layers - 1):
            self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))

        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(graph_feat_size, output_size)
        )

    def train_mode(self, train_type: str):
        if train_type == 'all_layers':
            self.init_context.train()
            self.gnn_layers.train()
            self.predict.train()
        elif train_type == 'predictor_and_readout':
            self.init_context.eval()
            self.gnn_layers.train()
            self.predict.train()
        elif train_type == 'all_layers' or train_type is None:
            self.init_context.eval()
            self.gnn_layers.eval()
            self.predict.train()

    def forward(self, g, node_feats, edge_feats, get_node_weight=False):

        mask = torch.sum(g.ndata['h'][:,-4:],dim = 1) * (1 - g.ndata['h'][:,0])
        mask = 1/mask -1
        node_feats = self.init_context(g, node_feats, edge_feats)
        for gnn in self.gnn_layers:
            node_feats = gnn(g, node_feats)
        atom_pka = self.predict(node_feats)
        atom_pka = -(atom_pka + mask.reshape(-1,1))
        g.ndata['h'] = torch.pow(10,atom_pka)
        g_feats = -torch.log10(dgl.sum_nodes(g, 'h'))
        atom_pka_out = atom_pka * -1
        atom_pka_out = torch.squeeze(atom_pka_out)

        return g_feats, atom_pka_out.detach().cpu().numpy().tolist()


    def lrp(self, g, node_feats, edge_feats):

        node_feats_original = node_feats.clone()
        edge_feats_original = edge_feats.clone()

        all_node_relevances = {}
        all_edge_relevances = {}
        # node relevance
        for node_feat_index in range(len(node_feats)):

            self.zero_grad()

            x_node = node_feats_original.clone()
            x_node = Variable(x_node.data, requires_grad=True)

            h0 = x_node

            mask = torch.zeros(x_node.shape).to(self.device)
            mask[node_feat_index] = 1

            x_node = x_node * mask + (1 - mask) * x_node.data

            # TODO
            # AGGREGATE STEP FOR WALKS(currently only nodes)

            # forward
            node_feats = self.init_context(g, x_node, edge_feats)
            for gnn in self.gnn_layers:
                node_feats = gnn(g, node_feats)
            atom_pka = self.predict(node_feats)

            g_feats = torch.log10(torch.sum(atom_pka))

            # backward
            g_feats.backward(retain_graph=True)

            all_node_relevances[node_feat_index] = h0.data * h0.grad
            h0.grad.data.zero_()

        for edge_feat_index in range(len(edge_feats)):

            self.zero_grad()

            x_edge = edge_feats_original.clone()
            x_edge = Variable(x_edge.data, requires_grad=True)

            e0 = x_edge

            mask = torch.zeros(x_edge.shape).to(self.device)
            mask[edge_feat_index] = 1

            x_edge = x_edge * mask + (1 - mask) * x_edge.data

            # forward
            node_feats = self.init_context(g, node_feats_original, x_edge)
            for gnn in self.gnn_layers:
                node_feats = gnn(g, node_feats)
            atom_pka = self.predict(node_feats)

            g_feats = torch.log10(torch.sum(atom_pka))

            # backward
            g_feats.backward(retain_graph=True)

            all_edge_relevances[edge_feat_index] = e0.data * e0.grad
            e0.grad.data.zero_()

        # nodes relevance preprocessing
        for idx, rel in all_node_relevances.items():
            relevance_score = rel.data.sum().item()
            all_node_relevances[idx] = relevance_score

        # edges relevance preprocessing
        for idx, rel in all_edge_relevances.items():
            relevance_score = rel.data.sum().item()
            all_edge_relevances[idx] = relevance_score

        return all_node_relevances, all_edge_relevances


class Pka_basic_view(nn.Module):

    def __init__(self,
                 node_feat_size,
                 edge_feat_size,
                 num_layers,
                 graph_feat_size,
                 output_size,
                 dropout):
        super(Pka_basic_view,self).__init__()

        self.device = torch.device("cpu")

        self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
        self.gnn_layers = nn.ModuleList()
        for i in range(num_layers - 1):
            self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))

        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(graph_feat_size, output_size)
        )

    def train_mode(self, train_type: str):
        if train_type == 'all_layers':
            self.init_context.train()
            self.gnn_layers.train()
            self.predict.train()
        elif train_type == 'predictor_and_readout':
            self.init_context.eval()
            self.gnn_layers.train()
            self.predict.train()
        elif train_type == 'all_layers' or train_type is None:
            self.init_context.eval()
            self.gnn_layers.eval()
            self.predict.train()

    def forward(self, g, node_feats, edge_feats, get_node_weight=False):

        mask = g.ndata['h'][:,1] * (1 - g.ndata['h'][:,61])
        #print(mask)
        mask = -1/mask +1
        #print(mask)
        node_feats = self.init_context(g, node_feats, edge_feats)
        for gnn in self.gnn_layers:
            node_feats = gnn(g, node_feats)
        atom_pka = self.predict(node_feats)
        #print(atom_pka)
        atom_pKa_after_predict = atom_pka.clone()
        # atom_pka = (atom_pka + mask.reshape(-1,1))
        #print(atom_pka)
        g.ndata['h'] = torch.pow(10,atom_pka)
        g_feats = torch.log10(dgl.sum_nodes(g, 'h'))
        atom_pka_out = torch.squeeze(atom_pka)

        # print(atom_pka)

        return g_feats, atom_pka_out.detach().cpu().numpy().tolist()

    def lrp(self, g, node_feats, edge_feats):

        node_feats_original = node_feats.clone()
        edge_feats_original = edge_feats.clone()

        all_node_relevances = {}
        all_edge_relevances = {}
        # node relevance
        for node_feat_index in range(len(node_feats)):

            self.zero_grad()

            x_node = node_feats_original.clone()
            x_node = Variable(x_node.data, requires_grad=True)

            h0 = x_node

            mask = torch.zeros(x_node.shape).to(self.device)
            mask[node_feat_index] = 1

            x_node = x_node * mask + (1 - mask) * x_node.data

            # TODO
            # AGGREGATE STEP FOR WALKS(currently only nodes)

            # forward
            node_feats = self.init_context(g, x_node, edge_feats)
            for gnn in self.gnn_layers:
                node_feats = gnn(g, node_feats)
            atom_pka = self.predict(node_feats)

            g_feats = torch.log10(torch.sum(atom_pka))

            # backward
            g_feats.backward(retain_graph=True)

            all_node_relevances[node_feat_index] = h0.data * h0.grad
            h0.grad.data.zero_()

        for edge_feat_index in range(len(edge_feats)):

            self.zero_grad()

            x_edge = edge_feats_original.clone()
            x_edge = Variable(x_edge.data, requires_grad=True)

            e0 = x_edge

            mask = torch.zeros(x_edge.shape).to(self.device)
            mask[edge_feat_index] = 1

            x_edge = x_edge * mask + (1 - mask) * x_edge.data

            # forward
            node_feats = self.init_context(g, node_feats_original, x_edge)
            for gnn in self.gnn_layers:
                node_feats = gnn(g, node_feats)
            atom_pka = self.predict(node_feats)

            g_feats = torch.log10(torch.sum(atom_pka))

            # backward
            g_feats.backward(retain_graph=True)

            all_edge_relevances[edge_feat_index] = e0.data * e0.grad
            e0.grad.data.zero_()

        # nodes relevance preprocessing
        for idx, rel in all_node_relevances.items():
            relevance_score = rel.data.sum().item()
            all_node_relevances[idx] = relevance_score

        # edges relevance preprocessing
        for idx, rel in all_edge_relevances.items():
            relevance_score = rel.data.sum().item()
            all_edge_relevances[idx] = relevance_score

        return all_node_relevances, all_edge_relevances



In [5]:
import torch
import torch.nn as nn
from torch.autograd import Variable

class MLPRegressor(nn.Module):
    """MLP for regression (over multiple tasks) from molecule representations.

    Parameters
    ----------
    in_feats : int
        Number of input molecular graph features
    hidden_feats : int
        Hidden size for molecular graph representations
    n_tasks : int
        Number of tasks, also output size
    dropout : float
        The probability for dropout. Default to 0, i.e. no dropout is performed.
    """
    def __init__(self, in_feats, hidden_feats, n_tasks, dropout=0.):
        super(MLPRegressor, self).__init__()

        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_feats, hidden_feats),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_feats),
            nn.Linear(hidden_feats, n_tasks)
        )

    def forward(self, h):
        """Predict for regression.

        Parameters
        ----------
        h : FloatTensor of shape (B, M3)
            * B is the number of molecules in a batch
            * M3 is the input molecule feature size, must match in_feats in initialization

        Returns
        -------
        FloatTensor of shape (B, n_tasks)
        """
        return self.predict(h)

class BaseGNNRegressor(nn.Module):
    """GNN based model for multitask molecular property prediction.
    We assume all tasks are regression problems.

    Parameters
    ----------
    readout_feats : int
        Size for molecular representations
    n_tasks : int
        Number of prediction tasks
    regressor_hidden_feats : int
        Hidden size in MLP regressor
    dropout : float
        The probability for dropout. Default to 0, i.e. no dropout is performed.
    """
    def __init__(self, readout_feats, n_tasks, regressor_hidden_feats=128, dropout=0.):
        super(BaseGNNRegressor, self).__init__()

        self.device = torch.device("cpu")

        self.regressor = MLPRegressor(readout_feats, regressor_hidden_feats, n_tasks, dropout)

    def forward(self, bg, node_feats, edge_feats):
        """Multi-task prediction for a batch of molecules

        Parameters
        ----------
        bg : DGLGraph
            DGLGraph for a batch of B graphs
        node_feats : FloatTensor of shape (N, D0)
            Initial features for all nodes in the batch of graphs
        edge_feats : FloatTensor of shape (M, D1)
            Initial features for all edges in the batch of graphs

        Returns
        -------
        FloatTensor of shape (B, n_tasks)
            Prediction for all tasks on the batch of molecules
        """
        # Update node representations
        feats = self.gnn(bg, node_feats, edge_feats)

        # Compute molecule features from atom features
        h_g = self.readout(bg, feats)

        # Multi-task prediction
        return self.regressor(h_g)

    def lrp(self, bg, node_feats, edge_feats):
        node_feats_original = node_feats.clone()
        edge_feats_original = edge_feats.clone()

        all_node_relevances = {}
        all_edge_relevances = {}
        # node relevance
        for node_feat_index in range(len(node_feats)):

            self.zero_grad()

            x_node = node_feats_original.clone()
            x_node = Variable(x_node.data, requires_grad=True)

            h0 = x_node

            mask = torch.zeros(x_node.shape).to(self.device)
            mask[node_feat_index] = 1

            x_node = x_node * mask + (1 - mask) * x_node.data

            # forward
            feats = self.gnn(bg, x_node, edge_feats_original)
            h_g = self.readout(bg, feats)

            # x_cloned = h_g.repeat(2, 1)
            predictions = self.regressor(h_g)
            logP_prediction = predictions[0][-1]

            # backward
            logP_prediction.backward(retain_graph=True)

            all_node_relevances[node_feat_index] = h0.data * h0.grad
            h0.grad.data.zero_()

        # edge relevance
        for edge_feat_index in range(len(edge_feats)):

            self.zero_grad()

            x_edge = edge_feats_original.clone()
            x_edge = Variable(x_edge.data, requires_grad=True)

            e0 = x_edge

            mask = torch.zeros(x_edge.shape).to(self.device)
            mask[edge_feat_index] = 1

            x_edge = x_edge * mask + (1 - mask) * x_edge.data

            # forward
            feats = self.gnn(bg, node_feats_original, x_edge)
            h_g = self.readout(bg, feats)

            predictions = self.regressor(h_g)
            logP_prediction = predictions[0][-1]

            # backward
            logP_prediction.backward(retain_graph=True)

            all_edge_relevances[edge_feat_index] = e0.data * e0.grad
            e0.grad.data.zero_()

        # TODO: lrp for smarts

        # nodes relevance preprocessing
        for idx, rel in all_node_relevances.items():
            relevance_score = rel.data.sum().item()
            all_node_relevances[idx] = relevance_score

        # edges relevance preprocessing
        for idx, rel in all_edge_relevances.items():
            relevance_score = rel.data.sum().item()
            all_edge_relevances[idx] = relevance_score

        return all_node_relevances, all_edge_relevances

class BaseGNNRegressorBypass(nn.Module):
    """This architecture uses one GNN for each task (task-speicifc) and one additional GNN shared
    across all tasks. To predict for each task, we feed the input to both the task-specific GNN
    and the task-shared GNN. The resulted representations of the two GNNs are then concatenated
    and fed to a task-specific forward NN.

    Parameters
    ----------
    readout_feats : int
        Size for molecular representations
    n_tasks : int
        Number of prediction tasks
    regressor_hidden_feats : int
        Hidden size in MLP regressor
    dropout : float
        The probability for dropout. Default to 0, i.e. no dropout is performed.
    """
    def __init__(self, readout_feats, n_tasks, regressor_hidden_feats=128, dropout=0.):
        super(BaseGNNRegressorBypass, self).__init__()

        self.n_tasks = n_tasks
        self.task_gnns = nn.ModuleList()
        self.readouts = nn.ModuleList()
        self.regressors = nn.ModuleList()

        for _ in range(n_tasks):
            self.regressors.append(
                MLPRegressor(readout_feats, regressor_hidden_feats, 1, dropout))

    def forward(self, bg, node_feats, edge_feats):
        """Multi-task prediction for a batch of molecules

        Parameters
        ----------
        bg : DGLGraph
            DGLGraph for a batch of B graphs
        node_feats : FloatTensor of shape (N, D0)
            Initial features for all nodes in the batch of graphs
        edge_feats : FloatTensor of shape (M, D1)
            Initial features for all edges in the batch of graphs

        Returns
        -------
        FloatTensor of shape (B, n_tasks)
            Prediction for all tasks on the batch of molecules
        """
        shared_repr = self.shared_gnn(bg, node_feats, edge_feats)
        predictions = []

        for t in range(self.n_tasks):
            task_repr = self.task_gnns[t](bg, node_feats, edge_feats)
            combined_repr = torch.cat([shared_repr, task_repr], dim=1)
            g_t = self.readouts[t](bg, combined_repr)
            predictions.append(self.regressors[t](g_t))

        # Combined predictions of all tasks
        return torch.cat(predictions, dim=1)


In [6]:
import torch.nn as nn

from dgllife.model import AttentiveFPGNN, AttentiveFPReadout

class AttentiveFPRegressor(BaseGNNRegressor):
    """AttentiveFP-based model for multitask molecular property prediction.
    We assume all tasks are regression problems.

    Parameters
    ----------
    in_node_feats : int
        Number of input node features
    in_edge_feats : int
        Number of input edge features
    gnn_out_feats : int
        The GNN output size
    num_layers : int
        Number of GNN layers
    num_timesteps : int
        Number of timesteps for updating molecular representations with GRU during readout
    n_tasks : int
        Number of prediction tasks
    regressor_hidden_feats : int
        Hidden size in MLP regressor
    dropout : float
        The probability for dropout. Default to 0, i.e. no dropout is performed.
    """
    def __init__(self, in_node_feats, in_edge_feats, gnn_out_feats, num_layers, num_timesteps,
                 n_tasks, regressor_hidden_feats=128, dropout=0.):

        super(AttentiveFPRegressor, self).__init__(readout_feats=gnn_out_feats,
                                                   n_tasks=n_tasks,
                                                   regressor_hidden_feats=regressor_hidden_feats,
                                                   dropout=dropout)
        self.gnn = AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers,
                                  gnn_out_feats, dropout)
        self.readout = AttentiveFPReadout(gnn_out_feats, num_timesteps, dropout)

    def train_mode(self, train_type: str):
        if train_type == 'all_layers':
            self.gnn.train()
            self.readout.train()
        elif train_type == 'predictor_and_readout':
            self.gnn.train()
            self.readout.train()
        else:
            self.gnn.eval()
            self.readout.train()


class AttentiveFPRegressorBypass(BaseGNNRegressorBypass):
    """AttentiveFP-based model for bypass multitask molecular property prediction.
    We assume all tasks are regression problems.

    Parameters
    ----------
    in_node_feats : int
        Number of input node features
    in_edge_feats : int
        Number of input edge features
    gnn_out_feats : int
        The GNN output size
    num_layers : int
        Number of GNN layers
    num_timesteps : int
        Number of timesteps for updating molecular representations with GRU during readout
    n_tasks : int
        Number of prediction tasks
    regressor_hidden_feats : int
        Hidden size in MLP regressor
    dropout : float
        The probability for dropout. Default to 0, i.e. no dropout is performed.
    """
    def __init__(self, in_node_feats, in_edge_feats, gnn_out_feats, num_layers, num_timesteps,
                 n_tasks, regressor_hidden_feats=128, dropout=0.):
        super(AttentiveFPRegressorBypass, self).__init__(
            readout_feats= 2 * gnn_out_feats, n_tasks=n_tasks,
            regressor_hidden_feats=regressor_hidden_feats,
            dropout=dropout)
        self.shared_gnn = AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers,
                                         gnn_out_feats, dropout)
        #self.shared_gnn.load_state_dict()
        for _ in range(n_tasks):
            self.task_gnns.append(AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers,
                                                 gnn_out_feats, dropout))
            self.readouts.append(AttentiveFPReadout(2 * gnn_out_feats, num_timesteps, dropout))


In [7]:
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


import dgl
import errno
import json
import os
import torch
from dgllife.data import MoleculeCSVDataset
from dgllife.utils import smiles_to_bigraph, ScaffoldSplitter, RandomSplitter
from functools import partial
import numpy
import numpy as np

import random
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def mkdir_p(path):
    """Create a folder for the given path.
    Parameters
    ----------
    path: str
        Folder to create
    """
    try:
        os.makedirs(path)
        print('Created directory {}'.format(path))
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            print('Directory {} already exists.'.format(path))
        else:
            raise

def setup(args, random_seed=0):
    """Decide the device to use for computing, set random seed and perform sanity check."""

    os.environ['PYTHONHASHSEED']=str(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    dgl.seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed)
        torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    if args['n_tasks'] == 1:
        assert args['mode'] == 'parallel', \
            'Bypass architecture is not applicable for single-task experiments.'

    return args

def default(self, obj):
    if isinstance(obj, (numpy.int_, numpy.intc, numpy.intp, numpy.int8,
        numpy.int16, numpy.int32, numpy.int64, numpy.uint8,
        numpy.uint16,numpy.uint32, numpy.uint64)):
        return int(obj)
    elif isinstance(obj, (numpy.float_, numpy.float16, numpy.float32,
        numpy.float64)):
        return float(obj)
    elif isinstance(obj, (numpy.ndarray,)): # add this line
        return obj.tolist() # add this line
    return json.JSONEncoder.default(self, obj)


def get_label_mean_and_std(dataset):
    """Compute the mean and std of labels.

    Non-existing labels are excluded for computing mean and std.

    Parameters
    ----------
    dataset
        We assume that len(dataset) gives the number of datapoints
        in the dataset and dataset[i] gives the SMILES, RDKit molecule
        instance, DGLGraph, label and mask for the i-th datapoint.

    Returns
    -------
    labels_mean: float32 tensor of shape (T)
        Mean of the labels for all tasks
    labels_std: float32 tensor of shape (T)
        Std of the labels for all tasks
    """
    _, _, label, _ = dataset[0]
    n_tasks = label.shape[-1]
    task_values = {t: [] for t in range(n_tasks)}
    for i in range(len(dataset)):
        _, _, label, mask = dataset[i]
        for t in range(n_tasks):
            if mask[t].data.item() == 1.:
                task_values[t].append(label[t].data.item())

    labels_mean = torch.zeros(n_tasks)
    labels_std = torch.zeros(n_tasks)
    for t in range(n_tasks):
        labels_mean[t] = float(np.mean(task_values[t]))
        labels_std[t] = float(np.std(task_values[t]))

    return labels_mean, labels_std

def collate(data):
    """Batching a list of datapoints for dataloader in training GNNs.

    Returns
    -------
    smiles: list
        List of smiles
    bg: DGLGraph
        DGLGraph for a batch of graphs
    labels: Tensor of dtype float32 and shape (B, T)
        Batched datapoint labels. B is len(data) and
        T is the number of total tasks.
    masks: Tensor of dtype float32 and shape (B, T)
        Batched datapoint binary mask, indicating the
        existence of labels.
    """
    smiles, graphs, labels, masks = map(list, zip(*data))
    bg = dgl.batch(graphs)
    labels = torch.stack(labels, dim=0)
    masks = torch.stack(masks, dim=0)

    return smiles, bg, labels, masks

def load_model(exp_configure):

    if exp_configure['model'] == 'AttentiveFP':
        if exp_configure['mode'] == 'parallel':
            model_class = AttentiveFPRegressor
        else:
            model_class = AttentiveFPRegressorBypass
        model = model_class(in_node_feats=exp_configure['in_node_feats'],
                            in_edge_feats=exp_configure['in_edge_feats'],
                            gnn_out_feats=exp_configure['graph_feat_size'],
                            num_layers=exp_configure['num_layers'],
                            num_timesteps=exp_configure['num_timesteps'],
                            n_tasks=exp_configure['n_tasks'],
                            regressor_hidden_feats=exp_configure['regressor_hidden_feats'],
                            dropout=exp_configure['dropout'])
    return model

def init_featurizer(args):
    """Initialize node/edge featurizer

    Parameters
    ----------
    args : dict
        Settings

    Returns
    -------
    args : dict
        Settings with featurizers updated
    """
    if args['atom_featurizer_type'] == 'canonical':
        from dgllife.utils import CanonicalAtomFeaturizer
        args['node_featurizer'] = CanonicalAtomFeaturizer()
    elif args['atom_featurizer_type'] == 'attentivefp':
        from dgllife.utils import AttentiveFPAtomFeaturizer
        args['node_featurizer'] = AttentiveFPAtomFeaturizer()
    else:
        return ValueError(
            "Expect node_featurizer to be in ['canonical', 'attentivefp'], "
            "got {}".format(args['atom_featurizer_type']))


    if args['model'] in ['AttentiveFP']:
        if args['bond_featurizer_type'] == 'canonical':
            from dgllife.utils import CanonicalBondFeaturizer
            args['edge_featurizer'] = CanonicalBondFeaturizer()
        elif args['bond_featurizer_type'] == 'attentivefp':
            from dgllife.utils import AttentiveFPBondFeaturizer
            args['edge_featurizer'] = AttentiveFPBondFeaturizer()
    else:
        args['edge_featurizer'] = None

    return args

from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer

def get_configure(model, args=None):
    """Query for the manually specified configuration

    Parameters
    ----------
    model : str
        Model type

    Returns
    -------
    dict
        Returns the manually specified configuration
    """
    if args == None:
            with open('C:\work\DrugDiscovery\RT_LogP_with_pKa_model\RTlogD\model_configures/{}.json'.format(model), 'r') as f:
                config = json.load(f)
            return config
    else:
            with open('C:\work\DrugDiscovery\RT_LogP_with_pKa_model\RTlogD\model_configures/{}.json'.format(model), 'r') as f:
                config = json.load(f)
            return config

def mkdir_p(path):
    """Create a folder for the given path.

    Parameters
    ----------
    path: str
        Folder to create
    """
    try:
        os.makedirs(path)
        print('Created directory {}'.format(path))
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            print('Directory {} already exists.'.format(path))
        else:
            raise

def init_trial_path(args):
    """Initialize the path for a hyperparameter setting

    Parameters
    ----------
    args : dict
        Settings

    Returns
    -------
    args : dict
        Settings with the trial path updated
    """
    trial_id = 0
    path_exists = True
    while path_exists:
        trial_id += 1
        path_to_results = args['result_path'] + '/{:d}'.format(trial_id)
        path_exists = os.path.exists(path_to_results)
    args['trial_path'] = path_to_results
    mkdir_p(args['trial_path'])

    return args


def collate_molgraphs(data):
    """Batching a list of datapoints for dataloader.

    Parameters
    ----------
    data : list of 4-tuples.
        Each tuple is for a single datapoint, consisting of
        a SMILES, a DGLGraph, all-task labels and a binary
        mask indicating the existence of labels.

    Returns
    -------
    smiles : list
        List of smiles
    bg : DGLGraph
        The batched DGLGraph.
    labels : Tensor of dtype float32 and shape (B, T)
        Batched datapoint labels. B is len(data) and
        T is the number of total tasks.
    masks : Tensor of dtype float32 and shape (B, T)
        Batched datapoint binary mask, indicating the
        existence of labels.
    """
    smiles, graphs, labels, masks = map(list, zip(*data))

    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)

    if masks is None:
        masks = torch.ones(labels.shape)
    else:
        masks = torch.stack(masks, dim=0)

    return smiles, bg, labels, masks

def collate_molgraphs_unlabeled(data):
    """Batching a list of datapoints without labels

    Parameters
    ----------
    data : list of 2-tuples.
        Each tuple is for a single datapoint, consisting of
        a SMILES and a DGLGraph.

    Returns
    -------
    smiles : list
        List of smiles
    bg : DGLGraph
        The batched DGLGraph.
    """
    smiles, graphs = map(list, zip(*data))
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)

    return smiles, bg



def predict(args, model, bg, default_weights=True):
    bg = bg.to(args['device'])
    if args['edge_featurizer'] is None:
        node_feats = bg.ndata.pop('h').to(args['device'])
        return model(bg, node_feats)
    elif args['bond_featurizer_type'] == 'pre_train':
        node_feats = [
            bg.ndata.pop('atomic_number').to(args['device']),
            bg.ndata.pop('chirality_type').to(args['device'])
        ]
        edge_feats = [
            bg.edata.pop('bond_type').to(args['device']),
            bg.edata.pop('bond_direction_type').to(args['device'])
        ]

        return model(bg, node_feats, edge_feats,get_node_weight=True)
    else:
        node_feats = bg.ndata['h'].to(args['device'])
        edge_feats = bg.edata['e'].to(args['device'])
        with bg.local_scope():
            pka1_model = Pka_acidic_view(node_feat_size = 74,
                    edge_feat_size = 12,
                    output_size = 1,
                    num_layers= 6,
                    graph_feat_size=200,
                    dropout=0).to(args['device'])

            pka1_model.eval()
            with torch.no_grad():
                if default_weights:
                    pka1_model.load_state_dict(torch.load(r'C:\work\DrugDiscovery\RT_LogP_with_pKa_model\RTlogD\Trained_model/site_acidic.pkl',map_location='cpu'))
                else:
                    pka1_model.load_state_dict(torch.load(r'C:\work\DrugDiscovery\RT_LogP_with_pKa_model\RTlogD\Trained_model/site_acidic_best_loss_distinctive-butterfly-11.pkl',map_location='cpu'))
                pka_acidic_prediction,pka1_atom_list = pka1_model(bg,bg.ndata['h'], bg.edata['e'])

        with bg.local_scope():
            pka2_model = Pka_basic_view(node_feat_size = 74,
                    edge_feat_size = 12,
                    output_size = 1,
                    num_layers= 6,
                    graph_feat_size=200,
                    dropout=0).to(args['device'])

            pka2_model.eval()
            with torch.no_grad():

                if default_weights:
                    pka2_model.load_state_dict(torch.load(r'C:\work\DrugDiscovery\RT_LogP_with_pKa_model\RTlogD\Trained_model/site_basic.pkl',map_location='cpu'))
                else:
                    pka2_model.load_state_dict(torch.load(r'C:\work\DrugDiscovery\RT_LogP_with_pKa_model\RTlogD\Trained_model/site_amine_best_loss_sweet-capybara-11.pkl',map_location='cpu'))

                pka_basic_prediction,pka2_atom_list = pka2_model(bg,bg.ndata['h'], bg.edata['e'])

        pka1_atom_list=np.array(pka1_atom_list)
        pka1_atom_list[np.isinf(pka1_atom_list)]=15
        pka2_atom_list=np.array(pka2_atom_list)
        pka2_atom_list[np.isinf(pka2_atom_list)]=0

        pka1_feature = torch.Tensor(pka1_atom_list/11).to(args['device'])
        pka2_feature = torch.Tensor(pka2_atom_list/11).to(args['device'])

        pka1_feature=pka1_feature.unsqueeze(-1)
        pka2_feature=pka2_feature.unsqueeze(-1)

        node_feats = torch.cat([node_feats,pka1_feature,pka2_feature],dim = 1)

        return model(bg, node_feats, edge_feats), pka_acidic_prediction, pka_basic_prediction
def split_dataset(args, dataset):
    """Split the dataset
    Parameters
    ----------
    args : dict
        Settings
    dataset
        Dataset instance
    Returns
    -------
    train_set
        Training subset
    val_set
        Validation subset
    test_set
        Test subset
    """
    train_ratio, val_ratio, test_ratio = map(float, args['split_ratio'].split(','))
    if args['split'] == 'scaffold_decompose':
        train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio,
            scaffold_func='decompose')
    elif args['split'] == 'scaffold_smiles':
        train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio,
            scaffold_func='smiles')
    elif args['split'] == 'random':
        train_set, val_set, test_set = RandomSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio)
    else:
        return ValueError("Expect the splitting method to be 'scaffold', got {}".format(args['split']))

    return train_set, val_set, test_set


In [8]:
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer

def load_dataset(args, df,name):
    dataset = MoleculeCSVDataset(df=df,
                                 smiles_to_graph=partial(smiles_to_bigraph, num_virtual_nodes=0),
                                 node_featurizer=CanonicalAtomFeaturizer(),
                                 edge_featurizer=CanonicalBondFeaturizer(),
                                 smiles_column=args['smiles_column'],
                                #  smiles_column="Smiles",
                                 cache_file_path=args['result_path'] +'/'+ str(name)+'_graph.bin',
                                 task_names=args['task'],
                                 n_jobs=args['num_workers'],
                                 load=False
                                )

    return dataset

In [11]:
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import time
import json
import pickle
import pandas as pd
import numpy as np
from functools import partial
import torch
import sys
import os

from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, CyclicLR

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from math import sqrt

from hyperopt import fmin, tpe, hp

from tqdm import tqdm
# import wandb

crossvals_list = []

class EarlyStopper:
    def __init__(self, patience=20, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = pow(10,10)

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > self.min_validation_loss:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

class Trainer:

    def __init__(self) -> None:
        self.space = {
            'lr': hp.loguniform('lr', -10, -3),
            'weight_decay': hp.loguniform('weight_decay', -10, -1),
            'train_type': hp.choice('train_type', ['only_predictor', 'predictor_and_readout'])
        }

        with open(r'C:\work\DrugDiscovery\main_git\XAI_Chem\ml_part\configs\args.pickle', 'rb') as file:
            self.args =pickle.load(file)
            self.args['task'] = ['pKa']
        with open(r'C:\work\DrugDiscovery\main_git\XAI_Chem\ml_part\configs\configure.json', 'r') as f:
            exp_config = json.load(f)
        self.args['device'] = torch.device('cpu')
        self.args = setup(self.args)

        self.args['smiles_column'] = "Smiles"

        self.train_set_basic = pd.read_csv(r'C:\work\DrugDiscovery\main_git\XAI_Chem\data\pKa_basicity_data\gnn_cv_canon_smiles\train_basic.csv', index_col=0)
        self.train_set_acid = pd.read_csv(r'C:\work\DrugDiscovery\main_git\XAI_Chem\data\pKa_basicity_data\gnn_cv_canon_smiles\train_acid.csv', index_col=0)
        self.train_set = pd.concat([self.train_set_basic, self.train_set_acid], axis=0)

        self.test_set_basic = pd.read_csv(r'C:\work\DrugDiscovery\main_git\XAI_Chem\data\pKa_basicity_data\gnn_cv_canon_smiles\test_basic.csv', index_col=0)
        self.test_set_acid = pd.read_csv(r'C:\work\DrugDiscovery\main_git\XAI_Chem\data\pKa_basicity_data\gnn_cv_canon_smiles\test_acid.csv', index_col=0)
        self.test_set = pd.concat([self.test_set_basic, self.test_set_acid], axis=0)

        best_hyperparameters = self.find_best_params_with_hyperopt()

        print(crossvals_list)
        cv_file = pd.DataFrame(crossvals_list)
        # cv_file.to_csv(rf'{link_to_colab_directory}/pKa_acidic(hopefully)_model_cv_saved_canon_smiles_all_dataset.csv')

        model = Trainer.load_pKa_acidic_model(self.args)

        train_set = load_dataset(self.args, self.train_set, "test")
        test_set = load_dataset(self.args, self.test_set, "test")

        loss_cv, metrics_oos, metrics_train, best_val_pred_values, best_train_pred_values, true_val_values, true_train_values = Trainer.train(args=self.args,
            model=model,
            _train_set=train_set,
            _test_set=test_set,
            lr=best_hyperparameters['lr'],
            weight_decay=best_hyperparameters['weight_decay'],
            train_type=best_hyperparameters['train_type'],
            save_best_model=True
        )

        print(f"OOS: {metrics_oos}")
        print(f"Train: {metrics_train}")
        print(f"True train values: {true_train_values}")
        print(f"Pred train values: {best_train_pred_values}")
        print(f"True val values: {true_val_values}")
        print(f"Pred val values: {best_val_pred_values}")

    @staticmethod
    def init_wandb(
        train_mode: str = None, batch_size: int = 32,
        lr: float = 10 ** (-2.5),
        weight_decay: float = 10 ** (-5.0),
        model_name: str = "",
        epochs: int = 200
    ):
        run = wandb.init(
            project="enamine-pKa-amine-acid-canon-smiles-combined-data",

            config={
                "learning_rate": lr,
                "train_mode": train_mode,
                "batch_size": batch_size,
                "architecture": model_name,
                "dataset": "data\pKa_basicity_data_canon_smiles",
                "epochs": epochs,
                "optimizer": f"Adam(lr={lr}, weight_decay={weight_decay})",
                "loss": "MSELoss",
                "scheduler": "ReduceLrOnPlateu(mode='min', patience=10, factor=0.5, min_lr=0.000002)",
                "EarlyStoppingPatience": 20
            }
        )

        return run.name

    def find_best_params_with_hyperopt(self):
        algo = tpe.suggest

        objective_partial = partial(Trainer.optimization_function, X_train=self.train_set)

        best_hyperparams = fmin(fn=objective_partial, space=self.space, algo=algo, max_evals=15, verbose=1)

        print("Найкращі гіперпараметри:", best_hyperparams)
        return best_hyperparams

    @staticmethod
    def optimization_function(params, X_train):
        global crossvals_list

        amount_of_cross_vals = 2

        cv_indices_dict = {0: [], 1: []}
        for index, row in X_train.iterrows():
            cv_indices_dict[row['fold_id']].append(index)

        cv_indices = [[cv_indices_dict[0], cv_indices_dict[1]], [cv_indices_dict[1], cv_indices_dict[0]]]

        cv_dict = {"params": params}

        total_loss = 0
        for cross_val_index in range(amount_of_cross_vals):
            # df.loc[df.index[index_list]]
            train_df_cv = X_train.loc[cv_indices[cross_val_index][0]]

            test_df_cv = X_train.loc[cv_indices[cross_val_index][1]]

            with open(r'C:\work\DrugDiscovery\main_git\XAI_Chem\ml_part\configs\args.pickle', 'rb') as file:
                args =pickle.load(file)
                args['task'] = ['pKa']
            with open(r'C:\work\DrugDiscovery\main_git\XAI_Chem\ml_part\configs\configure.json', 'r') as f:
                exp_config = json.load(f)
            args['device'] = torch.device('cpu')
            args = setup(args)

            args['smiles_column'] = "Smiles"

            acidic_model = Trainer.load_pKa_acidic_model(args)

            train_set = load_dataset(args,train_df_cv,"test")
            test_set = load_dataset(args,test_df_cv,"test")

            loss_cv, metrics_cv, metrics_cv_train, best_val_pred_values, best_train_pred_values, true_val_values, true_train_values = Trainer.train(
                args=args,
                model=acidic_model,
                _train_set=train_set,
                _test_set=test_set,
                save_best_model=True,
                cv_index=cross_val_index,
                lr=params['lr'],
                weight_decay=params['weight_decay'],
                train_type=params['train_type']
            )

            cv_dict[f"cv_{cross_val_index}_loss"] = loss_cv
            cv_dict[f"cv_{cross_val_index}_r^2"] = metrics_cv["r_score"]
            cv_dict[f"cv_{cross_val_index}_mse"] = metrics_cv["mse"]
            cv_dict[f"cv_{cross_val_index}_mae"] = metrics_cv["mae"]
            cv_dict[f"cv_{cross_val_index}_true_values"] = true_val_values
            cv_dict[f"cv_{cross_val_index}_pred_values"] = best_val_pred_values
            cv_dict[f"cv_{cross_val_index}_train_r^2"] = metrics_cv_train["r_score"]
            cv_dict[f"cv_{cross_val_index}_train_mse"] = metrics_cv_train["mse"]
            cv_dict[f"cv_{cross_val_index}_train_mae"] = metrics_cv_train["mae"]
            cv_dict[f"cv_{cross_val_index}_train_true_values"] = true_train_values
            cv_dict[f"cv_{cross_val_index}_train_pred_values"] = best_train_pred_values

            total_loss += loss_cv

        crossvals_list.append(cv_dict)
        return total_loss


    @staticmethod
    def load_pKa_acidic_model(args):
        pka1_model = Pka_acidic_view(node_feat_size = 74,
                                    edge_feat_size = 12,
                                    output_size = 1,
                                    num_layers= 6,
                                    graph_feat_size=200,
                                    dropout=0).to(args['device'])
        pka1_model.load_state_dict(torch.load(r'C:\work\DrugDiscovery\main_git\XAI_Chem\ml_part\weights\pKa\site_acidic.pkl',map_location='cpu'))
        return pka1_model


    @staticmethod
    def load_pKa_basic_model(args):
        pka2_model = Pka_basic_view(node_feat_size = 74,
                                    edge_feat_size = 12,
                                    output_size = 1,
                                    num_layers= 6,
                                    graph_feat_size=200,
                                    dropout=0).to(args['device'])
        pka2_model.load_state_dict(torch.load(r'C:\work\DrugDiscovery\main_git\XAI_Chem\ml_part\weights\pKa\site_basic.pkl',map_location='cpu'))
        return pka2_model


    @staticmethod
    def calculate_metrics(true_values, pred_values):
        # print(true_values)
        # print(pred_values)
        mse = round(mean_squared_error(true_values, pred_values),3)
        mae = round(mean_absolute_error(true_values, pred_values),3)
        r_score = round(r2_score(true_values, pred_values),3)

        return {"mse": mse,
                "mae": mae,
                "r_score": r_score,}


    @staticmethod
    def train(args, model, _train_set, _test_set, num_epochs=200, use_wandb=False, save_best_model=True, cv_index=None,
              lr=10**(-2.5), weight_decay=0.0006897, batch_size=150, train_type: str = None,):
        run_name = ""
        if use_wandb is True:
            run_name = Trainer.init_wandb(train_mode=train_type,
                                        batch_size=batch_size,
                                        lr=lr,
                                        weight_decay=weight_decay,
                                        model_name="pKa attentive fp canonical acid",
                                        epochs=num_epochs)
        batch_size = batch_size
        train_loader = DataLoader(dataset=_train_set, batch_size=batch_size,
                                collate_fn=collate_molgraphs, num_workers=args['num_workers'])
        test_loader = DataLoader(dataset=_test_set, batch_size=batch_size,
                                collate_fn=collate_molgraphs, num_workers=args['num_workers'])

        baseline_lr = lr
        baseline_wd = weight_decay

        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        criterion = nn.MSELoss()
        scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='min', patience=10, factor=0.5, min_lr=0.000002)
        early_stopper = EarlyStopper(patience=20, min_delta=0.001)

        best_vloss = pow(10, 10)
        best_train_metrics, best_val_metrics = None, None
        best_train_pred_values, best_val_pred_values = None, None
        for epoch in tqdm(range(num_epochs)):
            model.train_mode(train_type)
            running_loss = 0.0
            running_vloss = 0.0
            for i, batch_data in enumerate(train_loader):
                _, bg, labels, masks = batch_data
                optimizer.zero_grad()

                model_prediction, _ = model(bg,bg.ndata['h'], bg.edata['e'])
                loss = criterion(model_prediction, labels)
                loss.backward()

                optimizer.step()
                scheduler.step(loss)

                running_loss += loss.item()
            avg_loss = running_loss / (i + 1)

            true_train_values = labels.view(-1).tolist()
            pred_train_values = model_prediction.view(-1).tolist()
            train_metrics = Trainer.calculate_metrics(true_train_values, pred_train_values)

            model.eval()
            with torch.no_grad():
                for i, batch_data in enumerate(test_loader):
                    _, bg, labels, masks = batch_data
                    model_prediction, _ = model(bg,bg.ndata['h'], bg.edata['e'])
                    vloss = criterion(model_prediction, labels)

                    running_vloss += vloss.item()
            avg_vloss = running_vloss / (i + 1)

            true_val_values = labels.view(-1).tolist()
            pred_val_values = model_prediction.view(-1).tolist()
            val_metrics = Trainer.calculate_metrics(true_val_values, pred_val_values)

            lr = optimizer.param_groups[0]['lr']
            if avg_vloss < best_vloss:
                # if save_best_model is True:
                #     if cv_index is None:
                #         torch.save(model.state_dict(), rf'{link_to_colab_directory}/pKa/basic_best_loss_{run_name}.pkl')
                #     else:
                #         torch.save(model.state_dict(), rf'{link_to_colab_directory}/pKa/cv_{cv_index}_best_loss_lr_{baseline_lr}_wd_{baseline_wd}_train_type_{train_type}.pkl')
                best_train_metrics = train_metrics
                best_val_metrics = val_metrics
                best_val_pred_values = pred_val_values.copy()
                best_train_pred_values = pred_train_values.copy()
                pass

            if use_wandb is True:
                wandb.log({"loss/train": avg_loss,
                            "loss/val": avg_vloss,
                            "lr": lr,
                            "mse/train": train_metrics['mse'],
                            "mae/train": train_metrics['mae'],
                            "r^2/train": train_metrics['r_score'],
                            "mse/val": val_metrics['mse'],
                            "mae/val": val_metrics['mae'],
                            "r^2/val": val_metrics['r_score']})

            is_early_stop = early_stopper.early_stop(avg_vloss)
            if is_early_stop:
                break

            print('LOSS train: {} valid: {}, lr: {}'.format(loss, vloss, lr))

        if use_wandb is True:
            wandb.finish()

        return avg_vloss, best_val_metrics, best_train_metrics, best_val_pred_values, best_train_pred_values, true_val_values, true_train_values

import warnings
warnings.filterwarnings("ignore")

Trainer()


Processing dgl graphs from scratch...                 
Processing dgl graphs from scratch...                 
  0%|          | 0/15 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 0/200 [00:10<?, ?it/s]
job exception: DataLoader worker (pid(s) 13596) exited unexpectedly



  0%|          | 0/15 [00:18<?, ?trial/s, best loss=?]


RuntimeError: DataLoader worker (pid(s) 13596) exited unexpectedly

In [None]:
train_set_basic = pd.read_csv(r'/content/drive/MyDrive/colab/R&D/finetune_separate_amine_basic_models_on_all_dataset/train_basic.csv')
train_set_acid = pd.read_csv(r'/content/drive/MyDrive/colab/R&D/finetune_separate_amine_basic_models_on_all_dataset/train_acid.csv')
train_set = pd.concat([train_set_basic, train_set_acid], axis=0)

In [None]:
with open(r'/content/drive/MyDrive/colab/R&D/finetune_separate_amine_basic_models_on_all_dataset/args.pickle', 'rb') as file:
    args =pickle.load(file)
    args['task'] = ['pKa']

In [None]:
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer

def load_dataset(args, df,name):
    dataset = MoleculeCSVDataset(df=df,
                                 smiles_to_graph=partial(smiles_to_bigraph, num_virtual_nodes=0),
                                 node_featurizer=CanonicalAtomFeaturizer(),
                                 edge_featurizer=CanonicalBondFeaturizer(),
                                #  smiles_column=args['smiles_column'],
                                 smiles_column="Smiles",
                                 cache_file_path=args['result_path'] +'/'+ str(name)+'_graph.bin',
                                 task_names=args['task'],
                                 n_jobs=args['num_workers'],
                                 load=False
                                )

In [None]:
args['node_featurizer'].feat_size('h')

74

In [None]:
train_set = load_dataset(args,train_set,"test")

Processing dgl graphs from scratch...


In [None]:
cv_file = pd.DataFrame(crossvals_list)
cv_file.to_csv(rf'{link_to_colab_directory}/pKa_amine_model_cv_canon_smiles_all_dataset.csv')

In [None]:
crossvals_list

[{'params': {'lr': 0.00014063989290627956,
   'train_type': 'predictor_and_readout',
   'weight_decay': 0.03177412302918954},
  'cv_0_loss': 0.26258036494255066,
  'cv_0_r^2': 0.959,
  'cv_0_mse': 0.263,
  'cv_0_mae': 0.367,
  'cv_0_true_values': [9.449999809265137,
   10.579999923706055,
   9.770000457763672,
   6.559999942779541,
   10.420000076293945,
   5.960000038146973,
   10.380000114440918,
   9.40999984741211,
   6.369999885559082,
   10.859999656677246,
   10.399999618530273,
   10.529999732971191,
   8.869999885559082,
   7.070000171661377,
   7.110000133514404,
   6.739999771118164,
   6.21999979019165,
   7.010000228881836,
   8.600000381469727,
   9.170000076293945,
   7.989999771118164,
   8.960000038146973,
   9.550000190734863,
   10.319999694824219,
   9.75,
   10.380000114440918,
   8.670000076293945,
   9.569999694824219,
   9.020000457763672,
   10.050000190734863,
   10.460000038146973,
   10.029999732971191,
   7.199999809265137,
   8.369999885559082,
   8.130000