confusion matrix & curves

In [None]:
# Confusion matrix and test accuracy
# Reference: https://stackoverflow.com/questions/53290306/confusion-matrix-and-test-accuracy-for-pytorch-transfer-learning-tutorial

def confusion_matrix(model_ft, nb_classes, phase="test"):
    """
    given a pytorch model (model_ft), calculate confusion matrix (a torch tensor object)

    Arguments
    ---------
    model_ft:     pytorch model from torchvision

    nb_classes:   number of classes

    phase:        get the confusion matrix for 'test' or 'validation'
    """
    nb_classes = nb_classes
    confusion_matrix = torch.zeros(nb_classes, nb_classes)
    with torch.no_grad():
        for i, (inputs, classes) in enumerate(dataloaders[phase]):
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model_ft(inputs)
            _, preds = torch.max(outputs, 1)
            for t, p in zip(classes.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
    return confusion_matrix

In [None]:
# plot the confusion_matrix
# Reference: https://stackoverflow.com/questions/39033880/plot-confusion-matrix-sklearn-with-multiple-labels

def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2, 3]
                  the class names, for example: ['edible M', 'edible MS', 'poisonous M', 'poisonous MS']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(int(cm[i, j])),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()

In [None]:
PATH = "entire_model_net152.pt"

# Load finetuned ResNet152 model
model152_ft = torch.load(PATH, map_location=torch.device(device))
model152_ft.eval()

# Confusion matrix for ResNet152
cm152 = confusion_matrix(model152_ft, 2)
# cols are predicted labels (left to right): edible, poisonous
# rows are true labels (left to right): edible, poisonous
print(cm152)

In [None]:
plot_confusion_matrix(cm152.numpy(), class_names, title='Confusion matrix for ResNet152', normalize=False)

In [None]:
PATH = "entire_model_net34.pt"

# Load finetuned ResNet34 model
model34_ft = torch.load(PATH, map_location=torch.device(device))
model34_ft.eval()

# Confusion matrix for ResNet34
cm34 = confusion_matrix(model34_ft, 2)
print(cm34)

In [None]:
plot_confusion_matrix(cm34.numpy(), class_names, title='Confusion matrix for ResNet34', normalize=False)

In [None]:
PATH = "entire_model_net18.pt"

# Load finetuned ResNet18 model
model18_ft = torch.load(PATH, map_location=torch.device(device))
model18_ft.eval()

# Confusion matrix for ResNet18
cm18 = confusion_matrix(model18_ft, 2)
print(cm18)

In [None]:
plot_confusion_matrix(cm18.numpy(), class_names, title='Confusion matrix for ResNet18', normalize=False)

In [None]:
# get probability instead of prediction from the model
# Reference: https://stackoverflow.com/questions/60182984/how-to-get-the-predict-probability

def predict_prob(model_ft, phase="test"):
    """
    given a pytorch model (model_ft), calculate probability for each class assignment (a torch tensor object)

    Arguments
    ---------
    model_ft:     pytorch model from torchvision


    phase:        get the confusion matrix for 'test' or 'validation'
    """
    import torch.nn.functional as nnf
    prob_predict = []
    targets = []

    with torch.no_grad():
        for i, (inputs, classes) in enumerate(dataloaders[phase]):
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model_ft(inputs)
            prob = nnf.softmax(outputs, dim=1)
            prob_predict.append(prob)
            targets.append(classes)

    return torch.vstack(prob_predict), torch.hstack(targets)

In [None]:
# get the predicted probabilities and the true labels from the test set
resnet152_prob, y_test = predict_prob(model152_ft)
precision_res152, recall_res152, _ = precision_recall_curve(y_test.numpy(), resnet152_prob.numpy()[:, 1])
# area under curve
pr_auc_res152 = auc(recall_res152, precision_res152)
# ROC
fpr_res152, tpr_res152, _ = roc_curve(y_test.numpy(), resnet152_prob.numpy()[:, 1])
# area under curve
roc_auc_res152 = auc(fpr_res152, tpr_res152)

In [None]:
# get the predicted probabilities and the true labels from the test set
resnet34_prob, y_test = predict_prob(model34_ft)
precision_res34, recall_res34, _ = precision_recall_curve(y_test.numpy(), resnet34_prob.numpy()[:, 1])
# area under curve
pr_auc_res34 = auc(recall_res34, precision_res34)
# ROC
fpr_res34, tpr_res34, _ = roc_curve(y_test.numpy(), resnet34_prob.numpy()[:, 1])
# area under curve
roc_auc_res34 = auc(fpr_res34, tpr_res34)

In [None]:
# get the predicted probabilities and the true labels from the test set
resnet18_prob, y_test = predict_prob(model18_ft)
precision_res18, recall_res18, _ = precision_recall_curve(y_test.numpy(), resnet18_prob.numpy()[:, 1])
# area under curve
pr_auc_res18 = auc(recall_res18, precision_res18)
# ROC
fpr_res18, tpr_res18, _ = roc_curve(y_test.numpy(), resnet18_prob.numpy()[:, 1])
# area under curve
roc_auc_res18 = auc(fpr_res18, tpr_res18)

In [None]:
# PR Curve
plt.figure(figsize=(8, 6))
plt.plot(recall_res152, precision_res152, label="ResNet152 (area = {0:0.3f})".format(pr_auc_res152))
plt.plot(recall_res34, precision_res34, label="ResNet34 (area = {0:0.3f})".format(pr_auc_res34))
plt.plot(recall_res18, precision_res18, label="ResNet18 (area = {0:0.3f})".format(pr_auc_res18))
# plt.plot(recall_lr, precision_lr, label="LR")
plt.plot(recall_knn, precision_knn, label="KNN (area = {0:0.3f})".format(pr_auc_knn))
plt.plot(recall_rf, precision_rf, label="RF (area = {0:0.3f})".format(pr_auc_rf))
plt.axhline(0.8, color='k', linestyle='--')
plt.axvline(0.95, color='k', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.legend(loc='lower right')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('PR Curves')
plt.show()

In [None]:
# ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr_res152, tpr_res152, label="ResNet152 (area = {0:0.3f})".format(roc_auc_res152))
plt.plot(fpr_res34, tpr_res34, label="ResNet34 (area = {0:0.3f})".format(roc_auc_res34))
plt.plot(fpr_res18, tpr_res18, label="ResNet18 (area = {0:0.3f})".format(roc_auc_res18))
plt.plot(fpr_knn, tpr_knn, label="KNN (area = {0:0.3f})".format(roc_auc_knn))
plt.plot(fpr_rf, tpr_rf, label="RF (area = {0:0.3f})".format(roc_auc_rf))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.legend(loc='lower right')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves')
plt.show()

layers

In [None]:
# we will save the conv layer weights in this list
model_weights =[]
#we will save the 49 conv layers in this list
conv_layers = []
# get all the model children as list
model_load = list(model_load.children())
#counter to keep count of the conv layers
counter = 0
#append all the conv layers and their respective weights to the list
for i in range(len(model_load)):
    if type(model_load[i]) == nn.Conv2d:
        counter+=1
        model_weights.append(model_load[i].weight)
        conv_layers.append(model_load[i])
    elif type(model_load[i]) == nn.Sequential:
        for j in range(len(model_load[i])):
            for child in model_load[i][j].children():
                if type(child) == nn.Conv2d:
                    counter+=1
                    model_weights.append(child.weight)
                    conv_layers.append(child)
print(f"Total convolution layers: {counter}")

In [None]:
data_dir = os.getcwd() + '/data2'
image_dir = data_dir+'/test/edible/ncvc (395).jpg'
from IPython.display import Image
Image(image_dir)

In [None]:
from PIL import Image
image = Image.open(image_dir)

transform_image = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [None]:
image = transform_image(image)
print(f"Image shape before: {image.shape}")
image = image.unsqueeze(0)
print(f"Image shape after: {image.shape}")
image = image.to(device)

In [None]:
outputs = []
names = []
for layer in conv_layers[0:]:
    image = layer(image)
    outputs.append(image)
    names.append(str(layer))

In [None]:
processed = []
for feature_map in outputs:
    feature_map = feature_map.squeeze(0)
    gray_scale = torch.sum(feature_map,0)
    gray_scale = gray_scale / feature_map.shape[0]
    processed.append(gray_scale.data.cpu().numpy())