Download MTPrompt-PTM code and checkpoints

In [None]:
!git clone https://github.com/hanye311/MTPrompt-PTM.git
%cd MTPrompt-PTM
!pip -q install gdown

FILE_ID = "1FfMepaY1JLUbKTZncE1u7-pm2d16IUuf"
OUT = "/content/MTPrompt-PTM/best_model_13ptm_final.pth"

!gdown --fuzzy "https://drive.google.com/uc?id={FILE_ID}" -O "{OUT}"

import os
print("size:", os.path.getsize(OUT), "bytes")

Install necessary package

In [None]:
!pip -q uninstall -y torch torchvision torchaudio \
  nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 nvidia-cuda-nvrtc-cu12 \
  nvidia-cuda-runtime-cu12 nvidia-cudnn-cu12 nvidia-cufft-cu12 \
  nvidia-curand-cu12 nvidia-cusolver-cu12 nvidia-cusparse-cu12 \
  nvidia-nccl-cu12 nvidia-nvjitlink-cu12 || true

!pip -q install -U pip setuptools wheel
!pip -q install "torch==2.5.1+cu121" "torchvision==0.20.1+cu121" "torchaudio==2.5.1+cu121" \
  --index-url https://download.pytorch.org/whl/cu121

!pip -q install "websockets>=13,<15.1" "tqdm>=4.67"

!pip -q install fair-esm biopython
!pip -q install gradio==4.44.1

import torch, numpy as np
print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| cuda_available:", torch.cuda.is_available())


Run prediction tool
1.Input or upload fasta.
2.Choose the PTM type you want to prediction.
3.Click Button "Start Prediction".
4.Wait for the result(Please set 'Runtime->Change runtime type' to GPU)

In [36]:
import os
import subprocess
import tempfile
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from google.colab import files
import io

class PTMPredictor:
    def __init__(self):
        self.setup_interface()
        self.temp_fasta_path = None
        self.result_csv_path = None

    def setup_interface(self):
        print("🧬 PTM Prediction Tool")
        print("=" * 50)

        # FASTA sequence input area
        self.fasta_input = widgets.Textarea(
            value='',
            placeholder='Please enter FASTA format sequences, example:\n>sequence1\nMKVLLLLLLLLLLAAVAVA...\n>sequence2\nMETTLPRLLLLLLLLLLLL...',
            description='FASTA Sequence:',
            layout=widgets.Layout(width='100%', height='200px')
        )

        # PTM type selection
        self.ptm_selector = widgets.Dropdown(
            options=['Phosphorylation_S', 'Phosphorylation_T', 'Phosphorylation_Y',
                    'Methylation_K', 'Acetylation_K', 'Ubiquitination_K',''],
            value='Phosphorylation_S',
            description='PTM Type:',
            style={'description_width': 'initial'}
        )

        # File upload option
        self.file_upload = widgets.FileUpload(
            accept='.fasta,.fa,.txt',
            multiple=False,
            description='FASTA file'
        )

        # Prediction button
        self.predict_button = widgets.Button(
            description='🚀 Start Prediction',
            button_style='primary',
            layout=widgets.Layout(width='200px', height='40px')
        )

        # Output area
        self.output_area = widgets.Output()

        # Bind events
        self.predict_button.on_click(self.run_prediction)
        self.file_upload.observe(self.handle_upload, names='value')

        # Display interface
        display(widgets.VBox([
            widgets.HTML("<h3>📝 Input Sequences</h3>"),
            self.fasta_input,
            widgets.HTML("<h3>📤 Or Upload File</h3>"),
            self.file_upload,
            widgets.HTML("<h3>⚙️ Parameters</h3>"),
            self.ptm_selector,
            widgets.HTML("<br>"),
            self.predict_button,
            widgets.HTML("<h3>📊 Results</h3>"),
            self.output_area
        ]))

    def handle_upload(self, change):
        """Handle file upload"""
        if change['new']:
            uploaded_file = list(change['new'].values())[0]
            content = uploaded_file['content'].decode('utf-8')
            self.fasta_input.value = content
            print(f"✅ File loaded: {list(change['new'].keys())[0]}")

    def validate_fasta(self, fasta_content):
        """Validate FASTA format"""
        if not fasta_content.strip():
            return False, "Please enter FASTA sequences"

        lines = fasta_content.strip().split('\n')
        if not any(line.startswith('>') for line in lines):
            return False, "FASTA format error: Missing sequence identifiers (lines starting with >)"

        return True, "Format is correct"

    def save_temp_fasta(self, fasta_content):
        """Save temporary FASTA file"""
        temp_dir = tempfile.mkdtemp()
        self.temp_fasta_path = os.path.join(temp_dir, 'input_sequences.fasta')

        with open(self.temp_fasta_path, 'w') as f:
            f.write(fasta_content)

        return self.temp_fasta_path

    def run_prediction(self, button):
        """Run prediction"""
        with self.output_area:
            clear_output()

            try:
                # Get input
                fasta_content = self.fasta_input.value
                ptm_type = self.ptm_selector.value

                # Validate input
                is_valid, message = self.validate_fasta(fasta_content)
                if not is_valid:
                    print(f"❌ Error: {message}")
                    return

                print("🔄 Processing...")
                print(f"PTM Type: {ptm_type}")

                # Save temporary FASTA file
                fasta_path = self.save_temp_fasta(fasta_content)
                print(f"✅ FASTA file saved: {fasta_path}")

                # Set output path
                output_dir = tempfile.mkdtemp()

                # Build command
                cmd = [
                    'python', 'test.py',
                    '--config_path', 'config/PTM_config_prompt_tuning_test.yaml',
                    '--model_path', 'best_model_13ptm_final.pth',
                    '--data_path', fasta_path,
                    '--PTM_type', ptm_type,
                    '--save_path', output_dir
                ]

                print("🚀 Running prediction model...")
                print(f"Command: {' '.join(cmd)}")

                # Run command
                result = subprocess.run(cmd, capture_output=True, text=True)

                if result.returncode == 0:
                    print("✅ Prediction completed!")

                    # Find generated CSV file
                    csv_files = [f for f in os.listdir(output_dir) if f.endswith('.csv')]

                    if csv_files:
                        self.result_csv_path = os.path.join(output_dir, csv_files[0])
                        self.display_results()
                    else:
                        print("⚠️ No output CSV file found")
                        print("Standard output:", result.stdout)
                        print("Error output:", result.stderr)

                else:
                    print(f"❌ Prediction failed (return code: {result.returncode})")
                    print("Error message:")
                    print(result.stderr)
                    if result.stdout:
                        print("Output message:")
                        print(result.stdout)

            except Exception as e:
                print(f"❌ Error occurred: {str(e)}")

    def display_results(self):
        """Display results and provide download"""
        try:
            # Read CSV file
            df = pd.read_csv(self.result_csv_path)

            print(f"📊 Prediction Results ({len(df)} records)")
            print("=" * 50)

            # Display first 10 rows preview
            print("🔍 Data Preview (first 10 rows):")
            display(HTML(df.head(10).to_html(index=False)))

            # # Display statistics
            # if 'prediction' in df.columns or 'score' in df.columns:
            #     print("\n📈 Statistics:")
            #     for col in df.columns:
            #         if col in ['prediction', 'score', 'probability']:
            #             if df[col].dtype in ['float64', 'int64']:
            #                 print(f"{col}: Mean={df[col].mean():.4f}, Min={df[col].min():.4f}, Max={df[col].max():.4f}")
            #             else:
            #                 print(f"{col}: {df[col].value_counts().to_dict()}")

            # Create download button
            download_button = widgets.Button(
                description='💾 Download Results',
                button_style='success',
                layout=widgets.Layout(width='200px', height='40px')
            )

            def download_results(button):
                files.download(self.result_csv_path)
                print("✅ File download started")

            download_button.on_click(download_results)
            display(download_button)

            # Show all data option
            show_all_button = widgets.Button(
                description='👁️ Show All Data',
                button_style='info',
                layout=widgets.Layout(width='200px', height='40px')
            )

            def show_all_data(button):
                clear_output()
                display(HTML(df.to_html(index=False)))
                display(download_button)

            show_all_button.on_click(show_all_data)
            display(show_all_button)

        except Exception as e:
            print(f"❌ Error displaying results: {str(e)}")

# Launch interface
print("Initializing PTM prediction interface...")
ptm_predictor = PTMPredictor()

Initializing PTM prediction interface...
🧬 PTM Prediction Tool


VBox(children=(HTML(value='<h3>📝 Input Sequences</h3>'), Textarea(value='', description='FASTA Sequence:', lay…

✅ File loaded: Phosphorylation_S_sequence.fasta
