## Dependencies

In [4]:
pip install pandas numpy scipy scikit-learn seaborn matplotlib pillow joblib

^C
Note: you may need to restart the kernel to use updated packages.


## ThermAL

In [7]:
import os
import pandas as pd
import numpy as np
import itertools
from scipy.integrate import simps
import threading
from sklearn.preprocessing import StandardScaler
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Rectangle
import re
import warnings
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import tkinter.ttk as ttk  

warnings.filterwarnings("ignore")

def generate_variants(sequence):
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    variants = [sequence]  
    for i in range(len(sequence)):
        for letter in amino_acids:
            if sequence[i] != letter:  
                variant = list(sequence)
                variant[i] = letter
                variants.append(''.join(variant))
    return variants

def write_excel(variants, output_file):
    df = pd.DataFrame({'Sequence': variants})
    df.to_excel(output_file, index=False)

def calculate_aac(sequence, amino_acids):
    n = len(sequence)
    aac_counts = {aa: 0 for aa in amino_acids}
    for aa in sequence:
        if aa in aac_counts:
            aac_counts[aa] += 1
    aac_normalized = {aa: count / n for aa, count in aac_counts.items()}
    return aac_normalized

def calculate_dpc(sequence, amino_acids, dipeptides):
    n = len(sequence) - 1
    dpc_counts = {dipeptide: 0 for dipeptide in dipeptides}
    for i in range(n):
        dipeptide = sequence[i:i+2]
        if dipeptide in dpc_counts:
            dpc_counts[dipeptide] += 1
    dpc_normalized = {dipeptide: count / n for dipeptide, count in dpc_counts.items()}
    return dpc_normalized

def calculate_mean(window, df_reference):
    values = [df_reference.loc[letter, 'Value'] for letter in window]
    return sum(values) / len(values)

def calculate_auc(window, df_reference):
    values = [df_reference.loc[letter, 'Value'] for letter in window]
    x_values = list(range(len(window)))
    auc = simps(values, x=x_values)
    return auc

def convert_to_1_letter_code(variant_sequence, original_sequence):
    if pd.isna(variant_sequence):
        return None
    mutations = []
    for i, (original_aa, variant_aa) in enumerate(zip(original_sequence, variant_sequence)):
        if original_aa != variant_aa:
            mutations.append(f"{original_aa}{i+1}{variant_aa}")
    return ",".join(mutations) if mutations else None

def parse_mutation_code(mutation_code):
    if pd.isna(mutation_code) or mutation_code == "":
        return []
    mutations = []
    mutation_list = mutation_code.split(',')
    for code in mutation_list:
        match = re.match(r'^([A-Z])(\d+)([A-Z])$', code.strip())
        if match:
            original_residue = match.group(1)
            position = int(match.group(2))
            mutated_residue = match.group(3)
            mutations.append({'mutation': mutated_residue, 'position': position})
    return mutations

def color_labels(text):
    if 'Mean' in text:
        return 'black'
    elif 'P' in text:
        return 'orange'
    elif any(c in text for c in ['Q', 'N', 'C', 'T', 'S']):
        return 'green'
    elif any(c in text for c in ['D', 'E']):
        return 'red'
    elif any(c in text for c in ['H', 'R', 'K']):
        return 'blue'
    elif any(c in text for c in ['W', 'Y', 'F']):
        return 'brown'
    elif any(c in text for c in ['I', 'M', 'L', 'V', 'A']):
        return 'grey'
    elif 'G' in text:
        return 'purple'
    return 'black'

class ProgressWindow:
    def __init__(self, parent):
        self.window = tk.Toplevel(parent)
        self.window.title("Progress")
        self.window.geometry("400x200")
        self.window.transient(parent)
        self.window.grab_set()

        self.label1 = tk.Label(self.window, text="Calculating AAC/DPC")
        self.label1.pack(pady=10)
        self.progress1 = ttk.Progressbar(self.window, orient='horizontal', length=300, mode='determinate')
        self.progress1.pack(pady=5)

        self.label2 = tk.Label(self.window, text="Processing Features")
        self.label2.pack(pady=10)
        self.progress2 = ttk.Progressbar(self.window, orient='horizontal', length=300, mode='determinate')
        self.progress2.pack(pady=5)

        self.completed = False

    def update_progress1(self, value, maximum):
        self.progress1['maximum'] = maximum
        self.progress1['value'] = value

    def update_progress2(self, value, maximum):
        self.progress2['maximum'] = maximum
        self.progress2['value'] = value

    def update_label2(self, text):
        self.label2.config(text=text)

    def close(self):
        self.window.destroy()

def sanitize_filename(name):
    sanitized = re.sub(r'[\\/*?:"<>|]', "_", name)
    sanitized = sanitized.strip().replace(' ', '_')
    return sanitized[:255]

def read_fasta_file(fasta_file_path):
    sequences = []
    with open(fasta_file_path, 'r') as f:
        seq_name = None
        seq_lines = []
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if seq_name is not None:
                    sequence = ''.join(seq_lines)
                    sequences.append((seq_name, sequence))
                seq_name = line[1:]  
                seq_lines = []
            else:
                seq_lines.append(line)
        if seq_name is not None:
            sequence = ''.join(seq_lines)
            sequences.append((seq_name, sequence))
    return sequences

def run_analysis(reference_sequence, job_name, root, progress_window):
    try:
        amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
        if not all(residue in amino_acids for residue in reference_sequence):
            messagebox.showerror("Invalid Sequence", f"The sequence {job_name} contains invalid amino acid residues.")
            return

        dipeptides = [''.join(pair) for pair in itertools.product(amino_acids, repeat=2)]

        variants = generate_variants(reference_sequence)

        main_output_dir = job_name
        os.makedirs(main_output_dir, exist_ok=True)

        output_dir_variants = os.path.join(main_output_dir, "sequence_variants")
        os.makedirs(output_dir_variants, exist_ok=True)

        variant_output_file = os.path.join(output_dir_variants, "sequence_variants.xlsx")
        write_excel(variants, variant_output_file)

        df_input = pd.read_excel(variant_output_file)

        columns = ['Sequence'] + list(amino_acids) + dipeptides
        composition_df = pd.DataFrame(columns=columns)

        total_sequences = len(df_input)
        for index, row in df_input.iterrows():
            sequence = row['Sequence']
            aac_scores = calculate_aac(sequence, amino_acids)
            dpc_scores = calculate_dpc(sequence, amino_acids, dipeptides)
            combined_scores = {**aac_scores, **dpc_scores}
            combined_scores['Sequence'] = sequence  
            composition_df = pd.concat([composition_df, pd.DataFrame([combined_scores])], ignore_index=True)
            progress_value = index + 1
            progress_window.window.after(0, progress_window.update_progress1, progress_value, total_sequences)

        aac_dpc_output_dir = os.path.join(main_output_dir, "Sequence's_AAC_DPC")
        os.makedirs(aac_dpc_output_dir, exist_ok=True)
        aac_dpc_output_file = os.path.join(aac_dpc_output_dir, "Sequence's_AAC_DPC.xlsx")
        composition_df.to_excel(aac_dpc_output_file, index=False)

        atlases = {
            '3_B_Atlas.xlsx': 'Bulkiness',
            '3_BT_Atlas.xlsx': 'Beta Turn Propensity',
            '3_cDR_Atlas.xlsx': 'Coli Propensity',
            '3_CF_Atlas.xlsx': 'Beta Sheet Propensity',
            '3_Kd_Atlas.xlsx': 'hydrophobicity',
            '3_P_Atlas.xlsx': 'Polarity',
            '3_DR_Atlas.xlsx': 'Alpha helicity'
        }

        df_input = pd.read_excel(aac_dpc_output_file)
        sequences = df_input['Sequence'].tolist()  

        total_features = len(atlases)
        total_sequences = len(sequences)
        total_work_units = total_features * total_sequences
        current_work_unit = 0

        for atlas_file, feature_name in atlases.items():
            progress_window.window.after(0, progress_window.update_label2, f"Processing {feature_name}")

            atlas_path = os.path.join('required_docs', atlas_file)
            df_reference = pd.read_excel(atlas_path, index_col=0)

            auc_df = pd.DataFrame(columns=['Sequence', 'AUC'])

            for seq_index, sequence in enumerate(sequences):
                sequence_row = {'Sequence': sequence}
                auc_values = []
                for i in range(len(sequence) - 8):
                    window = sequence[i:i+9]
                    mean_value = calculate_mean(window, df_reference)
                    sequence_row[f'Mean_{i+1}'] = mean_value
                    auc_value = calculate_auc(window, df_reference)
                    auc_values.append(auc_value)

                sequence_row['AUC'] = sum(auc_values) / len(auc_values) if auc_values else 0

                auc_df = pd.concat([auc_df, pd.DataFrame([sequence_row])], ignore_index=True)

                current_work_unit += 1
                progress_window.window.after(0, progress_window.update_progress2, current_work_unit, total_work_units)

            wild_type_auc = auc_df.loc[0, 'AUC'] 
            auc_df[feature_name] = auc_df['AUC'] - wild_type_auc

            df_input[feature_name] = auc_df[feature_name]

        output_dir_final = os.path.join(main_output_dir, "sequences_with_features")
        os.makedirs(output_dir_final, exist_ok=True)

        output_file_final = os.path.join(output_dir_final, "sequences_with_features.xlsx")
        df_input.to_excel(output_file_final, index=False)

        new_file_path = output_file_final  
        new_df = pd.read_excel(new_file_path)

        if new_df.isnull().values.any():
            print("Warning: The new data contains missing values. Handling missing values...")
            new_df = new_df.dropna()  

        if new_df.isnull().values.any():
            print("Warning: The new data still contains missing values. Please handle them appropriately.")

        new_df.replace([np.inf, -np.inf], np.nan, inplace=True)

        new_df = new_df.dropna()

        model_path = 'required_docs/ThermAL.joblib'
        if not os.path.exists(model_path):
            messagebox.showerror("Model Not Found", f"The model file was not found at {model_path}.")
            return
        loaded_rf = joblib.load(model_path)

        features_used_for_training = loaded_rf.feature_names_in_

        new_df_features = new_df[features_used_for_training]

        scaler = StandardScaler()
        x_new_scaled = scaler.fit_transform(new_df_features)

        x_new_scaled_df = pd.DataFrame(x_new_scaled, columns=features_used_for_training)

        y_pred_new = loaded_rf.predict(x_new_scaled_df)

        new_df['Predicted Variant Fitness'] = y_pred_new

        output_dir = os.path.join(main_output_dir, 'Predicted_fitness')
        os.makedirs(output_dir, exist_ok=True)
        output_file_path = os.path.join(output_dir, 'predicted_fitness.xlsx')
        new_df.to_excel(output_file_path, index=False)

        predicted_fitness_path = output_file_path
        sequence_variants_path = variant_output_file

        predicted_fitness_df = pd.read_excel(predicted_fitness_path)
        sequence_variants_df = pd.read_excel(sequence_variants_path)

        original_sequence = reference_sequence

        sequence_variants_df['1_letter_mutation'] = sequence_variants_df['Sequence'].apply(
            lambda x: convert_to_1_letter_code(x, original_sequence)
        )

        result_df = pd.concat([predicted_fitness_df, sequence_variants_df['1_letter_mutation']], axis=1)

        result_df['1_letter_mutation'] = result_df['1_letter_mutation'].fillna('') 

        mutations_parsed = result_df['1_letter_mutation'].apply(parse_mutation_code)
        mutations_expanded = []
        for idx, mutations in mutations_parsed.items():
            if mutations:
                for mutation in mutations:
                    mutation_row = result_df.loc[idx].copy()
                    mutation_row['mutation'] = mutation['mutation']
                    mutation_row['position'] = mutation['position']
                    mutations_expanded.append(mutation_row)
            else:
                mutation_row = result_df.loc[idx].copy()
                mutation_row['mutation'] = None
                mutation_row['position'] = None
                mutations_expanded.append(mutation_row)
        result_df = pd.DataFrame(mutations_expanded)

        output_dir = os.path.join(main_output_dir, 'Predicted_fitness_with_1_letter_mutations')
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, 'Predicted_fitness_with_1_letter_mutations.xlsx')
        result_df.to_excel(output_file, index=False)

        file_path = output_file
        df = pd.read_excel(file_path)

        df = df.dropna(subset=['mutation', 'position'])

        df['position'] = df['position'].astype(int)

        mutation_order = list("GAVLMIFYWKRHDESTCNQP")

        pivot_table_simple = df.pivot_table(index='mutation', columns='position', values='Predicted Variant Fitness', aggfunc='mean')

        pivot_table_simple = pivot_table_simple.reindex(mutation_order)

        output_dir = os.path.join(main_output_dir, 'heatmap')
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, 'heatmap_simple.xlsx')
        pivot_table_simple.to_excel(output_path)

        input_path = output_path
        data = pd.read_excel(input_path, index_col=0)

        wt_sequence = reference_sequence

        new_column_names = {}
        for pos in data.columns:
            if str(pos).isdigit():
                pos_int = int(pos)
                if pos_int <= len(wt_sequence):
                    new_column_names[pos] = f'{wt_sequence[pos_int - 1]}{pos}'
                else:
                    new_column_names[pos] = f'X{pos}' 
            else:
                new_column_names[pos] = pos

        data.rename(columns=new_column_names, inplace=True)

        corrected_output_path = os.path.join(output_dir, 'heatmap_with_WT_sequence.xlsx')
        data.to_excel(corrected_output_path)


        colors = ["lightskyblue", "white", "red"]
        cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
        cmap.set_bad('#b0b0b0') 

        file_path = corrected_output_path
        pivot_table_data = pd.read_excel(file_path, sheet_name='Sheet1', engine='openpyxl').set_index('mutation')

        plt.figure(figsize=(10, 10))  
        ax = sns.heatmap(pivot_table_data, cmap=cmap, center=0, annot=False, square=True,
                         linewidths=2, linecolor='white', cbar_kws={'label': 'Predicted Variant Fitness', 'shrink': 0.5})

        for i in range(pivot_table_data.shape[0]):
            for j in range(pivot_table_data.shape[1]):
                mutated_residue = pivot_table_data.index[i]
                col_label = pivot_table_data.columns[j]
                match = re.match(r'^([A-Z])(\d+)$', col_label)
                if match:
                    wt_residue = match.group(1)
                    position = int(match.group(2))
                    if mutated_residue == wt_residue:
                        ax.add_patch(Rectangle((j, i), 1, 1, fill=True, edgecolor='black', facecolor='white', lw=3, zorder=10))

        cbar = ax.collections[0].colorbar
        cbar.outline.set_edgecolor('black')
        cbar.outline.set_linewidth(2)
        for _, spine in ax.spines.items():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(2)

        plt.xlabel('Position', fontsize=16)
        plt.ylabel('Mutation', fontsize=16)
        plt.xticks(rotation=90, fontsize=10)
        plt.yticks(rotation=0, fontsize=10)

        for label in ax.get_yticklabels():
            label.set_color(color_labels(label.get_text()))
        for label in ax.get_xticklabels():
            label.set_color(color_labels(label.get_text()))


        output_path_flipped = os.path.join(output_dir, 'heatmap.png')
        plt.tight_layout()
        plt.savefig(output_path_flipped, dpi=300)  
        plt.close()


        file_path = corrected_output_path
        df = pd.read_excel(file_path)

        numeric_df = df.select_dtypes(include=[np.number])

        df.loc['Mean'] = numeric_df.mean(axis=0)

        start_index = 1  
        end_index = len(df.columns) - 1


        mean_row = df.loc['Mean', df.columns[start_index:end_index + 1]]

        sliding_window_mean_corrected = mean_row.rolling(window=5, min_periods=5, center=True).mean()

        df.loc['Sliding Window', df.columns[start_index:end_index + 1]] = sliding_window_mean_corrected.values

        sliding_window_excel_path = os.path.join(output_dir, 'sliding_window.xlsx')
        df.to_excel(sliding_window_excel_path, index=True)


        sliding_window_mean_corrected_neg = sliding_window_mean_corrected * -1


        fig, ax1 = plt.subplots(figsize=(10, 5))  

        ax1.plot(sliding_window_mean_corrected_neg.index, sliding_window_mean_corrected_neg, marker='o', linestyle='-', color='b', markersize=6, label='Sliding Window Mean')


        ax1.set_xlabel('Centre of Sliding Window (n=5)', fontsize=14)
        ax1.set_ylabel('Average Predicted Fitness', fontsize=14)
        ax1.set_title('Predicted Stabilising Regions (Corrected Sliding Window)', fontsize=16)
        plt.xticks(sliding_window_mean_corrected_neg.index, df.columns[start_index:end_index + 1], rotation=90, fontsize=10)


        for label in ax1.get_xticklabels():
            label.set_color(color_labels(label.get_text()))

        plt.grid(False)
        plt.tight_layout()

        # Step 6: Save the plot to the output directory
        sliding_window_output_path = os.path.join(output_dir, 'sliding_window_with_foldx.png')
        plt.savefig(sliding_window_output_path, dpi=300, bbox_inches='tight')  
        plt.close()


        try:
            heatmap_image = Image.open(output_path_flipped)
        except FileNotFoundError:
            messagebox.showerror("Error", f"Heatmap image not found at {output_path_flipped}")
            return

        try:
            sliding_window_image = Image.open(sliding_window_output_path)
        except FileNotFoundError:
            messagebox.showerror("Error", f"Sliding window image not found at {sliding_window_output_path}")
            return


        max_width = 800   # Adjust as needed
        max_height = 600  # Adjust as needed


        heatmap_image.thumbnail((max_width, max_height), Image.LANCZOS)
        sliding_window_image.thumbnail((max_width, max_height), Image.LANCZOS)

        heatmap_photo = ImageTk.PhotoImage(heatmap_image)
        sliding_window_photo = ImageTk.PhotoImage(sliding_window_image)

        heatmap_window = tk.Toplevel(root)
        heatmap_window.title(f"Heatmap - {job_name}")
        heatmap_label = tk.Label(heatmap_window, image=heatmap_photo)
        heatmap_label.image = heatmap_photo  
        heatmap_label.pack()

        sliding_window_window = tk.Toplevel(root)
        sliding_window_window.title(f"Sliding Window Plot - {job_name}")
        sliding_window_label = tk.Label(sliding_window_window, image=sliding_window_photo)
        sliding_window_label.image = sliding_window_photo  
        sliding_window_label.pack()

    except Exception as e:
        messagebox.showerror("Error", f"An error occurred with sequence {job_name}:\n{e}")
    finally:
        progress_window.completed = True

def create_gui():
    global root, run_button, fasta_file_path, fasta_file_label
    root = tk.Tk()
    root.title("ThermAL")


    root.configure(bg='white')


    window_width = 700
    window_height = 600
    screen_width = root.winfo_screenwidth()
    screen_height = root.winfo_screenheight()
    x_cordinate = int((screen_width/2) - (window_width/2))
    y_cordinate = int((screen_height/2) - (window_height/2))
    root.geometry(f"{window_width}x{window_height}+{x_cordinate}+{y_cordinate}")


    header_font = ("Helvetica", 16, "bold")
    text_font = ("Helvetica", 12)
    button_font = ("Helvetica", 14, "bold")


    try:
        logo_path = 'ThermAL.png'

        logo_image = Image.open(logo_path)


        max_logo_width = 220
        max_logo_height = 150
        logo_image.thumbnail((max_logo_width, max_logo_height), Image.LANCZOS)

        logo_photo = ImageTk.PhotoImage(logo_image)
        logo_label = tk.Label(root, image=logo_photo, bg='white')
        logo_label.image = logo_photo  
        logo_label.pack(pady=10)
    except Exception as e:
        messagebox.showerror("Image Load Error", f"An error occurred while loading the logo image:\n{e}")
        pass

    tk.Label(root, text="Welcome to ThermAL! Please upload your FASTA file:", font=header_font, bg='white', fg='black').pack(pady=10)
    fasta_file_path = None  

    def select_fasta_file():
        global fasta_file_path
        fasta_file_path = filedialog.askopenfilename(title="Select FASTA File", filetypes=(("FASTA files", "*.fasta;*.fa"), ("All files", "*.*")))
        if fasta_file_path:
            fasta_file_label.config(text=f"Selected File: {fasta_file_path}")

    fasta_file_label = tk.Label(root, text="No file selected", font=text_font, bg='white', fg='black')
    fasta_file_label.pack(pady=5)

    select_file_button = tk.Button(root, text="Select FASTA File", command=select_fasta_file, font=button_font, bg='white', fg='black', activebackground='#DA5B2D', activeforeground='white')
    select_file_button.pack(pady=10)


    def on_run():
        if not fasta_file_path:
            messagebox.showerror("Input Error", "Please select a FASTA file.")
        else:
            try:
                run_button.config(state='disabled') 
                analysis_thread = threading.Thread(target=process_sequences)
                analysis_thread.start()
                root.after(100, check_analysis_thread, analysis_thread)
            except Exception as e:
                messagebox.showerror("Error", f"An error occurred:\n{e}")
                run_button.config(state='normal')  

    def on_enter(e):
        run_button['background'] = '#DA5B2D'  

    def on_leave(e):
        run_button['background'] = '#F37736'  

    run_button = tk.Button(root, text="Run Analysis", command=on_run, font=button_font, bg='white', fg='black', activebackground='white', activeforeground='white')
    run_button.bind("<Enter>", on_enter)
    run_button.bind("<Leave>", on_leave)
    run_button.pack(pady=20)

    tk.Label(root, text="Note: Only standard amino acid residues (ACDEFGHIKLMNPQRSTVWY) are accepted.", font=text_font, fg='red', bg='white').pack(pady=5)

    root.mainloop()


def check_analysis_thread(thread):
    if thread.is_alive():
        root.after(100, check_analysis_thread, thread)
    else:
        run_button.config(state='normal')  


def process_sequences():
    try:
        sequences = read_fasta_file(fasta_file_path)
        total_sequences = len(sequences)
        progress_window = ProgressWindow(root)
        progress_window.label1.config(text="Processing sequences")
        for index, (seq_name, sequence) in enumerate(sequences):
            progress_window.window.after(0, progress_window.update_label2, f"Processing sequence {index+1}/{total_sequences}: {seq_name}")
            job_name = sanitize_filename(seq_name)
            run_analysis(sequence, job_name, root, progress_window)
            progress_window.update_progress1(0, 1)
            progress_window.update_progress2(0, 1)
        progress_window.close()
        messagebox.showinfo("Analysis Complete", "Results are saved in the respective directories.\n Have a great day! :).")
    except Exception as e:
        messagebox.showerror("Error", f"An error occurred:\n{e}")
    finally:
        run_button.config(state='normal')  

if __name__ == "__main__":
    create_gui()