# LSTM Predict Confusion Matrix
這個檔案我想要分析的是我的 supervised learning model 他預測的結果，會依據 testing data 的 ground truth 以及預測的結果畫出 confusion matrix。

In [None]:
import sys
sys.path.insert(1, '../src')

import argparse
import logging
import os
from mypredictor import Predictor
from metric import Metric
import csv
import json
import pickle
from preprocess import Embedding, CSDataset
import torch
from sklearn.metrics import confusion_matrix
import seaborn as sn
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [None]:
model_dir = os.path.join("..","model","moreData_clean_256")
epoch = 3
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')

model_path = os.path.join(model_dir, "model.pkl.{}".format(epoch))
config_path = os.path.join(model_dir, "config.json")

logging.info('Loading configuration file from {}'.format(config_path))
with open(config_path) as f:
    config = json.load(f)

embedding_pkl_path = os.path.join(model_dir, config["embedding_pkl_path"])
val_pkl_path = os.path.join(model_dir, config["val_pkl_path"])
labelEncoder_path = os.path.join(model_dir, config["labelEncoder_path"])
with open(embedding_pkl_path, "rb") as f:
    config["model_parameters"]["embedding"] = pickle.load(f).vectors
    logging.info( "Load embedding from {}".format(embedding_pkl_path))
with open(val_pkl_path, "rb") as f:
    config["model_parameters"]["valid"] = pickle.load(f)
    logging.info( "Load val from {}".format(val_pkl_path))
with open(labelEncoder_path, "rb") as f:
    config["model_parameters"]["labelEncoder"] = pickle.load(f)
    logging.info( "Load labelEncoder from {}".format(labelEncoder_path))

le = config["model_parameters"]["labelEncoder"]
num_classes = len(le.classes_)
class_list = list(le.classes_)

print(class_list)

predictor = Predictor(metric=Metric(), **config["model_parameters"])
predictor.load(model_path)


logging.info("Loading testing data.")
#with open(args.test_data_path, "rb") as f:
#    pass
valid = config["model_parameters"]["valid"] # CSDataset
test = valid


logging.info("Predicting...")
predicts, solution = predictor.predict_dataset(test, test.collate_fn)


In [None]:
def analysis(predicts, gt, labels):
    matrix = confusion_matrix(gt, predicts, labels=labels)
    df = pd.DataFrame(matrix, columns=labels, index=labels)
    #print(df.head())
    figsize = (80, 25)
    plt.figure(figsize=figsize)
    plt.title('LSTM Confusion Matrix', y=1.03, fontsize = 25)
    #cmap = sn.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)
    heatmap = sn.heatmap(df, annot=True, annot_kws={"size": 16}) # , cmap=cmap
    plt.ylabel('Ground Truth', fontsize = 20)
    plt.xlabel('Prediction', fontsize = 20)
    heatmap.set_xticklabels(heatmap.get_xticklabels(), rotation=45, horizontalalignment="right")
    plt.savefig('ConfusionMatrix_LSTM.png', bbox_inches = "tight")
    
    normalized_matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis] # normalize
    normalized_matrix = np.nan_to_num(normalized_matrix).round(2)
    df = pd.DataFrame(normalized_matrix, columns=labels, index=labels)
    
    plt.figure(figsize=figsize)
    plt.title('Normalized LSTM Confusion Matrix', y=1.03, fontsize = 25)
    #cmap = sn.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)
    heatmap = sn.heatmap(df, annot=True, annot_kws={"size": 16}) # , cmap=cmap
    plt.ylabel('Ground Truth', fontsize = 20)
    plt.xlabel('Prediction', fontsize = 20)
    heatmap.set_xticklabels(heatmap.get_xticklabels(), rotation=45, horizontalalignment="right")
    plt.savefig('Normalized_ConfusionMatrix_LSTM.png', bbox_inches = "tight")


    n = 0
    n_correct = 0
    '''
    for i in range(len(gt)):
        n += 1
        if gt[i][maxindex[i]] == 1:
            n_correct += 1
    print("Accuracy: {}".format(n_correct / n))
    print(len(matrix), len(matrix[0]))
    '''

In [None]:
predicts_index = torch.argmax(predicts, dim=1)
gt_index = torch.argmax(solution, dim=1)
predicts_str = list(le.inverse_transform(predicts_index))
gt_str = list(le.inverse_transform(gt_index))
    
analysis(predicts_str, gt_str, labels=class_list)