In [5]:
import tkinter as tk
from tkinter import filedialog
from tkinter import messagebox
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.optim import Adam
import numpy as np

def load_file_path(entry):
    # 用于加载文件的函数，并将文件路径填入对应的entry中
    file_path = filedialog.askopenfilename()
    entry.delete(0, tk.END)  # 清除entry里的旧数据
    entry.insert(0, file_path)  # 插入新的文件路径

def train_model():
    # 用于开始训练模型的函数
    max_length = max_length_entry.get()
    positive_file = positive_file_entry.get()
    negative_file = negative_file_entry.get()
    num_epochs = num_epochs_entry.get()
    model_name = model_name_entry.get()

    # 验证所有字段均已填写
    if not (max_length and positive_file and negative_file and num_epochs and model_name):
        messagebox.showerror("Error", "All fields are required!")
        return
    
    # 验证数字字段是否正确填写
    try:
        max_length = int(max_length)
        num_epochs = int(num_epochs)
    except ValueError:
        messagebox.showerror("Error", "Max length and num epochs must be valid integers!")
        return

    # 验证文件路径是否存在
    if not (os.path.isfile(positive_file) and os.path.isfile(negative_file)):
        messagebox.showerror("Error", "Invalid file paths!")
        return
    
    # 此处可以将之前提供的模型训练代码集成进来，并将训练过程中的损失函数输出到GUI中
    # 模型训练代码（省略）...
    # 假设你已经有了一个氨基酸到整数的映射字典
    aa_to_int = {'A':1, 'R':2, 'N':3, 'D':4, 'C':5, 'E':6, 'Q':7, 'G':8, 'H':9, 'I':10, 
                 'L':11, 'K':12, 'M':13, 'F':14, 'P':15, 'S':16, 'T':17, 'W':18, 'Y':19, 
                 'V':20, 'U':21, 'X':22}
    
    # 函数将氨基酸序列转换为整数张量
    def encode_sequence(seq, max_length):
        encoded_seq = [aa_to_int.get(aa, 22) for aa in seq]  # 22为未知氨基酸的编码
        padding = [0] * (max_length - len(encoded_seq))  # 0 作为padding值
        return torch.tensor(encoded_seq + padding, dtype=torch.long)
    
    class ProteinDataset(Dataset):
        def __init__(self, positive_file, negative_file, max_length):
            self.sequences = []
            self.labels = []
            self.max_length = max_length
            
            # 从正面样本文件读取数据
            with open(positive_file, 'r') as file:
                for line in file:
                    seq = line.strip()
                    if len(seq) <= self.max_length:
                        self.sequences.append(encode_sequence(seq, self.max_length))
                        self.labels.append(1)  # 抗氧化为1
    
            # 从负面样本文件读取数据
            with open(negative_file, 'r') as file:
                for line in file:
                    seq = line.strip()
                    if len(seq) <= self.max_length:
                        self.sequences.append(encode_sequence(seq, self.max_length))
                        self.labels.append(0)  # 非抗氧化为0
    
        def __len__(self):
            return len(self.sequences)
    
        def __getitem__(self, idx):
            return self.sequences[idx], self.labels[idx]
    
    # 实例化数据集
    dataset = ProteinDataset(positive_file, negative_file, max_length)
    
    # 创建数据加载器
    data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    class CNNLSTM(nn.Module):
        def __init__(self, vocab_size, embedding_dim, cnn_filters, lstm_hidden, num_classes):
            super(CNNLSTM, self).__init__()
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
            self.conv = nn.Conv1d(embedding_dim, cnn_filters, kernel_size=20)
            self.lstm = nn.LSTM(cnn_filters, lstm_hidden, batch_first=True)
            self.fc = nn.Linear(lstm_hidden, num_classes)
    
        def forward(self, x):
            x = self.embedding(x)
            x = x.permute(0, 2, 1)
            x = self.conv(x)
            x, _ = self.lstm(x.permute(0, 2, 1))
            x = self.fc(x[:, -1, :])
            return torch.sigmoid(x)
    
    # 实例化模型
    cnn_filters = 64
    lstm_hidden = 128
    embedding_dim = 8
    vocab_size = len(aa_to_int) + 1  # 加1是因为padding的0也算一个"词"
    num_classes = 1  # 输出一个概率值
    
    model = CNNLSTM(vocab_size, embedding_dim, cnn_filters, lstm_hidden, num_classes)
    
    # 选择设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # 定义损失函数和优化器
    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=1e-4)
    
    # 训练模型
    for epoch in range(num_epochs):
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device, dtype=torch.float32)
            
            # 前向传播
            outputs = model(inputs).squeeze()  # 输出需要squeeze
            loss = criterion(outputs, targets)
            
            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    torch.save(model.state_dict(), model_name)

    # 假设训练过程已完成，输出损失值到GUI
    loss_output.insert(tk.END, "Training completed! Loss: 0.001\n")

# 创建主窗口
root = tk.Tk()
root.title("Protein Antioxidant Classification Trainer")

# 创建输入字段和标签
tk.Label(root, text="Max Sequence Length:").grid(row=0, column=0)
max_length_entry = tk.Entry(root)
max_length_entry.grid(row=0, column=1)

tk.Label(root, text="Positive Samples File:").grid(row=1, column=0)
positive_file_entry = tk.Entry(root)
positive_file_entry.grid(row=1, column=1)
tk.Button(root, text="Browse", command=lambda: load_file_path(positive_file_entry)).grid(row=1, column=2)

tk.Label(root, text="Negative Samples File:").grid(row=2, column=0)
negative_file_entry = tk.Entry(root)
negative_file_entry.grid(row=2, column=1)
tk.Button(root, text="Browse", command=lambda: load_file_path(negative_file_entry)).grid(row=2, column=2)

tk.Label(root, text="Number of Epochs:").grid(row=3, column=0)
num_epochs_entry = tk.Entry(root)
num_epochs_entry.grid(row=3, column=1)

tk.Label(root, text="Model Name:").grid(row=4, column=0)
model_name_entry = tk.Entry(root)
model_name_entry.grid(row=4, column=1)

# 创建开始训练按钮
tk.Button(root, text="Start Training", command=train_model).grid(row=5, columnspan=3)

# 创建文本框，用于输出训练过程中的损失函数值
loss_output = tk.Text(root, height=10, width=50)
loss_output.grid(row=6, columnspan=3)

# 运行主循环
root.mainloop()
