In [3]:
import sys
sys.path.append('CLIP-dissect') 

In [4]:
import os
import torch
import torchvision
import torch.nn as nn
import numpy as np
import pandas as pd
import cv2
import glob
import random
import copy
from sklearn.metrics import precision_score, f1_score, recall_score, confusion_matrix

import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import matplotlib
import similarity
import utils
import data_utils
import clip
from sentence_transformers import SentenceTransformer

%matplotlib inline

In [5]:
dataset_path = "FER2013"
class_names = os.listdir(dataset_path+"/train")
print(class_names)

['disgust', 'happy', 'angry', 'neutral', 'sad', 'surprise', 'fear']


In [6]:
emotions = []
for file_name in glob.glob(dataset_path+'/train/*/*'):
    emotion = file_name.split('/')[-2]
    if emotion not in emotions:
        img = cv2.imread(file_name)
    emotions.append(emotion)
    
emotions = []
for file_name in glob.glob(dataset_path+'/test/*/*'):
    emotion = file_name.split('/')[-2]
    if emotion not in emotions:
        img = cv2.imread(file_name)
    emotions.append(emotion)

In [7]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((48,48))])

train_dataset = ImageFolder(dataset_path+'/train',transform)
train_loader = DataLoader(dataset=train_dataset,batch_size=100*6)
#creating val data loaders
val_dataset = ImageFolder(dataset_path+'/test',transform)
val_loader = DataLoader(dataset=val_dataset,batch_size=100)

In [8]:
from clip_dissect_pipeline import dissect_pipeline

In [9]:
emotion_mapping = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [10]:
import models
model_path = 'VGGNet'
checkpoint = torch.load(model_path)
target_model = models.Vgg().to(device)
target_model.load_state_dict(checkpoint["params"])
net = target_model
net

Vgg(
  (conv1a): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2a): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2b): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3a): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3b): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4a): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4b): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bn1a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn1b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2a): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn

In [11]:
def correct_count(output, target, topk=(1,)):
    """Computes the top k corrrect count for the specified values of k"""
    maxk = max(topk)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
        res.append(correct_k)
    return res

In [12]:
def evaluate(net, dataloader, criterion):
    net = net.eval()
    loss_tr, n_samples = 0.0, 0.0

    y_pred = []
    y_gt = []

    correct_count1 = 0
    correct_count2 = 0

    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # fuse crops and batchsize
        #bs, ncrops, c, h, w = inputs.shape
        #bs, c, h, w = inputs.shape
        #inputs = inputs.view(-1, c, h, w)
        ncrops = 10
        bs = 5
        # forward
        outputs = net(inputs)

        # combine results across the crops
        #outputs = outputs.view(bs, ncrops, -1)
        #outputs = torch.sum(outputs, dim=1) / ncrops

        loss = criterion(outputs, labels)

        # calculate performance metrics
        loss_tr += loss.item()

        # accuracy
        counts = correct_count(outputs, labels, topk=(1, 2))
        correct_count1 += counts[0].item()
        correct_count2 += counts[1].item()

        _, preds = torch.max(outputs.data, 1)
        preds = preds.to("cpu")
        labels = labels.to("cpu")
        n_samples += labels.size(0)

        y_pred.extend(pred.item() for pred in preds)
        y_gt.extend(y.item() for y in labels)

    acc1 = 100 * correct_count1 / n_samples
    acc2 = 100 * correct_count2 / n_samples
    loss = loss_tr / n_samples
    print("--------------------------------------------------------")
    print("Top 1 Accuracy: %2.6f %%" % acc1)
    print("Top 2 Accuracy: %2.6f %%" % acc2)
    print("Loss: %2.6f" % loss)
    print("Precision: %2.6f" % precision_score(y_gt, y_pred, average='micro'))
    print("Recall: %2.6f" % recall_score(y_gt, y_pred, average='micro'))
    print("F1 Score: %2.6f" % f1_score(y_gt, y_pred, average='micro'))
    print("Confusion Matrix:\n", confusion_matrix(y_gt, y_pred), '\n')
    return acc1, acc2

In [11]:
criterion = nn.CrossEntropyLoss()
# Get data with no augmentation
#trainloader, valloader, testloader = get_dataloaders(augment=False)
dataset_path = "FER2013"
mu, st = 0, 1
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((40, 40)),
    transforms.functional.rgb_to_grayscale,
    #transforms.TenCrop(40),
    #transforms.Lambda(lambda tensors: torch.stack([transforms.Normalize(mean=(mu,), std=(st,))(t) for t in tensors])),
    #transforms.Lambda(lambda inputs: inputs.view(-1, inputs.shape[1], inputs.shape[2], inputs.shape[3]))
])
label_corrections = {4: 6, 5:4, 6:5}
target_transform = transforms.Lambda(lambda label: label_corrections[label] if label in label_corrections else label)
val_dataset = ImageFolder(dataset_path+'/test',transform, target_transform)
#plt.imshow(val_dataset[0][0][4].permute(1,2,0))
val_loader = DataLoader(dataset=val_dataset,batch_size=5, shuffle=True)
print("Test")
evaluate(net, val_loader, criterion)

Test




--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281748
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 



(69.42045137921427, 85.77598216773474)

In [12]:
net1 = copy.deepcopy(net)
weights = net.lin2.weight.data  # Get the weights of the lin1 layer

# Compute the sum of absolute values of weights for each neuron
weight_sums = torch.sum(torch.abs(weights), dim=1)

# Define a threshold to categorize neurons (this is arbitrary and for demonstration)
threshold = weight_sums.mean()  # Use the mean as a simple threshold

# Neurons with weight sum above threshold are "interpretable"
interpretable_indices = torch.where(weight_sums > threshold)[0]
uninterpretable_indices = torch.where(weight_sums <= threshold)[0]

# Modify weights: Increase for interpretable, decrease for uninterpretable
factor = 2  # Example factor to modify weights
for idx in interpretable_indices:
    net1.lin2.weight.data[idx] *= factor

for idx in uninterpretable_indices:
    net1.lin2.weight.data[idx] /= factor
print("modified...")
evaluate(net1, val_loader, criterion)

modified...




--------------------------------------------------------
Top 1 Accuracy: 69.517972 %
Top 2 Accuracy: 85.748119 %
Loss: 0.399206
Precision: 0.695180
Recall: 0.695180
F1 Score: 0.695180
Confusion Matrix:
 [[ 627    3   73   50  124   13   68]
 [  26   67    2    4    9    0    3]
 [ 123    0  523   34  175   69  100]
 [  28    1   19 1557   53   40   76]
 [ 117    0  105   57  754   19  195]
 [  22    1   82   43   19  643   21]
 [  56    1   42  102  202   11  819]] 



(69.51797157982725, 85.7481192532739)

In [13]:
d_probe = 'FER2013'
concept_set = 'CLIP-dissect/data/concept_set.txt' # concept set needs to be path to .txt file
similarity_fn = similarity.soft_wpmi
target_name = 'deep_emotion'
layers = ['conv1a', 'conv1b', 'conv2a', 'conv2b', 'conv3a', 'conv3b', 'conv4a', 'conv4b', 'lin1', 'lin2']
target_layer = 'conv4b'

predicted_emotions, neuron_image_indices, sims = dissect_pipeline(d_probe,concept_set, similarity_fn, target_name, 'lin2')


100%|██████████| 4096/4096 [00:00<00:00, 8856.52it/s]


torch.Size([4096, 25])


In [25]:
res = {}
for tau in np.arange(-1, 2, .25):
    for factor in np.arange(1, 4, .5):  # Example factor to modify weights
        net2 = copy.deepcopy(net)
        for idx in sims.keys():
            # Modify weights: Increase for interpretable, decrease for uninterpretable
            if sims[idx] > tau: # interpretable
                net2.lin2.weight.data[idx] *= factor
            else:               # uninterpretable
                net2.lin2.weight.data[idx] /= factor
        acc1, acc2 = evaluate(net2, val_loader, criterion)
        res[(tau, factor)] = acc1

--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281655
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.573697 %
Top 2 Accuracy: 85.748119 %
Loss: 0.380680
Precision: 0.695737
Recall: 0.695737
F1 Score: 0.695737
Confusion Matrix:
 [[ 621    3   75   54  113   14   78]
 [  26   67    2    4    9    0    3]
 [ 120    0  520   37  166   73  108]
 [  22    1   20 1560   42   40   89]
 [ 115    0  108   60  719   20  225]
 [  17    1   75   44   19  650   25]
 [  54    0   39  104  167   12  857]] 





--------------------------------------------------------
Top 1 Accuracy: 69.545834 %
Top 2 Accuracy: 85.692393 %
Loss: 0.483105
Precision: 0.695458
Recall: 0.695458
F1 Score: 0.695458
Confusion Matrix:
 [[ 622    3   75   54  103   18   83]
 [  26   67    2    4    9    0    3]
 [ 119    0  521   38  161   77  108]
 [  20    1   20 1559   42   43   89]
 [ 116    0  111   63  706   22  229]
 [  16    1   73   42   18  657   24]
 [  51    1   39  104  162   16  860]] 





--------------------------------------------------------
Top 1 Accuracy: 69.490109 %
Top 2 Accuracy: 85.636668 %
Loss: 0.587275
Precision: 0.694901
Recall: 0.694901
F1 Score: 0.694901
Confusion Matrix:
 [[ 621    3   77   54   98   20   85]
 [  25   69    2    5    7    0    3]
 [ 120    0  519   38  159   81  107]
 [  20    1   20 1559   41   43   90]
 [ 113    1  111   66  702   23  231]
 [  15    1   75   42   18  658   22]
 [  51    1   40  105  159   17  860]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.594873 %
Loss: 0.691347
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 615    7   78   56   97   21   84]
 [  24   70    2    5    7    0    3]
 [ 119    0  518   38  159   83  107]
 [  20    1   20 1559   40   45   89]
 [ 112    1  113   66  699   23  233]
 [  14    1   74   42   17  661   22]
 [  50    1   39  108  157   17  861]] 





--------------------------------------------------------
Top 1 Accuracy: 69.364726 %
Top 2 Accuracy: 85.622736 %
Loss: 0.796160
Precision: 0.693647
Recall: 0.693647
F1 Score: 0.693647
Confusion Matrix:
 [[ 614    7   80   56   96   21   84]
 [  24   72    2    4    7    0    2]
 [ 118    0  518   39  157   84  108]
 [  19    1   20 1561   40   47   86]
 [ 115    1  115   68  690   23  235]
 [  15    1   72   42   16  664   21]
 [  49    1   39  108  157   19  860]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281690
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.545834 %
Top 2 Accuracy: 85.775982 %
Loss: 0.380278
Precision: 0.695458
Recall: 0.695458
F1 Score: 0.695458
Confusion Matrix:
 [[ 621    3   75   54  114   14   77]
 [  26   67    2    4    9    0    3]
 [ 120    0  520   37  166   73  108]
 [  22    1   20 1560   42   40   89]
 [ 115    0  108   60  719   20  225]
 [  17    1   77   44   19  648   25]
 [  54    0   39  104  167   12  857]] 





--------------------------------------------------------
Top 1 Accuracy: 69.517972 %
Top 2 Accuracy: 85.720256 %
Loss: 0.481788
Precision: 0.695180
Recall: 0.695180
F1 Score: 0.695180
Confusion Matrix:
 [[ 621    3   75   54  105   18   82]
 [  26   67    2    4    9    0    3]
 [ 119    0  520   38  161   78  108]
 [  20    1   20 1559   42   43   89]
 [ 115    0  111   63  707   22  229]
 [  16    1   75   42   18  656   23]
 [  51    1   39  104  162   16  860]] 





--------------------------------------------------------
Top 1 Accuracy: 69.517972 %
Top 2 Accuracy: 85.594873 %
Loss: 0.585159
Precision: 0.695180
Recall: 0.695180
F1 Score: 0.695180
Confusion Matrix:
 [[ 620    3   78   55   98   20   84]
 [  25   69    2    4    8    0    3]
 [ 120    0  520   38  159   80  107]
 [  20    1   20 1559   41   44   89]
 [ 113    1  111   66  703   23  230]
 [  15    1   75   42   18  658   22]
 [  50    1   40  105  159   17  861]] 





--------------------------------------------------------
Top 1 Accuracy: 69.476177 %
Top 2 Accuracy: 85.594873 %
Loss: 0.688958
Precision: 0.694762
Recall: 0.694762
F1 Score: 0.694762
Confusion Matrix:
 [[ 615    7   78   56   97   21   84]
 [  24   71    2    4    7    0    3]
 [ 118    0  519   38  159   83  107]
 [  20    1   20 1559   40   45   89]
 [ 112    1  111   67  700   23  233]
 [  14    1   74   42   17  661   22]
 [  49    1   40  107  157   17  862]] 





--------------------------------------------------------
Top 1 Accuracy: 69.406520 %
Top 2 Accuracy: 85.608805 %
Loss: 0.793691
Precision: 0.694065
Recall: 0.694065
F1 Score: 0.694065
Confusion Matrix:
 [[ 613    7   80   56   96   21   85]
 [  24   72    2    4    7    0    2]
 [ 117    0  520   38  158   83  108]
 [  19    1   20 1561   40   47   86]
 [ 113    1  114   68  693   23  235]
 [  14    1   74   42   17  662   21]
 [  48    1   39  108  157   19  861]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281747
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.629423 %
Top 2 Accuracy: 85.762051 %
Loss: 0.378219
Precision: 0.696294
Recall: 0.696294
F1 Score: 0.696294
Confusion Matrix:
 [[ 619    3   75   54  114   15   78]
 [  26   67    2    4    9    0    3]
 [ 119    0  521   37  166   73  108]
 [  21    1   20 1559   43   40   90]
 [ 114    0  109   60  719   20  225]
 [  16    1   75   44   19  651   25]
 [  50    0   39  103  167   12  862]] 





--------------------------------------------------------
Top 1 Accuracy: 69.490109 %
Top 2 Accuracy: 85.692393 %
Loss: 0.478388
Precision: 0.694901
Recall: 0.694901
F1 Score: 0.694901
Confusion Matrix:
 [[ 615    3   77   54  107   18   84]
 [  26   67    2    4    9    0    3]
 [ 119    0  520   38  161   78  108]
 [  20    1   20 1559   41   43   90]
 [ 113    0  113   62  708   22  229]
 [  15    1   76   42   18  656   23]
 [  49    1   39  104  161   16  863]] 





--------------------------------------------------------
Top 1 Accuracy: 69.504040 %
Top 2 Accuracy: 85.664531 %
Loss: 0.580430
Precision: 0.695040
Recall: 0.695040
F1 Score: 0.695040
Confusion Matrix:
 [[ 615    5   78   54  100   20   86]
 [  24   70    2    4    8    0    3]
 [ 118    0  521   38  159   81  107]
 [  20    1   20 1558   41   45   89]
 [ 111    1  113   66  702   23  231]
 [  14    1   75   42   17  660   22]
 [  48    1   40  104  159   18  863]] 





--------------------------------------------------------
Top 1 Accuracy: 69.406520 %
Top 2 Accuracy: 85.608805 %
Loss: 0.683376
Precision: 0.694065
Recall: 0.694065
F1 Score: 0.694065
Confusion Matrix:
 [[ 611    7   81   54   98   21   86]
 [  24   71    2    4    7    0    3]
 [ 117    0  520   38  158   83  108]
 [  19    1   20 1560   40   46   88]
 [ 111    1  115   67  697   23  233]
 [  14    1   74   42   17  661   22]
 [  48    1   40  107  157   18  862]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.594873 %
Loss: 0.786580
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 612    7   81   55   97   21   85]
 [  24   72    2    4    7    0    2]
 [ 116    0  520   39  157   83  109]
 [  19    1   20 1560   40   47   87]
 [ 110    2  116   68  692   23  236]
 [  13    1   73   41   17  665   21]
 [  48    1   40  107  156   19  862]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281632
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.573697 %
Top 2 Accuracy: 85.762051 %
Loss: 0.375931
Precision: 0.695737
Recall: 0.695737
F1 Score: 0.695737
Confusion Matrix:
 [[ 617    3   75   54  113   15   81]
 [  26   67    2    4    9    0    3]
 [ 118    0  522   37  166   73  108]
 [  21    1   20 1559   43   40   90]
 [ 114    0  109   60  717   20  227]
 [  16    1   77   44   19  649   25]
 [  49    0   39  103  167   12  863]] 





--------------------------------------------------------
Top 1 Accuracy: 69.462246 %
Top 2 Accuracy: 85.678462 %
Loss: 0.474216
Precision: 0.694622
Recall: 0.694622
F1 Score: 0.694622
Confusion Matrix:
 [[ 613    3   77   54  107   19   85]
 [  26   67    2    4    9    0    3]
 [ 118    0  522   37  161   78  108]
 [  20    1   20 1557   41   44   91]
 [ 112    0  112   62  706   22  233]
 [  15    1   76   42   18  656   23]
 [  49    1   39  102  161   16  865]] 





--------------------------------------------------------
Top 1 Accuracy: 69.545834 %
Top 2 Accuracy: 85.636668 %
Loss: 0.574848
Precision: 0.695458
Recall: 0.695458
F1 Score: 0.695458
Confusion Matrix:
 [[ 612    6   79   54  101   20   86]
 [  24   70    2    4    8    0    3]
 [ 114    0  523   37  159   82  109]
 [  20    1   20 1558   41   45   89]
 [ 110    1  114   64  700   23  235]
 [  14    1   75   41   18  660   22]
 [  47    1   40  100  158   18  869]] 





--------------------------------------------------------
Top 1 Accuracy: 69.531903 %
Top 2 Accuracy: 85.594873 %
Loss: 0.675616
Precision: 0.695319
Recall: 0.695319
F1 Score: 0.695319
Confusion Matrix:
 [[ 608    7   84   54   98   21   86]
 [  24   71    2    3    8    0    3]
 [ 115    0  522   37  158   83  109]
 [  19    1   20 1559   40   46   89]
 [ 111    1  114   65  696   23  237]
 [  13    1   74   41   17  663   22]
 [  47    1   40  101  154   18  872]] 





--------------------------------------------------------
Top 1 Accuracy: 69.517972 %
Top 2 Accuracy: 85.650599 %
Loss: 0.777432
Precision: 0.695180
Recall: 0.695180
F1 Score: 0.695180
Confusion Matrix:
 [[ 608    7   84   54   98   21   86]
 [  24   72    2    4    7    0    2]
 [ 115    0  522   37  157   83  110]
 [  19    1   20 1559   40   48   87]
 [ 110    2  115   67  693   23  237]
 [  13    1   73   41   17  665   21]
 [  47    1   41  102  152   19  871]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281796
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.629423 %
Top 2 Accuracy: 85.762051 %
Loss: 0.371000
Precision: 0.696294
Recall: 0.696294
F1 Score: 0.696294
Confusion Matrix:
 [[ 615    3   77   54  112   16   81]
 [  26   67    2    4    9    0    3]
 [ 118    0  524   37  165   72  108]
 [  21    1   20 1557   43   41   91]
 [ 113    0  112   60  716   20  226]
 [  16    1   78   42   19  651   24]
 [  48    0   39  100  165   13  868]] 





--------------------------------------------------------
Top 1 Accuracy: 69.531903 %
Top 2 Accuracy: 85.706325 %
Loss: 0.465629
Precision: 0.695319
Recall: 0.695319
F1 Score: 0.695319
Confusion Matrix:
 [[ 609    3   81   53  106   20   86]
 [  25   68    2    4    9    0    3]
 [ 114    0  525   36  161   78  110]
 [  20    1   20 1557   41   44   91]
 [ 110    0  115   62  703   23  234]
 [  14    1   76   42   18  657   23]
 [  44    1   39   99  161   17  872]] 





--------------------------------------------------------
Top 1 Accuracy: 69.587629 %
Top 2 Accuracy: 85.636668 %
Loss: 0.562860
Precision: 0.695876
Recall: 0.695876
F1 Score: 0.695876
Confusion Matrix:
 [[ 605    6   85   54  102   20   86]
 [  24   70    2    4    8    0    3]
 [ 112    0  527   36  158   81  110]
 [  19    1   20 1557   41   46   90]
 [ 109    1  116   63  698   23  237]
 [  13    1   75   41   16  662   23]
 [  43    1   41  100  154   18  876]] 





--------------------------------------------------------
Top 1 Accuracy: 69.573697 %
Top 2 Accuracy: 85.706325 %
Loss: 0.661064
Precision: 0.695737
Recall: 0.695737
F1 Score: 0.695737
Confusion Matrix:
 [[ 601    7   86   54  102   21   87]
 [  23   72    2    3    8    0    3]
 [ 112    0  526   36  157   83  110]
 [  18    1   20 1558   39   48   90]
 [ 107    2  118   64  694   23  239]
 [  12    1   74   41   16  664   23]
 [  43    1   41  100  151   18  879]] 





--------------------------------------------------------
Top 1 Accuracy: 69.531903 %
Top 2 Accuracy: 85.636668 %
Loss: 0.759207
Precision: 0.695319
Recall: 0.695319
F1 Score: 0.695319
Confusion Matrix:
 [[ 599    7   89   53   99   23   88]
 [  23   73    2    3    8    0    2]
 [ 112    0  528   37  153   83  111]
 [  18    1   20 1557   39   50   89]
 [ 108    2  123   66  684   24  240]
 [  12    1   73   39   16  668   22]
 [  42    1   41  100  148   19  882]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281657
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.629423 %
Top 2 Accuracy: 85.748119 %
Loss: 0.366250
Precision: 0.696294
Recall: 0.696294
F1 Score: 0.696294
Confusion Matrix:
 [[ 615    3   77   54  113   16   80]
 [  26   67    2    4    9    0    3]
 [ 118    0  525   37  164   72  108]
 [  21    1   20 1557   43   41   91]
 [ 113    0  113   60  713   20  228]
 [  16    1   78   42   18  652   24]
 [  47    0   39  100  165   13  869]] 





--------------------------------------------------------
Top 1 Accuracy: 69.462246 %
Top 2 Accuracy: 85.664531 %
Loss: 0.456833
Precision: 0.694622
Recall: 0.694622
F1 Score: 0.694622
Confusion Matrix:
 [[ 610    3   82   53  104   20   86]
 [  25   68    2    4    9    0    3]
 [ 114    0  526   36  158   80  110]
 [  20    1   20 1553   41   45   94]
 [ 110    0  116   61  698   23  239]
 [  14    1   76   42   17  658   23]
 [  45    1   41   99  157   17  873]] 





--------------------------------------------------------
Top 1 Accuracy: 69.615492 %
Top 2 Accuracy: 85.720256 %
Loss: 0.550179
Precision: 0.696155
Recall: 0.696155
F1 Score: 0.696155
Confusion Matrix:
 [[ 608    5   84   53  101   21   86]
 [  24   70    2    4    8    0    3]
 [ 113    0  528   36  155   81  111]
 [  20    1   20 1554   39   47   93]
 [ 112    1  118   60  692   23  241]
 [  13    1   74   40   16  664   23]
 [  43    1   41   99  151   17  881]] 





--------------------------------------------------------
Top 1 Accuracy: 69.378657 %
Top 2 Accuracy: 85.594873 %
Loss: 0.645091
Precision: 0.693787
Recall: 0.693787
F1 Score: 0.693787
Confusion Matrix:
 [[ 601    7   88   53   97   23   89]
 [  23   72    2    3    8    0    3]
 [ 113    0  528   36  152   83  112]
 [  18    1   20 1554   39   50   92]
 [ 111    2  127   64  675   24  244]
 [  12    1   73   39   16  668   22]
 [  43    1   41   99  148   19  882]] 





--------------------------------------------------------
Top 1 Accuracy: 69.350794 %
Top 2 Accuracy: 85.622736 %
Loss: 0.741173
Precision: 0.693508
Recall: 0.693508
F1 Score: 0.693508
Confusion Matrix:
 [[ 600    7   90   53   95   23   90]
 [  23   73    2    4    7    0    2]
 [ 112    0  529   37  148   85  113]
 [  18    1   20 1554   38   51   92]
 [ 111    2  127   64  671   26  246]
 [  12    1   73   39   16  668   22]
 [  43    1   41   99  147   19  883]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281827
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.587629 %
Top 2 Accuracy: 85.748119 %
Loss: 0.356443
Precision: 0.695876
Recall: 0.695876
F1 Score: 0.695876
Confusion Matrix:
 [[ 611    3   79   53  112   18   82]
 [  26   67    2    4    9    0    3]
 [ 116    0  525   37  164   74  108]
 [  21    1   20 1553   42   44   93]
 [ 113    0  113   59  715   21  226]
 [  15    1   78   42   18  653   24]
 [  44    0   40  100  163   15  871]] 





--------------------------------------------------------
Top 1 Accuracy: 69.476177 %
Top 2 Accuracy: 85.734188 %
Loss: 0.439021
Precision: 0.694762
Recall: 0.694762
F1 Score: 0.694762
Confusion Matrix:
 [[ 605    4   84   52  104   21   88]
 [  25   68    2    4    9    0    3]
 [ 113    0  529   36  157   80  109]
 [  20    1   20 1552   41   48   92]
 [ 111    0  118   60  697   23  238]
 [  13    1   76   40   17  661   23]
 [  42    1   41   98  159   17  875]] 





--------------------------------------------------------
Top 1 Accuracy: 69.517972 %
Top 2 Accuracy: 85.650599 %
Loss: 0.525767
Precision: 0.695180
Recall: 0.695180
F1 Score: 0.695180
Confusion Matrix:
 [[ 601    5   89   53  100   22   88]
 [  23   72    2    3    8    0    3]
 [ 113    0  530   35  152   84  110]
 [  17    1   22 1551   39   51   93]
 [ 109    2  126   60  686   24  240]
 [  12    1   73   39   16  668   22]
 [  40    1   43   96  153   18  882]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.567010 %
Loss: 0.613561
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 596    7   91   53   99   22   90]
 [  23   72    2    3    8    0    3]
 [ 110    0  530   35  150   88  111]
 [  17    1   22 1552   37   52   93]
 [ 110    2  127   61  679   25  243]
 [  12    1   72   38   16  670   22]
 [  39    1   44   97  150   18  884]] 





--------------------------------------------------------
Top 1 Accuracy: 69.197548 %
Top 2 Accuracy: 85.525216 %
Loss: 0.702529
Precision: 0.691975
Recall: 0.691975
F1 Score: 0.691975
Confusion Matrix:
 [[ 590    8   96   53   96   24   91]
 [  23   73    3    3    7    0    2]
 [ 109    0  529   36  148   89  113]
 [  17    1   22 1551   36   53   94]
 [ 109    2  131   64  667   27  247]
 [  12    1   72   37   16  671   22]
 [  37    1   45   98  147   19  886]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281633
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.545834 %
Top 2 Accuracy: 85.887434 %
Loss: 0.344453
Precision: 0.695458
Recall: 0.695458
F1 Score: 0.695458
Confusion Matrix:
 [[ 608    3   81   52  115   19   80]
 [  26   67    2    4    9    0    3]
 [ 115    0  528   36  164   75  106]
 [  21    1   21 1550   43   45   93]
 [ 112    0  115   58  718   22  222]
 [  14    1   79   42   18  654   23]
 [  43    0   41   96  171   15  867]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.789914 %
Loss: 0.417295
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 600    4   89   51  109   22   83]
 [  24   70    2    3    9    0    3]
 [ 111    0  534   35  155   84  105]
 [  19    1   23 1546   41   50   94]
 [ 111    1  127   57  697   24  230]
 [  12    1   76   40   17  663   22]
 [  40    1   44   95  162   18  873]] 





--------------------------------------------------------
Top 1 Accuracy: 69.392588 %
Top 2 Accuracy: 85.650599 %
Loss: 0.494600
Precision: 0.693926
Recall: 0.693926
F1 Score: 0.693926
Confusion Matrix:
 [[ 592    6   96   52  103   22   87]
 [  23   73    3    3    7    0    2]
 [ 109    0  535   33  153   86  108]
 [  17    1   23 1546   40   52   95]
 [ 108    2  132   58  687   25  235]
 [  12    1   73   36   16  671   22]
 [  38    1   46   94  159   18  877]] 





--------------------------------------------------------
Top 1 Accuracy: 69.253274 %
Top 2 Accuracy: 85.539147 %
Loss: 0.574322
Precision: 0.692533
Recall: 0.692533
F1 Score: 0.692533
Confusion Matrix:
 [[ 587    8   97   53  101   23   89]
 [  23   73    3    3    7    0    2]
 [ 109    0  535   33  150   89  108]
 [  16    1   24 1545   39   54   95]
 [ 105    2  137   57  682   26  238]
 [  12    1   73   36   16  671   22]
 [  37    1   47   95  157   18  878]] 





--------------------------------------------------------
Top 1 Accuracy: 69.281137 %
Top 2 Accuracy: 85.330176 %
Loss: 0.655708
Precision: 0.692811
Recall: 0.692811
F1 Score: 0.692811
Confusion Matrix:
 [[ 587    8   97   53  101   23   89]
 [  22   73    3    3    7    1    2]
 [ 106    0  538   32  147   90  111]
 [  16    1   24 1545   38   56   94]
 [ 104    2  139   57  677   28  240]
 [  12    1   72   34   16  675   21]
 [  37    1   48   96  154   19  878]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281829
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.517972 %
Top 2 Accuracy: 85.817777 %
Loss: 0.327961
Precision: 0.695180
Recall: 0.695180
F1 Score: 0.695180
Confusion Matrix:
 [[ 608    3   83   49  117   20   78]
 [  26   67    2    4    9    0    3]
 [ 114    0  534   35  161   75  105]
 [  21    0   23 1549   43   45   93]
 [ 113    0  120   58  720   21  215]
 [  15    1   79   41   18  655   22]
 [  43    0   47   93  178   15  857]] 





--------------------------------------------------------
Top 1 Accuracy: 69.364726 %
Top 2 Accuracy: 85.692393 %
Loss: 0.387151
Precision: 0.693647
Recall: 0.693647
F1 Score: 0.693647
Confusion Matrix:
 [[ 596    4   95   49  113   22   79]
 [  24   71    3    3    8    0    2]
 [ 111    0  540   34  155   81  103]
 [  17    1   25 1548   42   51   90]
 [ 110    1  132   57  702   24  221]
 [  12    1   78   37   17  664   22]
 [  41    1   50   92  173   18  858]] 





--------------------------------------------------------
Top 1 Accuracy: 69.295068 %
Top 2 Accuracy: 85.511284 %
Loss: 0.452465
Precision: 0.692951
Recall: 0.692951
F1 Score: 0.692951
Confusion Matrix:
 [[ 588    7   99   52  111   24   77]
 [  22   73    3    3    7    1    2]
 [ 105    0  544   34  153   86  102]
 [  16    1   25 1547   39   55   91]
 [ 108    2  140   57  691   27  222]
 [  12    1   77   34   16  670   21]
 [  39    1   52   93  168   19  861]] 





--------------------------------------------------------
Top 1 Accuracy: 69.169685 %
Top 2 Accuracy: 85.483422 %
Loss: 0.520328
Precision: 0.691697
Recall: 0.691697
F1 Score: 0.691697
Confusion Matrix:
 [[ 584    8  102   52  107   24   81]
 [  21   74    4    3    6    1    2]
 [ 104    0  548   33  150   88  101]
 [  16    1   25 1543   39   58   92]
 [ 105    2  144   56  686   30  224]
 [  12    1   77   34   15  671   21]
 [  37    1   53   94  168   21  859]] 





--------------------------------------------------------
Top 1 Accuracy: 69.113959 %
Top 2 Accuracy: 85.371970 %
Loss: 0.590634
Precision: 0.691140
Recall: 0.691140
F1 Score: 0.691140
Confusion Matrix:
 [[ 582    8  104   52  104   25   83]
 [  19   77    4    3    5    1    2]
 [ 103    0  549   33  147   91  101]
 [  16    1   26 1542   37   59   93]
 [ 104    2  147   56  681   31  226]
 [  12    1   76   32   15  675   20]
 [  37    1   57   94  164   25  855]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281708
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.545834 %
Top 2 Accuracy: 85.734188 %
Loss: 0.311664
Precision: 0.695458
Recall: 0.695458
F1 Score: 0.695458
Confusion Matrix:
 [[ 612    3   87   46  119   19   72]
 [  26   67    2    4    9    0    3]
 [ 113    0  540   33  166   71  101]
 [  23    0   24 1543   48   47   89]
 [ 116    0  120   55  728   21  207]
 [  15    1   80   38   18  657   22]
 [  48    1   50   90  184   15  845]] 





--------------------------------------------------------
Top 1 Accuracy: 69.295068 %
Top 2 Accuracy: 85.539147 %
Loss: 0.357531
Precision: 0.692951
Recall: 0.692951
F1 Score: 0.692951
Confusion Matrix:
 [[ 601    4   99   43  118   23   70]
 [  23   72    2    3    9    0    2]
 [ 112    0  547   31  157   81   96]
 [  20    1   26 1537   48   54   88]
 [ 110    2  141   54  714   24  202]
 [  13    1   79   33   18  666   21]
 [  44    1   55   89  189   18  837]] 





--------------------------------------------------------
Top 1 Accuracy: 69.113959 %
Top 2 Accuracy: 85.469490 %
Loss: 0.410197
Precision: 0.691140
Recall: 0.691140
F1 Score: 0.691140
Confusion Matrix:
 [[ 595    7  103   41  117   25   70]
 [  23   73    4    3    6    0    2]
 [ 110    0  550   31  157   84   92]
 [  20    1   30 1532   46   59   86]
 [ 109    2  147   54  709   27  199]
 [  12    1   81   32   15  670   20]
 [  45    1   56   89  189   21  832]] 





--------------------------------------------------------
Top 1 Accuracy: 69.058233 %
Top 2 Accuracy: 85.385901 %
Loss: 0.466977
Precision: 0.690582
Recall: 0.690582
F1 Score: 0.690582
Confusion Matrix:
 [[ 591    8  107   40  116   26   70]
 [  18   77    4    3    6    1    2]
 [ 108    0  555   31  153   87   90]
 [  20    1   30 1528   45   63   87]
 [ 107    2  151   53  706   29  199]
 [  12    1   79   32   15  672   20]
 [  44    1   58   88  186   28  828]] 





--------------------------------------------------------
Top 1 Accuracy: 68.891056 %
Top 2 Accuracy: 85.330176 %
Loss: 0.525686
Precision: 0.688911
Recall: 0.688911
F1 Score: 0.688911
Confusion Matrix:
 [[ 587    8  110   39  116   28   70]
 [  18   78    4    3    5    1    2]
 [ 106    1  556   29  151   91   90]
 [  19    1   31 1523   44   68   88]
 [ 106    2  155   53  700   32  199]
 [  12    1   78   31   15  674   20]
 [  44    1   60   88  183   30  827]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281838
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.392588 %
Top 2 Accuracy: 85.762051 %
Loss: 0.297331
Precision: 0.693926
Recall: 0.693926
F1 Score: 0.693926
Confusion Matrix:
 [[ 609    3   90   45  122   19   70]
 [  26   67    2    4    9    0    3]
 [ 115    0  544   32  166   71   96]
 [  24    0   25 1538   48   50   89]
 [ 116    0  122   52  734   22  201]
 [  15    1   84   34   18  658   21]
 [  52    1   53   87  193   16  831]] 





--------------------------------------------------------
Top 1 Accuracy: 69.267205 %
Top 2 Accuracy: 85.469490 %
Loss: 0.331625
Precision: 0.692672
Recall: 0.692672
F1 Score: 0.692672
Confusion Matrix:
 [[ 598    5  102   41  120   24   68]
 [  23   72    2    3    9    0    2]
 [ 110    0  550   29  163   79   93]
 [  20    1   28 1530   50   58   87]
 [ 109    2  141   50  729   24  192]
 [  13    1   82   32   18  666   19]
 [  47    1   59   86  196   17  827]] 





--------------------------------------------------------
Top 1 Accuracy: 69.072165 %
Top 2 Accuracy: 85.413764 %
Loss: 0.373960
Precision: 0.690722
Recall: 0.690722
F1 Score: 0.690722
Confusion Matrix:
 [[ 594    7  106   38  121   26   66]
 [  20   75    4    3    6    1    2]
 [ 107    0  559   29  159   82   88]
 [  20    1   34 1518   47   65   89]
 [ 106    2  152   49  722   27  189]
 [  13    1   83   29   16  670   19]
 [  47    1   59   83  196   27  820]] 





--------------------------------------------------------
Top 1 Accuracy: 68.932850 %
Top 2 Accuracy: 85.316244 %
Loss: 0.420985
Precision: 0.689329
Recall: 0.689329
F1 Score: 0.689329
Confusion Matrix:
 [[ 587    8  111   37  121   29   65]
 [  18   77    4    3    6    1    2]
 [ 103    0  564   27  153   89   88]
 [  20    1   37 1510   46   71   89]
 [ 104    2  154   48  718   33  188]
 [  13    1   78   28   15  677   19]
 [  44    1   65   81  195   32  815]] 





--------------------------------------------------------
Top 1 Accuracy: 68.668153 %
Top 2 Accuracy: 85.218724 %
Loss: 0.470682
Precision: 0.686682
Recall: 0.686682
F1 Score: 0.686682
Confusion Matrix:
 [[ 581    9  114   37  122   32   63]
 [  18   78    4    3    5    1    2]
 [ 102    1  564   25  151   95   86]
 [  20    1   37 1503   46   78   89]
 [ 102    2  162   45  713   36  187]
 [  12    1   79   28   14  678   19]
 [  44    1   68   80  194   34  812]] 





--------------------------------------------------------
Top 1 Accuracy: 69.420451 %
Top 2 Accuracy: 85.775982 %
Loss: 0.281630
Precision: 0.694205
Recall: 0.694205
F1 Score: 0.694205
Confusion Matrix:
 [[ 625    3   72   52  119   12   75]
 [  28   61    2    5   11    0    4]
 [ 121    0  520   36  175   66  106]
 [  26    0   18 1558   47   36   89]
 [ 115    0  103   59  738   17  215]
 [  21    0   81   48   22  635   24]
 [  53    0   38  100  188    8  846]] 





--------------------------------------------------------
Top 1 Accuracy: 69.364726 %
Top 2 Accuracy: 85.706325 %
Loss: 0.282293
Precision: 0.693647
Recall: 0.693647
F1 Score: 0.693647
Confusion Matrix:
 [[ 611    3   90   40  129   19   66]
 [  26   67    2    4    9    0    3]
 [ 115    0  545   30  170   71   93]
 [  25    0   24 1535   55   49   86]
 [ 116    0  121   49  750   21  190]
 [  16    1   87   33   19  654   21]
 [  54    1   54   86  208   13  817]] 





--------------------------------------------------------
Top 1 Accuracy: 69.225411 %
Top 2 Accuracy: 85.399833 %
Loss: 0.303957
Precision: 0.692254
Recall: 0.692254
F1 Score: 0.692254
Confusion Matrix:
 [[ 600    5  103   38  128   22   62]
 [  23   72    2    3    9    0    2]
 [ 109    0  559   27  172   75   82]
 [  22    1   31 1523   54   57   86]
 [ 111    1  142   46  751   23  173]
 [  15    1   87   30   19  660   19]
 [  54    1   60   81  214   19  804]] 





--------------------------------------------------------
Top 1 Accuracy: 68.793536 %
Top 2 Accuracy: 85.371970 %
Loss: 0.335254
Precision: 0.687935
Recall: 0.687935
F1 Score: 0.687935
Confusion Matrix:
 [[ 592    7  111   36  126   26   60]
 [  20   76    3    3    7    0    2]
 [ 102    0  571   25  170   77   79]
 [  23    1   39 1502   53   66   90]
 [ 108    2  155   43  748   24  167]
 [  13    1   86   29   18  665   19]
 [  55    1   67   78  221   27  784]] 





--------------------------------------------------------
Top 1 Accuracy: 68.500975 %
Top 2 Accuracy: 85.246587 %
Loss: 0.371503
Precision: 0.685010
Recall: 0.685010
F1 Score: 0.685010
Confusion Matrix:
 [[ 585    8  115   36  126   29   59]
 [  18   77    4    3    6    1    2]
 [ 100    0  579   20  165   83   77]
 [  24    2   43 1488   54   73   90]
 [ 104    2  161   41  744   31  164]
 [  12    1   85   28   16  670   19]
 [  53    1   73   76  226   30  774]] 





--------------------------------------------------------
Top 1 Accuracy: 68.417386 %
Top 2 Accuracy: 85.093341 %
Loss: 0.411027
Precision: 0.684174
Recall: 0.684174
F1 Score: 0.684174
Confusion Matrix:
 [[ 580    9  121   33  126   32   57]
 [  18   78    4    3    5    1    2]
 [  99    1  585   18  159   86   76]
 [  24    2   45 1478   54   82   89]
 [ 102    2  171   39  743   34  156]
 [  11    1   81   25   15  680   18]
 [  51    1   78   76  225   35  767]] 



In [33]:
69.629 - 69.42

0.20900000000000318

In [29]:
for key in sorted(res, key=res.get, reverse=True):
    print(key, res[key])

(-0.5, 1.5) 69.62942323767066
(0.0, 1.5) 69.62942323767066
(0.25, 1.5) 69.62942323767066
(0.25, 2.5) 69.61549178044024
(0.0, 2.5) 69.58762886597938
(0.5, 1.5) 69.58762886597938
(-1.0, 1.5) 69.57369740874896
(-0.25, 1.5) 69.57369740874896
(0.0, 3.0) 69.57369740874896
(-1.0, 2.0) 69.5458344942881
(-0.75, 1.5) 69.5458344942881
(-0.25, 2.5) 69.5458344942881
(0.75, 1.5) 69.5458344942881
(1.25, 1.5) 69.5458344942881
(-0.25, 3.0) 69.53190303705767
(0.0, 2.0) 69.53190303705767
(0.0, 3.5) 69.53190303705767
(-0.75, 2.0) 69.51797157982725
(-0.75, 2.5) 69.51797157982725
(-0.25, 3.5) 69.51797157982725
(0.5, 2.5) 69.51797157982725
(1.0, 1.5) 69.51797157982725
(-0.5, 2.5) 69.50404012259682
(-1.0, 2.5) 69.49010866536639
(-0.5, 2.0) 69.49010866536639
(-0.75, 3.0) 69.47617720813597
(0.5, 2.0) 69.47617720813597
(-0.25, 2.0) 69.46224575090554
(0.25, 2.0) 69.46224575090554
(-1.0, 1.0) 69.42045137921427
(-1.0, 3.0) 69.42045137921427
(-0.75, 1.0) 69.42045137921427
(-0.5, 1.0) 69.42045137921427
(-0.5, 3.5) 69

In [None]:
plt.sim

In [None]:
min_sim = min(sims.values())
max_sim = max(sims.values())
min_sim, max_sim

In [None]:
res = {}
min_sim = min(sims.values())
max_sim = max(sims.values())

for factor in [2, 3, 4]:  # Example factor to modify weights
    net2 = copy.deepcopy(net)
    for idx in sims.keys():
        # Modify weights: Increase for interpretable, decrease for uninterpretable
        net2.lin2.weight.data[idx] *= factor * (sims[idx] - min_sim) / (max_sim - min_sim)
    acc1, acc2 = evaluate(net2, val_loader, criterion)
    res[(tau, factor)] = acc1

In [None]:
sorted(res.values())

In [None]:
res = {}
min_sim = min(sims.values())
max_sim = max(sims.values())
mean_sim = sum(sims.values()) / len(sims)
for factor in np.arange(1, 15):  # Example factor to modify weights
    net2 = copy.deepcopy(net)
    for idx in sims.keys():
        # Modify weights: Increase for interpretable, decrease for uninterpretable
        net2.lin2.weight.data[idx] *= factor * (sims[idx] - min_sim) / (max_sim - min_sim)
    acc1, acc2 = evaluate(net2, val_loader, criterion)
    res[factor] = acc1

In [None]:
sorted(res.values())

In [21]:
concept_labels = 'CLIP-dissect/data/concept_set_labels.txt'
with open(concept_labels, 'r') as f: 
    s = (f.read()).split('\n')
labels = {}
for item in s:
    items = item.split(':')
    labels[items[0]] = items[1].strip().split(',')
print(labels)

{'Furrowed Brows': ['angry'], 'Raised Brows': ['surprise'], 'Lowered Brows': ['angry'], 'Drooping Brows': ['sad'], 'Relaxed Brows': ['neutral'], 'Wide Open Eyes': ['surprise'], 'Narrowed Eyes': ['angry'], 'Twinkling or Crinkled Eyes': ['happy'], 'Tearful Eyes': ['sad'], 'Closed Eyes': ['sad', 'neutral', 'fear'], 'Relaxed Eyes': ['neutral'], 'Tense Mouth and Jaw': ['angry', 'disgust'], 'Relaxed Mouth and Jaw': ['happy', 'neutral', 'sad'], 'Smile': ['happy'], 'Frown': ['sad'], 'Downturned Mouth': ['sad'], 'Slightly Open Mouth': ['surprise'], 'Wide Open Mouth': ['surprise'], 'Slight Chin Raise': ['surprise', 'neutral'], 'Slightly Dropped Jaw': ['surprise'], 'Flared Nostrils': ['angry', 'disgust'], 'Tightened Facial Muscles': ['angry', 'fear', 'disgust'], 'Compressed Lips': ['angry', 'disgust'], 'Elevated Upper Eyelids': ['surprise'], 'Relaxed Facial Muscles': ['happy', 'neutral']}


In [16]:
from collections import Counter

def get_actual_labels(image_indices, pil_data):
    labels = []
    for idx in image_indices:
        path, _ = pil_data.samples[idx]
        label = path.split('/')[-2]  
        labels.append(label)
    return labels

def show_images(ids):
    num = len(ids)
    fig = plt.figure(figsize=[10, 5])
    axs = fig.subplots(nrows=num//5, ncols=5)
    for i, id in enumerate(ids):
        im, label = pil_data[id]
        im = im.resize([375,375])
        axs[i // 5, i % 5].imshow(im)
        axs[i // 5, i % 5].axis('off')

In [19]:

pil_data = data_utils.get_data(d_probe)
clip_name = 'ViT-B/16'
device = 'cuda'


In [None]:
net

In [22]:
from collections import defaultdict 
tau = 0
sim_thresh = .15
overall_accuracies = []
interpretables = []
uninterpretables = []
total_accs_interp = []
total_accs_uninterp = []
for layer in layers:
    predicted_emotions, neuron_image_indices, sims = dissect_pipeline(d_probe,concept_set, similarity_fn, target_name, layer)
    accuracies = defaultdict(dict)
    for neuron, predicted_emotion in zip(neuron_image_indices.keys(), predicted_emotions):
        neuron_id = neuron
        image_indices = neuron_image_indices[neuron]
        actual_labels = get_actual_labels(image_indices, pil_data)
        sim = sims[int(neuron_id.cpu().numpy())]
        accuracies[int(neuron_id)]['sim'] = sim
        correct_count = 0
        for label in actual_labels:
            if label in labels[predicted_emotion]:
                correct_count += 1
        # Calculate the accuracy for this neuron
        accuracy = correct_count / len(actual_labels)
        accuracies[int(neuron_id)]['accuracy'] = accuracy

    # Calculate overall accuracy
    overall_accuracy = sum(d['accuracy'] for d in accuracies.values()) / len(accuracies)
    print(overall_accuracy, len(accuracies))
    overall_accuracies.append(overall_accuracy)
    # Calculate interpretable accuracy
    interpretable = 0
    uninterpretable = 0
    acc_interpretable = 0
    acc_uninterpretable = 0
    for d in accuracies.values():
        if d['sim'] > tau:
            interpretable += 1
            acc_interpretable += d['accuracy']
        else:
            uninterpretable += 1
            acc_uninterpretable += d['accuracy']
    interpretables.append(interpretable)
    uninterpretables.append(uninterpretable)
    print(layer)
    total_acc_interp = acc_interpretable / interpretable if interpretable else 0
    total_acc_uninterp = acc_uninterpretable / uninterpretable if uninterpretable else 0
    total_accs_interp.append(total_acc_interp)
    total_accs_uninterp.append(total_acc_uninterp)
    

100%|██████████| 64/64 [00:00<00:00, 3559.44it/s]


torch.Size([64, 25])
0.21562499999999984 64
conv1a


100%|██████████| 64/64 [00:00<00:00, 4878.43it/s]


torch.Size([64, 25])
0.2999999999999999 64
conv1b


100%|██████████| 128/128 [00:00<00:00, 4108.19it/s]


torch.Size([128, 25])
0.21484375000000003 128
conv2a


100%|██████████| 128/128 [00:00<00:00, 1284.79it/s]


torch.Size([128, 25])
0.2765625 128
conv2b


100%|██████████| 256/256 [00:00<00:00, 5032.70it/s]


torch.Size([256, 25])
0.25976562500000017 256
conv3a


100%|██████████| 256/256 [00:00<00:00, 1773.13it/s]


torch.Size([256, 25])
0.2308593750000004 256
conv3b


100%|██████████| 512/512 [00:00<00:00, 6804.30it/s]


torch.Size([512, 25])
0.28242187499999954 512
conv4a


100%|██████████| 512/512 [00:00<00:00, 3026.60it/s]


torch.Size([512, 25])
0.2833984374999996 512
conv4b


100%|██████████| 4096/4096 [00:00<00:00, 7766.94it/s]


torch.Size([4096, 25])
0.3585693359374982 4096
lin1


100%|██████████| 4096/4096 [00:00<00:00, 7871.50it/s]


torch.Size([4096, 25])
0.4342285156249987 4096
lin2


In [34]:
print("f\n\nj")

f

j


In [46]:
df = pd.DataFrame({'num_interp':interpretables, 'num_uninterp':uninterpretables, \
                  'overall_acc':overall_accuracies, 'interp_acc':total_accs_interp, \
                  'uninterp_acc':total_accs_uninterp}, index=layers)
df['prop_interp'] = df['num_interp'] / (df['num_interp'] + df['num_uninterp'])
pd.options.display.float_format = "{:,.2f}".format
df

Unnamed: 0,num_interp,num_uninterp,overall_acc,interp_acc,uninterp_acc,prop_interp
conv1a,64,0,0.22,0.22,0.0,1.0
conv1b,64,0,0.3,0.3,0.0,1.0
conv2a,127,1,0.21,0.21,0.3,0.99
conv2b,127,1,0.28,0.27,0.9,0.99
conv3a,254,2,0.26,0.26,0.05,0.99
conv3b,255,1,0.23,0.23,0.3,1.0
conv4a,492,20,0.28,0.29,0.21,0.96
conv4b,494,18,0.28,0.28,0.32,0.96
lin1,3489,607,0.36,0.38,0.25,0.85
lin2,3835,261,0.43,0.45,0.24,0.94


In [None]:
tau = 1
sim_thresh = .15
detailed_results = {}

for neuron, predicted_emotion in zip(neuron_image_indices.keys(), predicted_emotions):
    neuron_id = neuron
    image_indices = neuron_image_indices[neuron]
    actual_labels = get_actual_labels(image_indices, pil_data)
    sim = sims[int(neuron_id.cpu().numpy())]

    detailed_results[int(neuron_id)] = {
        'predicted_emotion': predicted_emotion,
        'actual_labels': actual_labels,
        'top_image_indices': image_indices,
        'sim': sim
    }
    #label_embed = embedding_model.encode(actual_labels)
    #concept_embed = embedding_model.encode([predicted_emotion] * len(actual_labels))
    #mp_cos = np.mean(np.sum(label_embed*concept_embed, axis=1) > sim_thresh)
    correct_count = 0
    for label in actual_labels:
        if label in labels[predicted_emotion]:
            correct_count += 1
    # Calculate the accuracy for this neuron
    accuracy = correct_count / len(actual_labels)
    detailed_results[int(neuron_id)]['accuracy'] = accuracy

# Calculate overall accuracy
overall_accuracy = sum(d['accuracy'] for d in detailed_results.values()) / len(detailed_results)
# Calculate interpretable accuracy
interpretable = 0
uninterpretable = 0
acc_interpretable = 0
acc_uninterpretable = 0
for d in detailed_results.values():
    if d['sim'] > tau:
        interpretable += 1
        acc_interpretable += d['accuracy']
    else:
        uninterpretable += 1
        acc_uninterpretable += d['accuracy']
print(target_layer)
print(f'num interpretable: {interpretable}')
print(f'num uninterpretable: {uninterpretable}')
total_acc_interp = acc_interpretable / interpretable if interpretable else 0
total_acc_uninterp = acc_uninterpretable / uninterpretable if uninterpretable else 0
# Print the detailed results
i = 0
for neuron_id, results in detailed_results.items():
    i += 1
    if i >= 0:
        break
    print(f"Neuron {neuron_id}: Predicted Emotion = {results['predicted_emotion']}, Accuracy = {results['accuracy']}")
    for image_idx, actual_label in zip(results['top_image_indices'], results['actual_labels']):
        print(f"    Image Index {image_idx}: Actual Label = {actual_label}")

print(f"Overall accuracy: {overall_accuracy}")
print(f"Overall interpretable accuracy: {total_acc_interp}")
print(f"Overall uninterpretable accuracy: {total_acc_uninterp}")

    


In [None]:
ids = detailed_results[2275]['top_image_indices']
show_images(ids)