In [23]:
import pandas as pd
import numpy as np
from scipy.stats import multivariate_normal

In [25]:
data = pd.read_csv('naive_bayes_two_classes.csv')

In [26]:
num_classes = data['class'].nunique()
print("\nNúmero de clases en los datos:", num_classes)


Número de clases en los datos: 2


In [27]:
class_stats = {}
for class_label, class_data in data.groupby('class'):
    class_stats[class_label] = {
        'media': class_data[['x1', 'x2']].mean(),
        'covarianza': class_data[['x1', 'x2']].cov(),
        'probabilidad_apriori': len(class_data) / len(data)
    }

In [28]:
for class_label, stats in class_stats.items():
    print("\nClase:", class_label)
    print("Media:")
    print(stats['media'])
    print("\nMatriz de Covarianza:")
    print(stats['covarianza'])
    print("\nProbabilidad a priori:", stats['probabilidad_apriori'])


Clase: 1
Media:
x1    1.963530
x2    1.993345
dtype: float64

Matriz de Covarianza:
          x1        x2
x1  1.035336 -0.006257
x2 -0.006257  0.977647

Probabilidad a priori: 0.3333333333333333

Clase: 2
Media:
x1    9.965347
x2    9.985706
dtype: float64

Matriz de Covarianza:
          x1        x2
x1  1.057285  0.006016
x2  0.006016  0.976375

Probabilidad a priori: 0.6666666666666666


In [29]:
pdf_functions = {}
for class_label, stats in class_stats.items():
    mean = stats['media'].values
    cov = stats['covarianza'].values
    pdf_functions[class_label] = multivariate_normal(mean=mean, cov=cov).pdf

In [30]:
test_data = pd.read_csv('naive_bayes_two_classes_prediction.csv')

In [32]:
predictions = []
for index, row in test_data.iterrows():
    class_probabilities = {}
    for class_label, pdf_function in pdf_functions.items():
        class_probabilities[class_label] = pdf_function([row['x1'], row['x2']]) * class_stats[class_label]['probabilidad_apriori']
    predicted_class = max(class_probabilities, key=class_probabilities.get)
    predictions.append(predicted_class)

In [33]:
test_data['predicted_class'] = predictions

In [34]:
test_data.to_csv('naive_bayes_predictions.csv', index=False)