forked from thuzax/ia-knn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
80 lines (63 loc) · 2.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from CalculadoraEstatisticas import CalculadoraEstatisticas
from MatrizConfusao import MatrizConfusao
from MatrizAcertos import MatrizAcertos
from LeitorArquivo import LeitorArquivo, LeitorArquivoComTestes
from KNN import KNN
from prettytable import PrettyTable
import sys
def main():
if len(sys.argv) < 4:
print("Número inválido de argumentos. Abortando")
print("Passe: \n 1 - Caminho do arquivo \n 2 - Porcentagem para usar de teste \n 3 - Números de K")
print("Ex: python3 main.py spambase/spambase.data 0.1 1 3 5 7 9 \n")
sys.exit()
dados = LeitorArquivoComTestes(sys.argv[1], float(sys.argv[2]))
knn = KNN(dados[0])
k_para_teste = sys.argv[3:]
for index, k in enumerate(k_para_teste):
k_para_teste[index] = int(k)
melhor_k = -1
melhor_acuracia = -1
todasEstatisticas = {}
for k in k_para_teste:
print("\n\n************* TESTANDO PARA K = " + str(k) + "**********************\n\n")
resultado_knn = knn.run(dados[1], k)
resultados_esperados = []
for dado in dados[1]:
resultados_esperados.append(dado[-1])
matrizConfusao = MatrizConfusao(dados[2])
matrizConfusao.geraMatriz(resultado_knn, resultados_esperados)
print(matrizConfusao)
calculadoraEstatisticas = CalculadoraEstatisticas(matrizConfusao)
acuracia = calculadoraEstatisticas.calculateOverralAccuracy()
if acuracia >= melhor_acuracia:
melhor_acuracia = acuracia
melhor_k = k
todasEstatisticas[k] = calculadoraEstatisticas
# print(calculadoraEstatisticas)
print("\n\n************* TÉRMINO DO TESTE PARA K = " + str(k) + "**********************\n\n")
escreveTabelaTodasEstatisticas(todasEstatisticas)
print("K da melhor acurácia: " + str(melhor_k))
print("Melhor acurácia: " + "{:.3f}".format(melhor_acuracia))
def escreveTabelaTodasEstatisticas(todasEstatisticas):
saida = PrettyTable()
header = ["", "Dados das classes", "Accuracy (Overall)"]
saida.field_names = header
for k in todasEstatisticas:
tabela = PrettyTable()
header_interno = ["", "Recall", "Specificity", "Precisao", "F-Score", "Acuracia"]
tabela.field_names = header_interno
classesRating = todasEstatisticas[k].getClassesRatings()
for classe in classesRating:
row = []
row.append(classe["nome"])
row.append("{:.3f}".format(classe["truePositive"]))
row.append("{:.3f}".format(classe["trueNegative"]))
row.append("{:.3f}".format(classe["precision"]))
row.append("{:.3f}".format(classe["fScore"]))
row.append("{:.3f}".format(classe["accuracy"]))
tabela.add_row(row)
row_saida = [str(k), tabela, "{:.3f}".format(todasEstatisticas[k].calculateOverralAccuracy())]
saida.add_row(row_saida)
print(saida)
main()