In [4]:
# Read facial image frame & give it to the model
from models.PosterV2_7cls import *
import torch.nn.parallel
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import torch.utils.data
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, plot_confusion_matrix

datapath = r"/home/hyojinju/Dataset/test"#r"/home/hyojinju/Dataset/CAER-S"
batch_size = 1
workers = 4
checkpoint_path = r"/home/hyojinju/POSTER_V2/checkpoint/caer-s-model_best.pth"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

model = pyramid_trans_expr2(img_size=224, num_classes=7)
model = torch.nn.DataParallel(model).cuda()

traindir = os.path.join(datapath, 'train')
valdir = os.path.join(datapath, 'valid')

test_dataset = datasets.ImageFolder(valdir,
                                    transforms.Compose([transforms.Resize((224, 224)),
                                                        transforms.ToTensor(),
                                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                                std=[0.229, 0.224, 0.225]),
                                                        ]))
val_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=workers,
                                            pin_memory=True)

labels = ['A', 'B', 'C', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']


class RecorderMeter1(object):
    """Computes and stores the minimum loss value and its epoch index"""

    def __init__(self, total_epoch):
        self.reset(total_epoch)

    def reset(self, total_epoch):
        self.total_epoch = total_epoch
        self.current_epoch = 0
        self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32)  # [epoch, train/val]
        self.epoch_accuracy = np.zeros((self.total_epoch, 2), dtype=np.float32)  # [epoch, train/val]

    def update(self, output, target):
        self.y_pred = output
        self.y_true = target

    def plot_confusion_matrix(self, cm, title='Confusion Matrix', cmap=plt.cm.binary):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        y_true = self.y_true
        y_pred = self.y_pred

        plt.title(title)
        plt.colorbar()
        xlocations = np.array(range(len(labels)))
        plt.xticks(xlocations, labels, rotation=90)
        plt.yticks(xlocations, labels)
        plt.ylabel('True label')
        plt.xlabel('Predicted label')

        cm = confusion_matrix(y_true, y_pred)
        np.set_printoptions(precision=2)
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        plt.figure(figsize=(12, 8), dpi=120)

        ind_array = np.arange(len(labels))
        x, y = np.meshgrid(ind_array, ind_array)
        for x_val, y_val in zip(x.flatten(), y.flatten()):
            c = cm_normalized[y_val][x_val]
            if c > 0.01:
                plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=7, va='center', ha='center')
        # offset the tick
        tick_marks = np.arange(len(7))
        plt.gca().set_xticks(tick_marks, minor=True)
        plt.gca().set_yticks(tick_marks, minor=True)
        plt.gca().xaxis.set_ticks_position('none')
        plt.gca().yaxis.set_ticks_position('none')
        plt.grid(True, which='minor', linestyle='-')
        plt.gcf().subplots_adjust(bottom=0.15)

        plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
        # show confusion matrix
        plt.savefig('./log/confusion_matrix.png', format='png')
        # fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        print('Saved figure')
        plt.show()

    def matrix(self):
        target = self.y_true
        output = self.y_pred
        im_re_label = np.array(target)
        im_pre_label = np.array(output)
        y_ture = im_re_label.flatten()
        # im_re_label.transpose()
        y_pred = im_pre_label.flatten()
        im_pre_label.transpose()

class RecorderMeter(object):
    """Computes and stores the minimum loss value and its epoch index"""

    def __init__(self, total_epoch):
        self.reset(total_epoch)

    def reset(self, total_epoch):
        self.total_epoch = total_epoch
        self.current_epoch = 0
        self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32)  # [epoch, train/val]
        self.epoch_accuracy = np.zeros((self.total_epoch, 2), dtype=np.float32)  # [epoch, train/val]

    def update(self, idx, train_loss, train_acc, val_loss, val_acc):
        self.epoch_losses[idx, 0] = train_loss * 30
        self.epoch_losses[idx, 1] = val_loss * 30
        self.epoch_accuracy[idx, 0] = train_acc
        self.epoch_accuracy[idx, 1] = val_acc
        self.current_epoch = idx + 1

    def plot_curve(self, save_path):
        title = 'the accuracy/loss curve of train/val'
        dpi = 80
        width, height = 1800, 800
        legend_fontsize = 10
        figsize = width / float(dpi), height / float(dpi)

        fig = plt.figure(figsize=figsize)
        x_axis = np.array([i for i in range(self.total_epoch)])  # epochs
        y_axis = np.zeros(self.total_epoch)

        plt.xlim(0, self.total_epoch)
        plt.ylim(0, 100)
        interval_y = 5
        interval_x = 5
        plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
        plt.yticks(np.arange(0, 100 + interval_y, interval_y))
        plt.grid()
        plt.title(title, fontsize=20)
        plt.xlabel('the training epoch', fontsize=16)
        plt.ylabel('accuracy', fontsize=16)

        y_axis[:] = self.epoch_accuracy[:, 0]
        plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_accuracy[:, 1]
        plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 0]
        plt.plot(x_axis, y_axis, color='g', linestyle=':', label='train-loss-x30', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 1]
        plt.plot(x_axis, y_axis, color='y', linestyle=':', label='valid-loss-x30', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        if save_path is not None:
            fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
            print('Saved figure')
        plt.close(fig)


Mobilefacenet OK
Mobilefacenet OK 2
VisionTransformer OK
irback load OK
load_weight 304
irback OK
layers OK
layers OK


In [5]:
if os.path.isfile(checkpoint_path):
    print("=> loading checkpoint '{}'".format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path)
    vit_keys = ["module.VIT.cls_token", "module.VIT.pos_embed", "module.VIT.se_block.linear1.weight", "module.VIT.se_block.linear1.bias", "module.VIT.se_block.linear2.weight", "module.VIT.se_block.linear2.bias", "module.VIT.patch_embed.proj.weight", "module.VIT.patch_embed.proj.bias", "module.VIT.head.linear.weight", "module.VIT.head.linear.bias", "module.VIT.eca_block.conv.weight", "module.VIT.CON1.weight", "module.VIT.IRLinear1.weight", "module.VIT.IRLinear1.bias", "module.VIT.IRLinear2.weight", "module.VIT.IRLinear2.bias", "module.VIT.blocks.0.norm1.weight", "module.VIT.blocks.0.norm1.bias", "module.VIT.blocks.0.conv.weight", "module.VIT.blocks.0.conv.bias", "module.VIT.blocks.0.attn.qkv.weight", "module.VIT.blocks.0.attn.qkv.bias", "module.VIT.blocks.0.attn.proj.weight", "module.VIT.blocks.0.attn.proj.bias", "module.VIT.blocks.0.norm2.weight", "module.VIT.blocks.0.norm2.bias", "module.VIT.blocks.0.mlp.fc1.weight", "module.VIT.blocks.0.mlp.fc1.bias", "module.VIT.blocks.0.mlp.fc2.weight", "module.VIT.blocks.0.mlp.fc2.bias", "module.VIT.blocks.1.norm1.weight", "module.VIT.blocks.1.norm1.bias", "module.VIT.blocks.1.conv.weight", "module.VIT.blocks.1.conv.bias", "module.VIT.blocks.1.attn.qkv.weight", "module.VIT.blocks.1.attn.qkv.bias", "module.VIT.blocks.1.attn.proj.weight", "module.VIT.blocks.1.attn.proj.bias", "module.VIT.blocks.1.norm2.weight", "module.VIT.blocks.1.norm2.bias", "module.VIT.blocks.1.mlp.fc1.weight", "module.VIT.blocks.1.mlp.fc1.bias", "module.VIT.blocks.1.mlp.fc2.weight", "module.VIT.blocks.1.mlp.fc2.bias", "module.VIT.norm.weight", "module.VIT.norm.bias"]
    state_dict = checkpoint['state_dict']
    for key in vit_keys:
        state_dict.pop(key, None)
    best_acc = checkpoint['best_acc']
    best_acc = best_acc.to()
    print(f'best_acc:{best_acc}')
    model.load_state_dict(state_dict)
    print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(checkpoint_path))


def extract_feature(val_loader, model):
    # switch to evaluate mode
    model.eval()
    # D = [[0, 0, 0, 0, 0, 0, 0],
    #      [0, 0, 0, 0, 0, 0, 0],
    #      [0, 0, 0, 0, 0, 0, 0],
    #      [0, 0, 0, 0, 0, 0, 0],
    #      [0, 0, 0, 0, 0, 0, 0],
    #      [0, 0, 0, 0, 0, 0, 0],
    #      [0, 0, 0, 0, 0, 0, 0]]
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.cuda()
            target = target.cuda()
            output = model(images)

            target = target.squeeze().cpu().numpy()
            output = output.squeeze().cpu().numpy()

            print(np.shape(output), target)
            # im_re_label = np.array(target)
            # im_pre_label = np.array(output)
            # y_ture = im_re_label.flatten()
            # im_re_label.transpose()
            # y_pred = im_pre_label.flatten()
            # im_pre_label.transpose()

            # C = metrics.confusion_matrix(y_ture, y_pred, labels=[0, 1, 2, 3, 4, 5, 6])
            # D += C

            # if i % args.print_freq == 0:
            #     progress.display(i)

        # print(' **** Accuracy {top1.avg:.3f} *** '.format(top1=top1))
        # with open('./log/' + time_str + 'log.txt', 'a') as f:
        #     f.write(' * Accuracy {top1.avg:.3f}'.format(top1=top1) + '\n')
    # print(D)
    return output, target#top1.avg, losses.avg, output, target, D
# Save the feature tensor (147x768)

extract_feature(val_loader, model)

=> loading checkpoint '/home/hyojinju/POSTER_V2/checkpoint/caer-s-model_best.pth'
best_acc:93.0068588256836
=> loaded checkpoint '/home/hyojinju/POSTER_V2/checkpoint/caer-s-model_best.pth' (epoch 181)
(147, 768) 0
(147, 768) 0
(147, 768) 0
(147, 768) 0
(147, 768) 0
(147, 768) 0
(147, 768) 0
(147, 768) 1
(147, 768) 1
(147, 768) 1
(147, 768) 1
(147, 768) 1
(147, 768) 1
(147, 768) 1


(array([[ 2.1857016 , -4.6753135 ,  1.9146696 , ..., -3.053631  ,
         -2.1800702 ,  1.8267777 ],
        [ 4.5367804 , -6.880458  ,  3.6301682 , ..., -4.810457  ,
         -3.3279278 ,  2.703135  ],
        [ 4.4297442 , -8.695255  ,  3.3854532 , ..., -5.882805  ,
         -3.3752832 ,  4.076563  ],
        ...,
        [ 0.09259506, -0.37725556,  0.53169596, ..., -0.6998611 ,
         -0.6563517 , -0.00916411],
        [-0.1701906 , -0.72928655,  0.2515726 , ..., -0.03692304,
         -0.12921822, -0.64030516],
        [-0.053938  , -0.6623482 ,  0.3417162 , ..., -1.1521099 ,
         -0.94193435, -0.38627896]], dtype=float32),
 array(1))