In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, Batch

class GCNEmbedding(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCNEmbedding, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return x
class GraphSimilarityModel(torch.nn.Module):
    def __init__(self, node_feature_dim=10, hidden_dim=64, embedding_dim=32):
        super(GraphSimilarityModel, self).__init__()
        self.gcn = GCNEmbedding(node_feature_dim, hidden_dim, embedding_dim)
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(3 * embedding_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1),
            torch.nn.Sigmoid()
        )
    
    def forward(self, data1, data2):
        embedding1 = self.gcn(data1.x, data1.edge_index, data1.batch)
        embedding2 = self.gcn(data2.x, data2.edge_index, data2.batch)
        combined = torch.cat([embedding1, embedding2, torch.abs(embedding1 - embedding2)], dim=1)
        similarity = self.fc(combined)
        return similarity.squeeze()

In [2]:
import re
def sequence_to_node_edge(sequence):
    """
    Example::

        ref_seq = 'Cys(1)(2)--Cys--OH-DL-Val(2)--4OH-Leu--OH-Ile(1)'

        ref_nodes, ref_edges = ga.sequence_to_node_edge(ref_seq)

    :param sequence: amino acid chain format.
    :type sequence: Format conversion by :py:func:`SequenceTransformer.read_sequence`
    :return: Sequence nodes and edges information
    :rtype: Display with :py:func:`create_graph`

    """
    amino_acids = sequence.split('--')
    nodes = []  # 节点信息: [(0, 'Cys'), (1, 'Cys'), ...]
    edges = []  # 边信息: [(0, 1), (1, 2), ...]
    # 用于存储特殊连接信息
    special_connections = {}
    for i, amino_acid in enumerate(amino_acids):
        # 提取氨基酸名称和特殊连接信息
        amino_acid_name = re.sub(r"\(\d+\)", "", amino_acid)
        # print(amino_acid_name)
        if '-' in amino_acid_name:
            items = [k.strip() for k in amino_acid_name.split('-')]
            for j in range(len(items) - 1):
                nodes.append((str(i) + '_' + str(j), items[j].strip().upper()))
                edges.append((str(i) + '_' + str(j), i))
            amino_acid_name = items[-1]
        nodes.append((i, amino_acid_name.strip().capitalize()))
        # 记录普通连接
        if i > 0:
            edges.append((i - 1, i))
        # 处理特殊连接
        special_conn_ids = re.findall(r"\((\d+)\)", amino_acid)
        # print(special_conn_ids)
        for conn_id in special_conn_ids:
            conn_id = int(conn_id) - 1  # 转换为从0开始的索引
            # 存储特殊连接信息
            if conn_id in special_connections:
                edges.append((special_connections[conn_id], i))
            else:
                special_connections[conn_id] = i
    return nodes, edges

def node2embedding(node):
    node_labels = ["Ala", "Arg", "Asn", "Asp", "Cys", "Gln", "Glu", "Gly", "His", "Ile", "Leu", "Lys", 
                   "Met", "Phe", "Pro", "Ser", "Thr", "Trp", "Tyr", "Val", "Orn", "Aile", "DL", "D", "Dap", "Athr",
                   "4OH", "OH"]
    if node in node_labels:
        return [1 if i==node else 0 for i in node_labels] + [0]
    else:
        return [0 for i in node_labels] + [1]

def graph2data(nodes, edges):
    # 定义节点特征矩阵和索引
    node_index_map = {node[0]: i for i, node in enumerate(nodes)}  # 用于将节点标识映射到索引
    x = torch.tensor([node2embedding(node[1]) for node in nodes], dtype=torch.float)
    
    # 将边转换为索引形式
    edge_index = torch.tensor([[node_index_map[edge[0]], node_index_map[edge[1]]] for edge in edges], dtype=torch.long).t().contiguous()
    
    # 创建 PyTorch Geometric 图数据
    data = Data(x=x, edge_index=edge_index)
    return data

def load_ga_gcn(model_path='GA_GCN.pth'):
    model = GraphSimilarityModel(node_feature_dim=29)
    model.load_state_dict(torch.load(model_path, weights_only=False))
    model.eval()
    return model

def ga_prediction(model, seq1, seq2):
    nodes1, edges1 = sequence_to_node_edge(seq1)
    nodes2, edges2 = sequence_to_node_edge(seq2)
    data1, data2 = graph2data(nodes1, edges1), graph2data(nodes2, edges2)
    return model(data1, data2).detach().numpy()

# Test

In [3]:
model = load_ga_gcn(model_path='GA_GCN.pth')

In [4]:
seq1 = 'Ala(1)--Ala--Gly--Phe--Pro--Val--Phe--Phe(1)'
seq2 = 'Ala(1)--Ala--Gly--Phe--Pro--Val--Phe--Phe(1)'
ga_prediction(model, seq1, seq2)

array(0.9999999, dtype=float32)

In [5]:
seq1 = 'Ala(1)--Ala--Gly--Phe--Pro--Val--Phe--Phe(1)'
seq2 = 'Pro(1)--Val--Phe--Phe--Ala--Ala--Gly--Phe(1)'
ga_prediction(model, seq1, seq2)

array(0.9999136, dtype=float32)

# Case

In [6]:
import pandas as pd
df = pd.read_excel('../one-letter.xlsx')
ref_seqs = df['Seq-plus'].to_list()

In [7]:
model = load_ga_gcn(model_path='GA_GCN.pth')

In [8]:
%%time
query = "Ala(1)--Ala--Gly--Phe--Pro--Val--Phe--Phe(1)"
similarities = []
for ref in ref_seqs:
    similarities.append(ga_prediction(model, query, ref))
df['sim'] = similarities
df.to_excel('GA_GCN_sim.xlsx')

CPU times: user 8.24 s, sys: 22.3 ms, total: 8.26 s
Wall time: 401 ms
