Sequence of code: Parse FASTA → engineer biochemical and biophysical features → clean dataset -> train multiple ML models → evaluate → deploy interactive dashboard where users can predict protein localization.


In [81]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import gzip
import joblib
import streamlit as st
from sklearn.utils.class_weight import compute_sample_weight
from Bio import SeqIO
from Bio.SeqUtils.ProtParam import ProteinAnalysis



In [82]:
# load fasta.gz file
fasta_data = "uniprot_sprot.fasta.gz"

# parse fasta sequences
sequences = [] 
for i in SeqIO.parse(gzip.open(fasta_data, 'rt'), 'fasta'): #open in read text mode 
    sequences.append(i)
 

print(f"Total sequences: {len(sequences)}")
print(sequences[0].id)       # header
print(sequences[0].description)  
print(str(sequences[0].seq))     # amino acid sequence

Total sequences: 573661
sp|Q6GZX4|001R_FRG3G
sp|Q6GZX4|001R_FRG3G Putative transcription factor 001R OS=Frog virus 3 (isolate Goorha) OX=654924 GN=FV3-001R PE=4 SV=1
MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL


In [83]:
#extract protein sequences and localization 
data = [] #empty list to collect data 
for i in sequences: # loop through each parsed record, each i is a biopython SeqRecord object
    header = i.description
    seq = str(i.seq)
    
    localization = header.split('|')[-1]  
    
    data.append({
        "protein_id": i.id,
        "sequence": seq,
        "localization": localization
    })

protein_info = pd.DataFrame(data)

In [84]:
# feature engineering, raw aa sequences -> numerical vectors describing biochemical and physical properties 

# amino acids
valid_aa = set("ACDEFGHIKLMNPQRSTVWY")

#clean the sequences
def clean_sequence(seq):
    return "".join([aa for aa in seq.upper() if aa in valid_aa])

#feature engineering
def sequence_features(seq):
    seq = clean_sequence(seq)
    
    if len(seq) == 0:
        # return zeros if sequence is empty after cleaning
        return [0]*20 + [0]*10  # 20 aa comp + 10 physchem features

    pa = ProteinAnalysis(seq)
    
    # amino acid composition
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
    aa_comp = [seq.count(aa)/len(seq) for aa in amino_acids]

    # physicochemical properties
    length = len(seq)
    molecular_weight = pa.molecular_weight()
    isoelectric_point = pa.isoelectric_point()
    aromaticity = pa.aromaticity()
    instability_index = pa.instability_index()
    gravy = pa.gravy()
    
    # group features
    polar = sum(seq.count(aa) for aa in "DEKRQN") / length
    nonpolar = sum(seq.count(aa) for aa in "AVLIFMW") / length
    charged = sum(seq.count(aa) for aa in "DEKR") / length
    hydrophobic = sum(seq.count(aa) for aa in "AILMFWV") / length
    
    return aa_comp + [length, molecular_weight, isoelectric_point, aromaticity,
                      instability_index, gravy, polar, nonpolar, charged, hydrophobic]

# apply to all sequences
feature_array = np.array([sequence_features(seq) for seq in protein_info['sequence']])

array([[0.05078125, 0.015625  , 0.06640625, ..., 0.34765625, 0.296875  ,
        0.34765625],
       [0.08125   , 0.05625   , 0.075     , ..., 0.296875  , 0.196875  ,
        0.296875  ],
       [0.03275109, 0.01310044, 0.10262009, ..., 0.31004367, 0.29475983,
        0.31004367],
       ...,
       [0.06315789, 0.09473684, 0.03157895, ..., 0.28421053, 0.21052632,
        0.28421053],
       [0.02105263, 0.07368421, 0.04210526, ..., 0.24210526, 0.24210526,
        0.24210526],
       [0.03157895, 0.07368421, 0.04210526, ..., 0.28421053, 0.25263158,
        0.28421053]])

In [85]:
#assign labels to feature array 
aa_cols = list("ACDEFGHIKLMNPQRSTVWY")
physchem_cols = [ 
    "length", "molecular_weight", "isoelectric_point", "aromaticity",
    "instability_index", "gravy", "polar", "nonpolar", "charged", "hydrophobic"]

feature_df = pd.DataFrame(feature_array, columns=aa_cols + physchem_cols)

Unnamed: 0,A,C,D,E,F,G,H,I,K,L,...,length,molecular_weight,isoelectric_point,aromaticity,instability_index,gravy,polar,nonpolar,charged,hydrophobic
0,0.050781,0.015625,0.066406,0.058594,0.031250,0.058594,0.035156,0.046875,0.113281,0.097656,...,256.0,29735.1007,9.370173,0.101562,31.205078,-0.538672,0.363281,0.347656,0.296875,0.347656
1,0.081250,0.056250,0.075000,0.018750,0.028125,0.071875,0.009375,0.031250,0.053125,0.043750,...,320.0,34642.0562,8.046315,0.081250,27.386562,-0.402187,0.265625,0.296875,0.196875,0.296875
2,0.032751,0.013100,0.102620,0.098253,0.048035,0.034934,0.021834,0.052402,0.045852,0.074236,...,458.0,53920.8273,4.383389,0.126638,52.245000,-0.773799,0.377729,0.310044,0.294760,0.310044
3,0.064103,0.051282,0.032051,0.025641,0.019231,0.057692,0.019231,0.038462,0.025641,0.076923,...,156.0,17043.1201,8.658056,0.083333,62.696795,-0.471795,0.224359,0.256410,0.141026,0.256410
4,0.066210,0.009132,0.066210,0.031963,0.045662,0.059361,0.022831,0.027397,0.043379,0.107306,...,438.0,48296.6452,6.369173,0.073059,41.416438,0.144749,0.257991,0.440639,0.189498,0.440639
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
573656,0.050000,0.080000,0.030000,0.060000,0.020000,0.040000,0.020000,0.070000,0.080000,0.080000,...,100.0,11122.7071,8.124709,0.060000,61.832000,-0.367000,0.320000,0.290000,0.200000,0.290000
573657,0.012658,0.037975,0.012658,0.037975,0.126582,0.037975,0.050633,0.101266,0.012658,0.139241,...,79.0,9127.7832,6.883111,0.164557,56.329241,0.807595,0.126582,0.493671,0.101266,0.493671
573658,0.063158,0.094737,0.031579,0.052632,0.021053,0.021053,0.021053,0.042105,0.052632,0.105263,...,95.0,10850.4013,8.631237,0.073684,56.357895,-0.529474,0.326316,0.284211,0.210526,0.284211
573659,0.021053,0.073684,0.042105,0.052632,0.021053,0.063158,0.021053,0.052632,0.063158,0.084211,...,95.0,10994.6309,8.904518,0.094737,44.261053,-0.617895,0.305263,0.242105,0.242105,0.242105


In [86]:
protein_features = pd.concat([protein_info, feature_df], axis=1)

Unnamed: 0,protein_id,sequence,localization,A,C,D,E,F,G,H,...,length,molecular_weight,isoelectric_point,aromaticity,instability_index,gravy,polar,nonpolar,charged,hydrophobic
0,sp|Q6GZX4|001R_FRG3G,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,001R_FRG3G Putative transcription factor 001R ...,0.050781,0.015625,0.066406,0.058594,0.031250,0.058594,0.035156,...,256.0,29735.1007,9.370173,0.101562,31.205078,-0.538672,0.363281,0.347656,0.296875,0.347656
1,sp|Q6GZX3|002L_FRG3G,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,002L_FRG3G Uncharacterized protein 002L OS=Fro...,0.081250,0.056250,0.075000,0.018750,0.028125,0.071875,0.009375,...,320.0,34642.0562,8.046315,0.081250,27.386562,-0.402187,0.265625,0.296875,0.196875,0.296875
2,sp|Q197F8|002R_IIV3,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,002R_IIV3 Uncharacterized protein 002R OS=Inve...,0.032751,0.013100,0.102620,0.098253,0.048035,0.034934,0.021834,...,458.0,53920.8273,4.383389,0.126638,52.245000,-0.773799,0.377729,0.310044,0.294760,0.310044
3,sp|Q197F7|003L_IIV3,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,003L_IIV3 Uncharacterized protein 003L OS=Inve...,0.064103,0.051282,0.032051,0.025641,0.019231,0.057692,0.019231,...,156.0,17043.1201,8.658056,0.083333,62.696795,-0.471795,0.224359,0.256410,0.141026,0.256410
4,sp|Q6GZX2|003R_FRG3G,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,003R_FRG3G Uncharacterized protein 3R OS=Frog ...,0.066210,0.009132,0.066210,0.031963,0.045662,0.059361,0.022831,...,438.0,48296.6452,6.369173,0.073059,41.416438,0.144749,0.257991,0.440639,0.189498,0.440639
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
573656,sp|Q6UY62|Z_SABVB,MGNSKSKSKLSANQYEQQTVNSTKQVAILKRQAEPSLYGRHNCRCC...,Z_SABVB RING finger protein Z OS=Sabia mammare...,0.050000,0.080000,0.030000,0.060000,0.020000,0.040000,0.020000,...,100.0,11122.7071,8.124709,0.060000,61.832000,-0.367000,0.320000,0.290000,0.200000,0.290000
573657,sp|P08105|Z_SHEEP,MSSSLEITSFYSFIWTPHIGPLLFGIGLWFSMFKEPSHFCPCQHPH...,Z_SHEEP Putative uncharacterized protein Z OS=...,0.012658,0.037975,0.012658,0.037975,0.126582,0.037975,0.050633,...,79.0,9127.7832,6.883111,0.164557,56.329241,0.807595,0.126582,0.493671,0.101266,0.493671
573658,sp|Q88470|Z_TACVF,MGNCNRTQKPSSSSNNLEKPPQAAEFRRTAEPSLYGRYNCKCCWFA...,Z_TACVF RING finger protein Z OS=Tacaribe viru...,0.063158,0.094737,0.031579,0.052632,0.021053,0.021053,0.021053,...,95.0,10850.4013,8.631237,0.073684,56.357895,-0.529474,0.326316,0.284211,0.210526,0.284211
573659,sp|A9JR22|Z_TAMVU,MGLRYSKEVRDRHGDKDPEGRIPITQTMPQTLYGRYNCKSCWFANK...,Z_TAMVU RING finger protein Z OS=Tamiami mamma...,0.021053,0.073684,0.042105,0.052632,0.021053,0.063158,0.021053,...,95.0,10994.6309,8.904518,0.094737,44.261053,-0.617895,0.305263,0.242105,0.242105,0.242105


In [87]:
#data cleaning for ML models

#extract accession (Q6GZX4) from the protein_id
protein_features["accession"] = protein_features["protein_id"].str.split("|").str[1]

# load the UniProt TSV metadata
uniprot_meta = pd.read_csv("uniprot_data.tsv", sep="\t")

# merge to bring in subcellular location
merged = protein_features.merge(uniprot_meta[["Entry", "Subcellular location [CC]"]], 
                            left_on="accession", right_on="Entry", how="left")


In [88]:
#filter out empty sequences 
merged = merged[merged['sequence'].str.len() > 0]

Unnamed: 0,protein_id,sequence,localization,A,C,D,E,F,G,H,...,aromaticity,instability_index,gravy,polar,nonpolar,charged,hydrophobic,accession,Entry,Subcellular location [CC]
0,sp|Q6GZX4|001R_FRG3G,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,001R_FRG3G Putative transcription factor 001R ...,0.050781,0.015625,0.066406,0.058594,0.031250,0.058594,0.035156,...,0.101562,31.205078,-0.538672,0.363281,0.347656,0.296875,0.347656,Q6GZX4,Q6GZX4,
1,sp|Q6GZX3|002L_FRG3G,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,002L_FRG3G Uncharacterized protein 002L OS=Fro...,0.081250,0.056250,0.075000,0.018750,0.028125,0.071875,0.009375,...,0.081250,27.386562,-0.402187,0.265625,0.296875,0.196875,0.296875,Q6GZX3,Q6GZX3,SUBCELLULAR LOCATION: Host membrane {ECO:00003...
2,sp|Q197F8|002R_IIV3,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,002R_IIV3 Uncharacterized protein 002R OS=Inve...,0.032751,0.013100,0.102620,0.098253,0.048035,0.034934,0.021834,...,0.126638,52.245000,-0.773799,0.377729,0.310044,0.294760,0.310044,Q197F8,Q197F8,
3,sp|Q197F7|003L_IIV3,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,003L_IIV3 Uncharacterized protein 003L OS=Inve...,0.064103,0.051282,0.032051,0.025641,0.019231,0.057692,0.019231,...,0.083333,62.696795,-0.471795,0.224359,0.256410,0.141026,0.256410,Q197F7,Q197F7,
4,sp|Q6GZX2|003R_FRG3G,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,003R_FRG3G Uncharacterized protein 3R OS=Frog ...,0.066210,0.009132,0.066210,0.031963,0.045662,0.059361,0.022831,...,0.073059,41.416438,0.144749,0.257991,0.440639,0.189498,0.440639,Q6GZX2,Q6GZX2,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
573656,sp|Q6UY62|Z_SABVB,MGNSKSKSKLSANQYEQQTVNSTKQVAILKRQAEPSLYGRHNCRCC...,Z_SABVB RING finger protein Z OS=Sabia mammare...,0.050000,0.080000,0.030000,0.060000,0.020000,0.040000,0.020000,...,0.060000,61.832000,-0.367000,0.320000,0.290000,0.200000,0.290000,Q6UY62,Q6UY62,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...
573657,sp|P08105|Z_SHEEP,MSSSLEITSFYSFIWTPHIGPLLFGIGLWFSMFKEPSHFCPCQHPH...,Z_SHEEP Putative uncharacterized protein Z OS=...,0.012658,0.037975,0.012658,0.037975,0.126582,0.037975,0.050633,...,0.164557,56.329241,0.807595,0.126582,0.493671,0.101266,0.493671,P08105,P08105,
573658,sp|Q88470|Z_TACVF,MGNCNRTQKPSSSSNNLEKPPQAAEFRRTAEPSLYGRYNCKCCWFA...,Z_TACVF RING finger protein Z OS=Tacaribe viru...,0.063158,0.094737,0.031579,0.052632,0.021053,0.021053,0.021053,...,0.073684,56.357895,-0.529474,0.326316,0.284211,0.210526,0.284211,Q88470,Q88470,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...
573659,sp|A9JR22|Z_TAMVU,MGLRYSKEVRDRHGDKDPEGRIPITQTMPQTLYGRYNCKSCWFANK...,Z_TAMVU RING finger protein Z OS=Tamiami mamma...,0.021053,0.073684,0.042105,0.052632,0.021053,0.063158,0.021053,...,0.094737,44.261053,-0.617895,0.305263,0.242105,0.242105,0.242105,A9JR22,A9JR22,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...


In [89]:
#clean subcellular location text from Uniprot 
import re
merged['clean_localization'] = merged['localization'].str.extract(r'(Mitochondrion|Nucleus|Secreted|Cytoplasm|Membrane|Endoplasmic reticulum)', expand=False)

def clean_location(text):
    if pd.isna(text):
        return np.nan
    
    # extract the main keyword(s)
    match = re.findall(r'SUBCELLULAR LOCATION:\s*([^.;]+)', text)
    if match:
        # clean punctuation and extra spaces
        loc = match[0].replace('{', '').replace('}', '').strip()
        return loc
    return np.nan

merged["clean_localization"] = merged["Subcellular location [CC]"].apply(clean_location)


#simplify location categories 
def simplify_loc(loc):
    if pd.isna(loc):
        return np.nan
    loc = loc.lower()
    if "nucleus" in loc: return "Nucleus"
    elif "mitochond" in loc: return "Mitochondrion"
    elif "cytoplasm" in loc: return "Cytoplasm"
    elif "membrane" in loc: return "Membrane"
    elif "secreted" in loc: return "Secreted"
    elif "virion" in loc: return "Virion"
    else: return "Other"

merged["loc_category"] = merged["clean_localization"].apply(simplify_loc)

In [90]:
ml_data = merged.dropna(subset=['loc_category'])
ml_data = ml_data.drop(['localization'], axis=1)

Unnamed: 0,protein_id,sequence,A,C,D,E,F,G,H,I,...,gravy,polar,nonpolar,charged,hydrophobic,accession,Entry,Subcellular location [CC],clean_localization,loc_category
1,sp|Q6GZX3|002L_FRG3G,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,0.081250,0.056250,0.075000,0.018750,0.028125,0.071875,0.009375,0.031250,...,-0.402187,0.265625,0.296875,0.196875,0.296875,Q6GZX3,Q6GZX3,SUBCELLULAR LOCATION: Host membrane {ECO:00003...,Host membrane ECO:0000305,Membrane
5,sp|Q6GZX1|004R_FRG3G,MNAKYDTDQGVGRMLFLGTIGLAVVVGGLMAYGYYYDGKTPSSGTS...,0.066667,0.000000,0.050000,0.000000,0.050000,0.150000,0.016667,0.016667,...,-0.153333,0.166667,0.316667,0.133333,0.316667,Q6GZX1,Q6GZX1,SUBCELLULAR LOCATION: Host membrane {ECO:00003...,Host membrane ECO:0000305,Membrane
15,sp|Q6GZW5|010R_FRG3G,MKMDTDCRHWIVLASVPVLTVLAFKGEGALALAGLLVMAAVAMYRD...,0.131387,0.007299,0.051095,0.036496,0.043796,0.094891,0.029197,0.014599,...,-0.021898,0.226277,0.430657,0.211679,0.430657,Q6GZW5,Q6GZW5,SUBCELLULAR LOCATION: Host membrane {ECO:00003...,Host membrane ECO:0000305,Membrane
17,sp|Q6GZW4|011R_FRG3G,MTSVKTIAMLAMLVIVAALIYMGYRTFTSMQSKLNELESRVNAPQL...,0.071429,0.000000,0.042857,0.071429,0.028571,0.014286,0.000000,0.071429,...,0.265714,0.271429,0.471429,0.200000,0.471429,Q6GZW4,Q6GZW4,SUBCELLULAR LOCATION: Host membrane {ECO:00003...,Host membrane ECO:0000305,Membrane
19,sp|Q197E7|013L_IIV3,MYYRDQYGNVKYAPEGMGPHHAASSSHHSAQHHHMTKENFSMDDVH...,0.077778,0.000000,0.044444,0.033333,0.066667,0.044444,0.100000,0.033333,...,-0.578889,0.244444,0.377778,0.177778,0.377778,Q197E7,Q197E7,SUBCELLULAR LOCATION: Host membrane {ECO:00003...,Host membrane ECO:0000305,Membrane
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
573655,sp|Q6RSS3|Z_PIRVV,MGLRYSKEVRERHGDKDLEGRVPMTLNLPQGLYGRFNCKSCWFANR...,0.031579,0.073684,0.031579,0.063158,0.042105,0.084211,0.021053,0.031579,...,-0.532632,0.315789,0.273684,0.242105,0.273684,Q6RSS3,Q6RSS3,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...,Virion ECO:0000255|HAMAP-Rule:MF_04087,Virion
573656,sp|Q6UY62|Z_SABVB,MGNSKSKSKLSANQYEQQTVNSTKQVAILKRQAEPSLYGRHNCRCC...,0.050000,0.080000,0.030000,0.060000,0.020000,0.040000,0.020000,0.070000,...,-0.367000,0.320000,0.290000,0.200000,0.290000,Q6UY62,Q6UY62,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...,Virion ECO:0000255|HAMAP-Rule:MF_04087,Virion
573658,sp|Q88470|Z_TACVF,MGNCNRTQKPSSSSNNLEKPPQAAEFRRTAEPSLYGRYNCKCCWFA...,0.063158,0.094737,0.031579,0.052632,0.021053,0.021053,0.021053,0.042105,...,-0.529474,0.326316,0.284211,0.210526,0.284211,Q88470,Q88470,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...,Virion ECO:0000255|HAMAP-Rule:MF_04087,Virion
573659,sp|A9JR22|Z_TAMVU,MGLRYSKEVRDRHGDKDPEGRIPITQTMPQTLYGRYNCKSCWFANK...,0.021053,0.073684,0.042105,0.052632,0.021053,0.063158,0.021053,0.052632,...,-0.617895,0.305263,0.242105,0.242105,0.242105,A9JR22,A9JR22,SUBCELLULAR LOCATION: Virion {ECO:0000255|HAMA...,Virion ECO:0000255|HAMAP-Rule:MF_04087,Virion


In [91]:
ml_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 360038 entries, 1 to 573660
Data columns (total 37 columns):
 #   Column                     Non-Null Count   Dtype  
---  ------                     --------------   -----  
 0   protein_id                 360038 non-null  object 
 1   sequence                   360038 non-null  object 
 2   A                          360038 non-null  float64
 3   C                          360038 non-null  float64
 4   D                          360038 non-null  float64
 5   E                          360038 non-null  float64
 6   F                          360038 non-null  float64
 7   G                          360038 non-null  float64
 8   H                          360038 non-null  float64
 9   I                          360038 non-null  float64
 10  K                          360038 non-null  float64
 11  L                          360038 non-null  float64
 12  M                          360038 non-null  float64
 13  N                          360038 

In [92]:
#scale numeric features and encode labels

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split

# separate features and labels
X = ml_data.select_dtypes(include=[np.number])  # all numeric columns
y = ml_data['loc_category']

# scale the numeric features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)


# encode the labels 
encoder = LabelEncoder()
y_encoded = encoder.fit_transform(y)

# split into train/test sets
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
)

# check
print("Feature matrix shape:", X_scaled.shape)
print("Training samples:", X_train.shape[0])
print("Test samples:", X_test.shape[0])
print("Classes:", encoder.classes_)


Feature matrix shape: (360038, 30)
Training samples: 288030
Test samples: 72008
Classes: ['Cytoplasm' 'Membrane' 'Mitochondrion' 'Nucleus' 'Other' 'Secreted'
 'Virion']


In [93]:
#define models for looping 
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier

# Define models
models = {
    "Logistic Regression": LogisticRegression(
        max_iter=200, 
        class_weight='balanced', 
        solver='saga', 
        n_jobs=-1
    ),
    "Random Forest": RandomForestClassifier(
        n_estimators=50, 
        class_weight='balanced', 
        max_depth=10, 
        n_jobs=-1, 
        random_state=42
    ),
    "XGBoost": XGBClassifier(
        objective='multi:softmax', 
        num_class=None,  # will be inferred automatically
        eval_metric='mlogloss',
        use_label_encoder=False,
        n_estimators=200,
        max_depth=6,
        learning_rate=0.1,
        random_state=42,
        n_jobs=-1
    ),
    "HistGradientBoosting":HistGradientBoostingClassifier(
    max_iter=200,
    max_depth=5,
    learning_rate=0.1,
    random_state=42
    ),
   }



In [94]:
#train and evaluate different models for best prediction metrics 
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, f1_score, classification_report
results = []

for name, model in models.items():
    # pipeline  
    pipe = Pipeline([
        ('clf', model)
    ])
    
    pipe.fit(X_train, y_train)
    y_pred = pipe.predict(X_test)
    
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='weighted')
    
    results.append((name, acc, f1))
    print(f"\n{name}")
    print(classification_report(y_test, y_pred))





Logistic Regression
              precision    recall  f1-score   support

           0       0.80      0.67      0.73     32076
           1       0.77      0.43      0.55     19772
           2       0.19      0.41      0.26      3066
           3       0.45      0.61      0.52      6365
           4       0.16      0.20      0.18      3840
           5       0.44      0.56      0.49      5970
           6       0.09      0.56      0.15       919

    accuracy                           0.55     72008
   macro avg       0.41      0.49      0.41     72008
weighted avg       0.66      0.55      0.59     72008


Random Forest
              precision    recall  f1-score   support

           0       0.80      0.72      0.76     32076
           1       0.90      0.49      0.64     19772
           2       0.34      0.51      0.41      3066
           3       0.42      0.73      0.53      6365
           4       0.28      0.45      0.35      3840
           5       0.64      0.66      0.6

Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)



XGBoost
              precision    recall  f1-score   support

           0       0.78      0.92      0.84     32076
           1       0.81      0.78      0.79     19772
           2       0.77      0.40      0.52      3066
           3       0.62      0.64      0.63      6365
           4       0.86      0.36      0.51      3840
           5       0.79      0.71      0.74      5970
           6       0.90      0.45      0.60       919

    accuracy                           0.78     72008
   macro avg       0.79      0.61      0.66     72008
weighted avg       0.78      0.78      0.77     72008


HistGradientBoosting
              precision    recall  f1-score   support

           0       0.79      0.92      0.85     32076
           1       0.82      0.77      0.80     19772
           2       0.73      0.42      0.54      3066
           3       0.63      0.64      0.64      6365
           4       0.83      0.38      0.52      3840
           5       0.79      0.71      0.75    

Now you know that XGBoost has better prediction accuracy and deployment efficency. Now we need to isolate XGBoost and start building an interactive dashboard for this model


In [95]:
#save and load the XGBoost model 
best_model = models['XGBoost']
sample_weights = compute_sample_weight('balanced', y_train)
best_model.fit(X_train, y_train, sample_weight=sample_weights)
joblib.dump(best_model, "xgboost_protein_localization.pkl")
joblib.dump(scaler, 'scaler.pkl')
joblib.dump(encoder, 'label_encoder.pkl')

Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)


['label_encoder.pkl']

In [97]:
#sample sequence for Streamlit
print(ml_data.iloc[2]['sequence'])


MKMDTDCRHWIVLASVPVLTVLAFKGEGALALAGLLVMAAVAMYRDRTEKKYSAARAPSPIAGHKTAYVTDPSAFAAGTVPVYPAPSNMGSDRFEGWVGGVLTGVGSSHLDHRKFAERQLVDRREKMVGYGWTKSFF
