In [47]:
import pandas as pd
import os
from typing import List
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, hamming_loss, make_scorer, f1_score
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import resample
from sklearn.decomposition import PCA
import numpy as np
from matplotlib.colors import ListedColormap
import seaborn as sns
import re
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier

In [2]:
go_file = "/Users/kajolpatel/Desktop/Individual_Project/poc/dataset/go-basic.obo"

In [3]:
def parse_obo_file(file_path):
    
    data = []
    current_term = {}
    in_term_block = False
    
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line == '[Term]':  #starting a new term block
                if current_term:
                    data.append(current_term)
                current_term = {}
                in_term_block = True
            elif line == '':
                in_term_block = False  #end of a term block
            elif in_term_block:
                if ': ' in line:
                    key, value = line.split(': ', 1)
                    if key in current_term:  #handling multiple lines of the same key
                        if isinstance(current_term[key], list):
                            current_term[key].append(value)
                        else:
                            current_term[key] = [current_term[key], value]
                    else:
                        current_term[key] = value

    
    if current_term: #add the last term if file does not end with a newline
        data.append(current_term)

    return pd.DataFrame(data)

df = parse_obo_file(go_file)


In [4]:
df = df.rename(columns={'def': 'definition'}) 

In [5]:
df.shape

(47856, 14)

In [6]:
df.head(3)

Unnamed: 0,id,name,namespace,definition,synonym,is_a,alt_id,comment,is_obsolete,replaced_by,consider,xref,subset,relationship
0,GO:0000001,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","""mitochondrial inheritance"" EXACT []","[GO:0048308 ! organelle inheritance, GO:004831...",,,,,,,,
1,GO:0000002,mitochondrial genome maintenance,biological_process,"""The maintenance of the structure and integrit...",,GO:0007005 ! mitochondrion organization,,,,,,,,
2,GO:0000003,obsolete reproduction,biological_process,"""OBSOLETE. The production of new individuals t...","""reproductive physiological process"" EXACT []",,"[GO:0019952, GO:0050876]",The reason for obsoletion is that this term is...,True,GO:0022414,,,,


In [7]:
df.iloc[0]['is_a']

['GO:0048308 ! organelle inheritance',
 'GO:0048311 ! mitochondrion distribution']

In [8]:
pd.reset_option('display.max_rows')
pd.reset_option('display.max_columns')


In [9]:
print(df[df['is_a'].isna() == False].shape)
print(df[df['is_a'].isna() == True].shape)

(42200, 14)
(5656, 14)


Total records = 47856

42200 records have is_a relationship present

5656 records do not have is_a relationship present

### Data Preprocessing

#### 1. Excluding the records which do not have is_a

In [10]:
df = df[df['is_a'].notna()]

In [11]:
df[df.is_obsolete.isna() == True]

Unnamed: 0,id,name,namespace,definition,synonym,is_a,alt_id,comment,is_obsolete,replaced_by,consider,xref,subset,relationship
0,GO:0000001,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","""mitochondrial inheritance"" EXACT []","[GO:0048308 ! organelle inheritance, GO:004831...",,,,,,,,
1,GO:0000002,mitochondrial genome maintenance,biological_process,"""The maintenance of the structure and integrit...",,GO:0007005 ! mitochondrion organization,,,,,,,,
4,GO:0000006,high-affinity zinc transmembrane transporter a...,molecular_function,"""Enables the transfer of zinc ions (Zn2+) from...","[""high affinity zinc uptake transmembrane tran...",GO:0005385 ! zinc ion transmembrane transporte...,,,,,,,,
5,GO:0000007,low-affinity zinc ion transmembrane transporte...,molecular_function,"""Enables the transfer of a solute or solutes f...",,GO:0005385 ! zinc ion transmembrane transporte...,,,,,,,,
7,GO:0000009,"alpha-1,6-mannosyltransferase activity",molecular_function,"""Catalysis of the transfer of a mannose residu...","""1,6-alpha-mannosyltransferase activity"" EXACT []",GO:0000030 ! mannosyltransferase activity,,,,,,"Reactome:R-HSA-449718 ""Addition of a third man...",,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47851,GO:2001313,UDP-4-deoxy-4-formamido-beta-L-arabinopyranose...,biological_process,"""The chemical reactions and pathways involving...","""UDP-4-deoxy-4-formamido-beta-L-arabinopyranos...","[GO:0006040 ! amino sugar metabolic process, G...",,,,,,,,
47852,GO:2001314,UDP-4-deoxy-4-formamido-beta-L-arabinopyranose...,biological_process,"""The chemical reactions and pathways resulting...","[""UDP-4-deoxy-4-formamido-beta-L-arabinopyrano...",[GO:0009227 ! nucleotide-sugar catabolic proce...,,,,,,,,
47853,GO:2001315,UDP-4-deoxy-4-formamido-beta-L-arabinopyranose...,biological_process,"""The chemical reactions and pathways resulting...","[""UDP-4-deoxy-4-formamido-beta-L-arabinopyrano...",[GO:0009226 ! nucleotide-sugar biosynthetic pr...,,,,,,,,
47854,GO:2001316,kojic acid metabolic process,biological_process,"""The chemical reactions and pathways involving...","[""5-hydroxy-2-(hydroxymethyl)-4H-pyran-4-one m...",[GO:0034308 ! primary alcohol metabolic proces...,,,,,,,,


Checking how many unique values is_a has

In [12]:
exploded_df = df.explode('is_a')

In [13]:
exploded_df['is_a'].value_counts().head(10)

is_a
GO:0110165 ! cellular anatomical entity                                                                                                                                              431
GO:0016616 ! oxidoreductase activity, acting on the CH-OH group of donors, NAD or NADP as acceptor                                                                                   310
GO:0032991 ! protein-containing complex                                                                                                                                              277
GO:0016709 ! oxidoreductase activity, acting on paired donors, with incorporation or reduction of molecular oxygen, NAD(P)H as one donor, and incorporation of one atom of oxygen    261
GO:0016758 ! hexosyltransferase activity                                                                                                                                             208
GO:0048856 ! anatomical structure development                         

#### 15547 unique is_a values

Checking if a GO term can have more than 9 is_a values

In [14]:
df[df['is_a'].apply(lambda x: len(x) > 9 if isinstance(x, list) else False)]

Unnamed: 0,id,name,namespace,definition,synonym,is_a,alt_id,comment,is_obsolete,replaced_by,consider,xref,subset,relationship
39234,GO:0140872,viridicatumtoxin biosynthetic process,biological_process,"""The chemical reactions and pathways resulting...","[""viridicatumtoxin anabolism"" EXACT [], ""virid...","[GO:0030639 ! polyketide biosynthetic process,...",,,,,,MetaCyc:PWY-7659,,


#### 2. Converting the is_a values to only have GO term ids instead of names too :)

In [15]:
def extract_go_terms(s):
    go_terms = []

    if isinstance(s, list):
        for item in s:
            go_terms.extend(re.findall(r'GO:\d{7}', item))
    else:
        go_terms = re.findall(r'GO:\d{7}', s)
    return go_terms if len(go_terms) > 1 else (go_terms[0] if go_terms else None)

In [16]:
df.head(2)

Unnamed: 0,id,name,namespace,definition,synonym,is_a,alt_id,comment,is_obsolete,replaced_by,consider,xref,subset,relationship
0,GO:0000001,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","""mitochondrial inheritance"" EXACT []","[GO:0048308 ! organelle inheritance, GO:004831...",,,,,,,,
1,GO:0000002,mitochondrial genome maintenance,biological_process,"""The maintenance of the structure and integrit...",,GO:0007005 ! mitochondrion organization,,,,,,,,


In [17]:
df = df[['id','definition','is_a']]

In [18]:
df['is_a'] = df['is_a'].apply(extract_go_terms)

In [19]:
df

Unnamed: 0,id,definition,is_a
0,GO:0000001,"""The distribution of mitochondria, including t...","[GO:0048308, GO:0048311]"
1,GO:0000002,"""The maintenance of the structure and integrit...",GO:0007005
4,GO:0000006,"""Enables the transfer of zinc ions (Zn2+) from...",GO:0005385
5,GO:0000007,"""Enables the transfer of a solute or solutes f...",GO:0005385
7,GO:0000009,"""Catalysis of the transfer of a mannose residu...",GO:0000030
...,...,...,...
47851,GO:2001313,"""The chemical reactions and pathways involving...","[GO:0006040, GO:0006793, GO:0009225]"
47852,GO:2001314,"""The chemical reactions and pathways resulting...","[GO:0009227, GO:0046348, GO:2001313]"
47853,GO:2001315,"""The chemical reactions and pathways resulting...","[GO:0009226, GO:0046349, GO:2001313]"
47854,GO:2001316,"""The chemical reactions and pathways involving...","[GO:0034308, GO:0042180, GO:0120254]"


#### 3. Remove list from definitions

In [20]:
df['definition'] = df['definition'].str.replace(r' \[.*?\]$', '', regex=True)

In [21]:
df['definition'][0]

'"The distribution of mitochondria, including the mitochondrial genome, into daughter cells after mitosis or meiosis, mediated by interactions between mitochondria and the cytoskeleton."'

In [22]:
df

Unnamed: 0,id,definition,is_a
0,GO:0000001,"""The distribution of mitochondria, including t...","[GO:0048308, GO:0048311]"
1,GO:0000002,"""The maintenance of the structure and integrit...",GO:0007005
4,GO:0000006,"""Enables the transfer of zinc ions (Zn2+) from...",GO:0005385
5,GO:0000007,"""Enables the transfer of a solute or solutes f...",GO:0005385
7,GO:0000009,"""Catalysis of the transfer of a mannose residu...",GO:0000030
...,...,...,...
47851,GO:2001313,"""The chemical reactions and pathways involving...","[GO:0006040, GO:0006793, GO:0009225]"
47852,GO:2001314,"""The chemical reactions and pathways resulting...","[GO:0009227, GO:0046348, GO:2001313]"
47853,GO:2001315,"""The chemical reactions and pathways resulting...","[GO:0009226, GO:0046349, GO:2001313]"
47854,GO:2001316,"""The chemical reactions and pathways involving...","[GO:0034308, GO:0042180, GO:0120254]"


In [23]:
df['is_a'] = df['is_a'].apply(lambda x: x if isinstance(x, list) else [x])


In [70]:
df

Unnamed: 0,id,definition,is_a
0,GO:0000001,"""The distribution of mitochondria, including t...","[GO:0048308, GO:0048311]"
1,GO:0000002,"""The maintenance of the structure and integrit...",[GO:0007005]
4,GO:0000006,"""Enables the transfer of zinc ions (Zn2+) from...",[GO:0005385]
5,GO:0000007,"""Enables the transfer of a solute or solutes f...",[GO:0005385]
7,GO:0000009,"""Catalysis of the transfer of a mannose residu...",[GO:0000030]
...,...,...,...
47851,GO:2001313,"""The chemical reactions and pathways involving...","[GO:0006040, GO:0006793, GO:0009225]"
47852,GO:2001314,"""The chemical reactions and pathways resulting...","[GO:0009227, GO:0046348, GO:2001313]"
47853,GO:2001315,"""The chemical reactions and pathways resulting...","[GO:0009226, GO:0046349, GO:2001313]"
47854,GO:2001316,"""The chemical reactions and pathways involving...","[GO:0034308, GO:0042180, GO:0120254]"


### Selecting only the records which have is_a as the 10 most frequent values

In [68]:
exploded_df = df.explode('is_a')

In [25]:
exploded_df['is_a'].value_counts().head(10)

is_a
GO:0110165    431
GO:0016616    310
GO:0032991    277
GO:0016709    261
GO:0016758    208
GO:0048856    202
GO:0098797    180
GO:0140513    172
GO:0016747    153
GO:0003006    151
Name: count, dtype: int64

In [26]:
is_a_of_interest = ["GO:0110165","GO:0016616", "GO:0032991", "GO:0016709", "GO:0016758", "GO:0048856", "GO:0098797", "GO:0140513", "GO:0016747", "GO:0003006"]
filtered_df = df[df['is_a'].apply(lambda x: any(item in is_a_of_interest for item in (x if isinstance(x, list) else [x])))]

In [69]:
filtered_df

Unnamed: 0,id,definition,is_a
25,GO:0000030,"""Catalysis of the transfer of a mannosyl group...",[GO:0016758]
26,GO:0000031,"""Catalysis of the transfer of a mannosylphosph...",[GO:0016758]
88,GO:0000109,"""Any complex formed of proteins that act in nu...",[GO:0140513]
97,GO:0000118,"""A protein complex that possesses histone deac...","[GO:0140513, GO:1902494]"
98,GO:0000120,"""A transcription factor complex that acts at a...","[GO:0005667, GO:0140513]"
...,...,...,...
46507,GO:1990909,"""A multiprotein protein complex containing mem...",[GO:0032991]
46514,GO:1990916,"""The outermost layers of the spore wall, as de...",[GO:0110165]
46521,GO:1990923,"""A protein complex that is composed of at leas...",[GO:0032991]
46550,GO:1990957,"""A protein complex that is located at the cili...",[GO:0032991]


#### Converting the labels to vectors for 10 labels

In [27]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(filtered_df['is_a'])
y_df = pd.DataFrame(y, columns=mlb.classes_) 

In [28]:
y_df = y_df[is_a_of_interest]

In [29]:
y_df

Unnamed: 0,GO:0110165,GO:0016616,GO:0032991,GO:0016709,GO:0016758,GO:0048856,GO:0098797,GO:0140513,GO:0016747,GO:0003006
0,0,0,0,0,1,0,0,0,0,0
1,0,0,0,0,1,0,0,0,0,0
2,0,0,0,0,0,0,0,1,0,0
3,0,0,0,0,0,0,0,1,0,0
4,0,0,0,0,0,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...
2325,0,0,1,0,0,0,0,0,0,0
2326,1,0,0,0,0,0,0,0,0,0
2327,0,0,1,0,0,0,0,0,0,0
2328,0,0,1,0,0,0,0,0,0,0


In [30]:
y_df.columns

Index(['GO:0110165', 'GO:0016616', 'GO:0032991', 'GO:0016709', 'GO:0016758',
       'GO:0048856', 'GO:0098797', 'GO:0140513', 'GO:0016747', 'GO:0003006'],
      dtype='object')

#### Converting definition attribute to feature vectors for 10 labels

Excluding the words which appear in less than 1% of definitions

In [31]:
vectorizer = CountVectorizer(stop_words='english', min_df = 0.01)
X_tfidf = vectorizer.fit_transform(filtered_df['definition'])
X_df = pd.DataFrame(X_tfidf.toarray(), columns=vectorizer.get_feature_names_out())


In [32]:
X_df

Unnamed: 0,11,12,acceptor,acetyl,acid,actin,activation,activity,acyl,alpha,...,time,tissue,transcription,transfer,transmembrane,transport,type,udp,wall,yeast
0,0,0,1,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2325,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2326,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
2327,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2328,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


#### 7. Normalised dataset

In [33]:
scaler = StandardScaler()

In [34]:
X_normalised_df = scaler.fit_transform(X_df)
X_normalised_df = pd.DataFrame(X_normalised_df, columns = X_df.columns)

In [35]:
X_normalised_df

Unnamed: 0,11,12,acceptor,acetyl,acid,actin,activation,activity,acyl,alpha,...,time,tissue,transcription,transfer,transmembrane,transport,type,udp,wall,yeast
0,-0.101002,-0.100096,7.569045,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,7.208096,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
1,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,7.208096,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
2,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
3,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,5.999450,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
4,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,5.649277,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2325,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
2326,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,5.687965,-0.124526
2327,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526
2328,-0.101002,-0.100096,-0.105400,-0.124494,-0.1092,-0.122618,-0.114208,-0.150448,-0.104787,-0.269993,...,-0.300539,-0.102018,-0.121357,-0.132316,-0.118212,-0.113109,-0.139853,-0.262318,-0.129839,-0.124526


### Random Forests

In [71]:
X_train, X_test, y_train, y_test = train_test_split(X_df, y_df, test_size=0.2, random_state=42)

In [138]:
rf_model = RandomForestClassifier(random_state=42, class_weight='balanced')
rf_model.fit(X_train, y_train)
y_pred_rf = rf_model.predict(X_test)
print("F1-Score:", str(f1_score(y_test, y_pred_rf, average = 'micro')))
print(classification_report(y_test, y_pred_rf, zero_division=0))

F1-Score: 0.8693743139407245
              precision    recall  f1-score   support

           0       0.90      0.88      0.89        98
           1       1.00      1.00      1.00        76
           2       0.70      0.67      0.69        46
           3       1.00      0.96      0.98        50
           4       0.97      0.91      0.94        35
           5       0.96      0.92      0.94        53
           6       1.00      0.50      0.67        32
           7       0.71      0.63      0.67        27
           8       1.00      0.94      0.97        16
           9       0.70      0.68      0.69        38

   micro avg       0.90      0.84      0.87       471
   macro avg       0.89      0.81      0.84       471
weighted avg       0.90      0.84      0.87       471
 samples avg       0.83      0.84      0.84       471



Random Forest Optimization

In [132]:
rf_model = RandomForestClassifier(random_state=42, class_weight='balanced')
param_grid = {
    'n_estimators': list(range(20,300,50)),  # Number of trees in the forest
    'max_depth': list(range(3,60,3)),  # Maximum depth of the tree
    'min_samples_split': [2, 5, 10, 20, 30],  # Minimum number of samples required to split an internal node
    'min_samples_leaf': [1, 2, 4, 10]     # Minimum number of samples required to be at a leaf node
}
scoring = {'F1-Score': make_scorer(f1_score, average='micro', needs_proba=False)}
grid_search = GridSearchCV(rf_model, param_grid, cv=5, scoring=scoring, refit='F1-Score', n_jobs=-1)

# Perform the search
grid_search.fit(X_train, y_train)

# Best parameter set
print('Best parameters found: \n', grid_search.best_params_)

Best parameters found: 
 {'max_depth': 30, 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 220}


In [133]:
grid_search.best_estimator_

Testing tuned parameters

In [137]:
rf_model = RandomForestClassifier(class_weight='balanced', max_depth=30, n_estimators=220, random_state=42)
rf_model.fit(X_train, y_train)
y_pred_rf = rf_model.predict(X_test)
print("F1-Score:", str(f1_score(y_test, y_pred_rf, average = 'micro')))
print(classification_report(y_test, y_pred_rf, zero_division=0))

F1-Score: 0.8726467331118495
              precision    recall  f1-score   support

           0       0.90      0.87      0.89        98
           1       1.00      0.99      0.99        76
           2       0.78      0.70      0.74        46
           3       1.00      0.96      0.98        50
           4       0.97      0.91      0.94        35
           5       0.96      0.92      0.94        53
           6       1.00      0.50      0.67        32
           7       0.70      0.52      0.60        27
           8       1.00      0.94      0.97        16
           9       0.72      0.74      0.73        38

   micro avg       0.91      0.84      0.87       471
   macro avg       0.90      0.80      0.84       471
weighted avg       0.91      0.84      0.87       471
 samples avg       0.83      0.84      0.83       471



#### Support Vector Machines

In [145]:
X_train, X_test, y_train, y_test = train_test_split(X_df, y_df, test_size=0.2, random_state=42)

In [147]:
svm_model = OneVsRestClassifier(SVC(random_state=42, class_weight='balanced'))
svm_model.fit(X_train, y_train)
y_pred_svm = svm_model.predict(X_test)
print(f1_score(y_test, y_pred_svm, average='micro'))
print(classification_report(y_test, y_pred_svm, zero_division=0))

0.7912713472485767
              precision    recall  f1-score   support

           0       0.86      0.90      0.88        98
           1       0.99      0.97      0.98        76
           2       0.48      0.78      0.60        46
           3       1.00      0.94      0.97        50
           4       0.87      0.97      0.92        35
           5       0.86      0.94      0.90        53
           6       0.92      0.69      0.79        32
           7       0.35      0.74      0.48        27
           8       0.54      0.94      0.68        16
           9       0.40      0.82      0.53        38

   micro avg       0.72      0.89      0.79       471
   macro avg       0.73      0.87      0.77       471
weighted avg       0.79      0.89      0.82       471
 samples avg       0.76      0.88      0.80       471



SVM Optimization

In [159]:
# Define the SVM model wrapped in OneVsRestClassifier
svm_model = OneVsRestClassifier(SVC(random_state=42, class_weight='balanced'))

param_grid = {
    'estimator__C': [0.1, 1, 10, 0.01, 0.5, 5, 50],
    'estimator__kernel': ['linear', 'rbf', 'poly'],
    'estimator__gamma': ['scale', 'auto',0.001, 0.01, 0.1, 1, 10]
}

scoring = {'F1-Score': make_scorer(f1_score, average='micro')}
grid_search = GridSearchCV(svm_model, param_grid, scoring=scoring, refit='F1-Score', cv=5, n_jobs=-1)

# Perform the search
grid_search.fit(X_train, y_train)

# Best parameter set
print('Best parameters found: \n', grid_search.best_params_)

Best parameters found: 
 {'estimator__C': 5, 'estimator__gamma': 'scale', 'estimator__kernel': 'rbf'}


Testing tuned parameters

In [160]:
svm_model = OneVsRestClassifier(SVC(C= 5, gamma = 'scale', random_state=42, class_weight='balanced'))
svm_model.fit(X_train, y_train)
y_pred_svm = svm_model.predict(X_test)
print(f1_score(y_test, y_pred_svm, average='micro'))
print(classification_report(y_test, y_pred_svm, zero_division=0))


0.8355739400206825
              precision    recall  f1-score   support

           0       0.88      0.88      0.88        98
           1       1.00      0.99      0.99        76
           2       0.62      0.72      0.67        46
           3       0.98      0.94      0.96        50
           4       0.92      0.97      0.94        35
           5       0.93      0.94      0.93        53
           6       1.00      0.53      0.69        32
           7       0.49      0.67      0.56        27
           8       0.71      0.94      0.81        16
           9       0.52      0.76      0.62        38

   micro avg       0.81      0.86      0.84       471
   macro avg       0.80      0.83      0.81       471
weighted avg       0.84      0.86      0.84       471
 samples avg       0.80      0.86      0.82       471

