In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import sys
import torch
import models.models as models
import utils.configuration as Conf
import numpy as np
import data
from torchvision import transforms
import time
import pickle
import utils.process_samples as Process
import matplotlib.pyplot as plt
import torchvision

config = Conf.Config()
config.update_options_from_path('configs/train.yaml')

MODEL_NAME = 'ViT_bidirectional_dualPatches_untie_inference'
TRANSFORM = config.val_dataset_options['transform']
MAX_SEQ_LEN = config.model_options['dec_seq_len']
MULTIPLE_GPU = config.base_options['multi_gpu']
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

config.model_options['in_num_embeddings'] = 100
config.model_options['out_num_embeddings'] = 100
config.model_options['enc_dropout'] = 0.1
config.model_options['dec_dropout'] = 0.1
config.model_options['enc_seq_len'] = 128

if MULTIPLE_GPU:
    model = getattr(models, MODEL_NAME)(config.model_options).to(DEVICE)
    model = torch.nn.DataParallel(model)

model.load_state_dict(torch.load('./best_avg.pth'))
model.eval()

DataParallel(
  (module): ViT_bidirectional_dualPatches_untie_inference(
    (linear_4by8): Linear(in_features=32, out_features=512, bias=True)
    (linear_8by4): Linear(in_features=32, out_features=512, bias=True)
    (image_enc): FeaturesEncoder_dualPatches_untie(
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): ModuleList(
        (0): TransformerEncoderLayer_untie(
          (attention): MultiHeadAttention_untie(
            (q_linear): Linear(in_features=512, out_features=512, bias=True)
            (v_linear): Linear(in_features=512, out_features=512, bias=True)
            (k_linear): Linear(in_features=512, out_features=512, bias=True)
            (out): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout_1): Dropout(p=0.1, inplace=False)
          (norm_1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
          (feedForward): FeedForward(
            (linear_1): Linear(in_features=512, out_features=2048, bias=True)
       

In [17]:
dataset = 'cute80Dataset_Custom'

model.eval()
model.reqiures_grad = False
single_case = True
with_sym = False

corpus = config.corpus

GT_text_RL = []
pred_text_RL_4by8 = []
prob_list_RL_4by8 = []
pred_text_RL_8by4 = []
prob_list_RL_8by4 = []

GT_text_LR = []
pred_text_LR_4by8 = []
prob_list_LR_4by8 = []
pred_text_LR_8by4 = []
prob_list_LR_8by4 = []

with torch.no_grad():
    val_count = 0
    val_acc_word_LR_4by8 = 0
    val_acc_word_RL_4by8 = 0
    val_acc_word_LR_8by4 = 0
    val_acc_word_RL_8by4 = 0

    testDataset = getattr(data, dataset)(**config.test_dataset_options)
    testDataset.pad = False
    testDataset.width = 128
    testDataset.out_channel = 1
    loader = torch.utils.data.DataLoader(testDataset, batch_size=640, shuffle=False)

    temp_pred_bool_arr_LR_4by8 = np.zeros((testDataset.__len__(),))
    temp_pred_bool_arr_RL_4by8 = np.zeros((testDataset.__len__(),))
    temp_pred_bool_arr_LR_8by4 = np.zeros((testDataset.__len__(),))
    temp_pred_bool_arr_RL_8by4 = np.zeros((testDataset.__len__(),))        
    
    counter = 0
    for test_output in loader:
        image_4by8 =  test_output[0]
        image_8by4 =  test_output[1]
        text_LR = test_output[2]
        text_RL = test_output[3]

        image_4by8 = image_4by8.cuda()
        image_8by4 = image_8by4.cuda()
        bs = image_4by8.size(0)

        val_count += image_4by8.size(0)
        output = Process.word_level_acc_bidirectional_dualPatches_untie_inference(model,image_4by8, image_8by4, text_LR, text_RL, corpus, MAX_SEQ_LEN, DEVICE,True, single_case, with_sym)

        temp_acc_LR_4by8 = output[0]
        temp_acc_RL_4by8 = output[1]
        temp_acc_LR_8by4 = output[2]
        temp_acc_RL_8by4 = output[3]     

        temp_pred_bool_arr_LR_4by8[counter:counter+bs] = output[4]
        temp_pred_bool_arr_RL_4by8[counter:counter+bs] = output[5]
        temp_pred_bool_arr_LR_8by4[counter:counter+bs] = output[6]
        temp_pred_bool_arr_RL_8by4[counter:counter+bs] = output[7]

        counter += bs
        val_acc_word_LR_4by8 += temp_acc_LR_4by8
        val_acc_word_RL_4by8 += temp_acc_RL_4by8
        val_acc_word_LR_8by4 += temp_acc_LR_8by4
        val_acc_word_RL_8by4 += temp_acc_RL_8by4

        GT_text_LR += output[8]
        GT_text_RL += output[9]
        pred_text_LR_4by8 += output[10]
        pred_text_RL_4by8 += output[11]
        pred_text_LR_8by4 += output[12]
        pred_text_RL_8by4 += output[13] 

        output_prob_LR_4by8 = output[14] 
        output_prob_RL_4by8 = output[15] 
        output_prob_LR_8by4 = output[16] 
        output_prob_RL_8by4 = output[17] 

        output_prob_LR_4by8, _ = torch.max(output_prob_LR_4by8, dim=2)
        output_prob_RL_4by8, _ = torch.max(output_prob_RL_4by8, dim=2)
        output_prob_LR_8by4, _ = torch.max(output_prob_LR_8by4, dim=2)
        output_prob_RL_8by4, _ = torch.max(output_prob_RL_8by4, dim=2)

        for i in range(bs):
            text_prob_LR_4by8 = 1
            for j in range(len(output[10][i])):
                text_prob_LR_4by8 *= output_prob_LR_4by8[i, j].item()
            prob_list_LR_4by8.append(text_prob_LR_4by8)
        for i in range(bs):
            text_prob_RL_4by8 = 1
            for j in range(len(output[11][i])):
                text_prob_RL_4by8 *= output_prob_RL_4by8[i, j].item()
            prob_list_RL_4by8.append(text_prob_RL_4by8)
        for i in range(bs):
            text_prob_LR_8by4 = 1
            for j in range(len(output[12][i])):
                text_prob_LR_8by4 *= output_prob_LR_8by4[i, j].item()
            prob_list_LR_8by4.append(text_prob_LR_8by4)
        for i in range(bs):
            text_prob_RL_8by4 = 1
            for j in range(len(output[13][i])):
                text_prob_RL_8by4 *= output_prob_RL_8by4[i, j].item()
            prob_list_RL_8by4.append(text_prob_RL_8by4)
            
    
    print('Acc LR 4by8: ', val_acc_word_LR_4by8/val_count)
    print('Acc RL 4by8: ', val_acc_word_RL_4by8/val_count)
    print('Acc LR 8by4: ', val_acc_word_LR_8by4/val_count)
    print('Acc RL 8by4: ', val_acc_word_RL_8by4/val_count)
    print('count: ', val_count)

    print()

prob_bi = np.zeros((testDataset.__len__(), 4))
prob_bi[:, 0] = np.array(prob_list_LR_4by8)
prob_bi[:, 1] = np.array(prob_list_RL_4by8)
prob_bi[:, 2] = np.array(prob_list_LR_8by4)
prob_bi[:, 3] = np.array(prob_list_RL_8by4)

pred_bool_bi = np.zeros((testDataset.__len__(), 4))
pred_bool_bi[:, 0] = temp_pred_bool_arr_LR_4by8
pred_bool_bi[:, 1] = temp_pred_bool_arr_RL_4by8 
pred_bool_bi[:, 2] = temp_pred_bool_arr_LR_8by4 
pred_bool_bi[:, 3] = temp_pred_bool_arr_RL_8by4 

arg_prob_bi = np.argmax(prob_bi, axis=1)

print('Highest Prob in LR 4by8:' ,np.sum(arg_prob_bi == 0))
print('Highest Prob in RL 4by8:' ,np.sum(arg_prob_bi == 1))
print('Highest Prob in LR 8by4:' ,np.sum(arg_prob_bi == 2))
print('Highest Prob in RL 8by4:' ,np.sum(arg_prob_bi == 3))

max_prob_pred_bool = pred_bool_bi[np.arange(len(prob_list_LR_4by8)), arg_prob_bi]
print('Acc:', np.sum(max_prob_pred_bool)/len(prob_list_LR_4by8))


Acc LR 4by8:  0.8854166666666666
Acc RL 4by8:  0.8923611111111112
Acc LR 8by4:  0.90625
Acc RL 8by4:  0.8854166666666666
count:  288

Highest Prob in LR 4by8: 76
Highest Prob in RL 4by8: 50
Highest Prob in LR 8by4: 80
Highest Prob in RL 8by4: 82
Acc: 0.9166666666666666
