In [12]:
from sklearn.metrics import make_scorer
import numpy as np

# Define cost matrix
num_classes = 20
cost_matrix = [[0]*num_classes for i in range(num_classes)]

for i in range(num_classes):
    for j in range(num_classes):
        cost_matrix[i][j] = abs(i-j)

def custom_scorer(y_true, y_pred):
    """
    Calculates the cost of classifying class i as class j.
    :param y_true: true labels
    :param y_pred: predicted labels
    :return: total cost
    """
    cost = 0
    for i in range(len(y_true)):
        cost += cost_matrix[y_true[i]][y_pred[i]]
    return cost

# Create scorer
scorer = make_scorer(custom_scorer, greater_is_better=False)

In [13]:
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification

# Generate sample data
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_classes=20, random_state=42, n_clusters_per_class=1)

# Define SVM model
svm = SVC()

# Define parameter grid for grid search
param_grid = {'C': np.logspace(-3, 3, 10), 'kernel': ['linear', 'rbf']}

# Perform grid search using custom scorer
grid_search = GridSearchCV(svm, param_grid, scoring=scorer)
grid_search.fit(X, y)

# Print best parameters and score
print(grid_search.best_params_)
print(grid_search.best_score_)

{'C': 2.154434690031882, 'kernel': 'rbf'}
-604.8


In [14]:
num_classes = 20
cost_matrix = [[0]*num_classes for i in range(num_classes)]

for i in range(num_classes):
    for j in range(num_classes):
        cost_matrix[i][j] = abs(i-j)
        
print(cost_matrix)

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8], [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7], [13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6], [14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5],