In [1]:
import tkinter as tk
from tkinter import messagebox
import torch
from torch import nn

aa_to_int = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'E': 6, 'Q': 7, 'G': 8, 'H': 9, 'I': 10,
                 'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'S': 16, 'T': 17, 'W': 18, 'Y': 19,
                 'V': 20, 'U': 21, 'X': 22}

# 确保以下类定义与保存模型时使用的完全一致
class CNNLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, cnn_filters, lstm_hidden, num_classes):
        super(CNNLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.conv = nn.Conv1d(embedding_dim, cnn_filters, kernel_size=20)
        self.lstm = nn.LSTM(cnn_filters, lstm_hidden, batch_first=True)
        self.fc = nn.Linear(lstm_hidden, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        x = self.conv(x)
        x, _ = self.lstm(x.permute(0, 2, 1))
        x = self.fc(x[:, -1, :])
        return torch.sigmoid(x)


def load_model(model_path, model):
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    return model


def encode_sequence(seq, max_length):
    encoded_seq = [aa_to_int.get(aa, 22) for aa in seq]
    padding = [0] * (max_length - len(encoded_seq))
    return torch.tensor([encoded_seq + padding], dtype=torch.long)


def predict(model, sequence, max_length):
    model.eval()
    with torch.no_grad():
        encoded_seq = encode_sequence(sequence, max_length)
        prediction = model(encoded_seq)
        return prediction.item() * 100


# GUI部分
class Application(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title('抗氧化蛋白预测工具')
        self.geometry('500x300')
        self.create_widgets()

    def create_widgets(self):
        # 输入框：模型路径
        self.entry_model_path = tk.Entry(self)
        self.entry_model_path.pack()
        self.entry_model_path.insert(0, "请输入模型路径")

        # 说明标签
        self.label_max_length = tk.Label(self, text=f"请输入氨基酸序列(最大长度 {max_length}):")
        self.label_max_length.pack()
        
        # 输入框：未知序列
        self.entry_unknown_sequence = tk.Entry(self)
        self.entry_unknown_sequence.pack()
        self.entry_unknown_sequence.insert(0, "请输入氨基酸序列")

        # 预测按钮
        self.predict_button = tk.Button(self, text="预测", command=self.run_prediction)
        self.predict_button.pack()

        # 结果显示
        self.result_label = tk.Label(self, text="")
        self.result_label.pack()

    def run_prediction(self):
        model_path = self.entry_model_path.get()
        sequence = self.entry_unknown_sequence.get()
        if not model_path or not sequence:
            messagebox.showerror("错误", "模型路径和氨基酸序列为必填项")
            return
        try:
            # 加载模型和预测
            model = load_model(model_path, cnn_lstm_model)
            percentage = predict(model, sequence, max_length)
            self.result_label.config(text=f"该氨基酸序列为抗氧化蛋白的可能性为：{percentage:.2f}%")
        except Exception as e:
            messagebox.showerror("错误", str(e))

# 设置参数
max_length = 1000
cnn_lstm_model = CNNLSTM(vocab_size=len(aa_to_int) + 1, embedding_dim=8, cnn_filters=64, lstm_hidden=128, num_classes=1)

if __name__ == "__main__":
    app = Application()
    app.mainloop()
