In [1]:
from vit_pipeline import *

import os
import torch
import torch.utils.data
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pathlib
import typing

import context_handlers
import models
import utils
import data_preprocessing
from ltn_support import *

In [8]:
def compute_sat_testing_value(logits_to_predicate,
                         prediction, labels_coarse, labels_fine):
    """
    compute satagg function for rules
    argument:
      - logits_to_predicate: get the satisfaction of a variable given the label
      - prediction: output of fine tuner, 
      - labels_coarse, labels_fine: ground truth of coarse and fine label
      - fine_to_coarse: dictionary mapping fine-grain class to coarse-grain class

    return:
      sat_agg: sat_agg for all the rules

    """
    Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
    And = ltn.Connective(ltn.fuzzy_ops.AndProd())
    Implies = ltn.Connective(ltn.fuzzy_ops.ImpliesReichenbach())
    Forall = ltn.Quantifier(
        ltn.fuzzy_ops.AggregPMeanError(p=4), quantifier="f")
    SatAgg = ltn.fuzzy_ops.SatAgg()

    fine_label_dict = {name: label for label, name in enumerate(data_preprocessing.fine_grain_classes)}
    coarse_label_dict = {name: label + len(data_preprocessing.fine_grain_classes) for label, name in enumerate(data_preprocessing.coarse_grain_classes)}
    labels_fine = labels_fine.detach().to('cpu')
    labels_coarse = labels_coarse.detach().to('cpu') + len(data_preprocessing.fine_grain_classes)
        
    # Define constant
    l = {}
    num_labels = len(data_preprocessing.fine_grain_classes) + len(data_preprocessing.coarse_grain_classes)
    for label in range(num_labels):
        one_hot = torch.zeros(num_labels)
        one_hot[label] = 1.0
        l[label] = ltn.Constant(one_hot, trainable=True)

    # Define variables
    x_variables = {}
    x = ltn.Variable("x", prediction)

    for name, label in fine_label_dict.items():
        x_variables[label] = ltn.Variable(
            name, prediction[labels_fine == label])
    for name, label in coarse_label_dict.items():
        x_variables[label] = ltn.Variable(
            name, prediction[labels_coarse == label])

    sat_agg_label = []

    # Rewrite the inconsistency code (Forall(x, Implies(P(x,coarse_label), Not(P(x,coarse_to_not_fine))))
    for coarse_label, i in coarse_label_dict.items():
        for fine_label, j in fine_label_dict.items():
            corresponding_coarse_label = data_preprocessing.fine_to_course_idx[j] + len(fine_label_dict)
            if (corresponding_coarse_label != i):
                satisfaction = Forall(x,
                                      Implies(logits_to_predicate(x,l[i]), 
                                              Not(logits_to_predicate(x,l[j]))
                                      )
                )
                sat_agg_label.append([0, 
                                      f"for all x, P(x, l[{coarse_label}]) imply -P(x, l[{fine_label}])", 
                                      satisfaction.value.detach().item()])
            else:
                satisfaction = Forall(x,
                                      Implies(logits_to_predicate(x,l[i]), 
                                              Not(logits_to_predicate(x,l[j]))
                                      )
                )
                sat_agg_label.append([0, 
                                      f"wrong rules: for all x, P(x, l[{coarse_label}]) imply -P(x, l[{fine_label}])", 
                                      satisfaction.value.detach().item()])
    # Coarse labels: for all x[i], x[i] -> l[i]

    for coarse_label, i in coarse_label_dict.items():
        satisfaction = Forall(x_variables[i], logits_to_predicate(x_variables[i], l[i]))
        sat_agg_label.append([1, 
                              f'for all {coarse_label}, P(x[{coarse_label}], l[{coarse_label}])', 
                              satisfaction.value.detach().item()])

    # Coarse Label: for all x, - (P(x, l[coarse_1] and x[different coarse]}

    for coarse_label_1, i in coarse_label_dict.items():
        for coarse_label_2, j in coarse_label_dict.items():
            if i != j :
                satisfaction = Forall(x, Not(And(logits_to_predicate(x, l[i]), logits_to_predicate(x, l[j]))))
                sat_agg_label.append([2, 
                                      f"for all x, - (P(x, {coarse_label_1}) and P(x,{coarse_label_2}))", 
                                      satisfaction.value.detach().item()])

    # Fine labels: for all x[i], x[i] -> l[i]

    for fine_label, i in fine_label_dict.items():
        satisfaction = Forall(x_variables[i], logits_to_predicate(x_variables[i], l[i]))
        sat_agg_label.append([1, 
                              f'for all {fine_label}, P(x[{fine_label}], l[{fine_label}])', 
                              satisfaction.value.detach().item()])
    # Fine Label: for all x[fine], - {x[fine] and x[different fine]}

    for fine_label_1, i in fine_label_dict.items():
        for fine_label_2, j in fine_label_dict.items():
            if i != j :
                satisfaction = Forall(x, Not(And(logits_to_predicate(x, l[i]), logits_to_predicate(x, l[j]))))
                sat_agg_label.append([2, 
                                      f"for all x, - (P(x, {fine_label_1}) and P(x,{fine_label_2}))", 
                                      satisfaction.value.detach().item()])


    return sat_agg_label


In [4]:
fine_tuner = fine_tuners[0]
device = devices[0]

logits_to_predicate = ltn.Predicate(LogitsToPredicate()).to(ltn.device)

test_loader = loaders['test']
fine_tuner.to(device)
fine_tuner.eval()

test_fine_prediction = []
test_coarse_prediction = []

test_fine_ground_truth = []
test_coarse_ground_truth = []

name_list = []

print(f'Testing {fine_tuner} on {device}...')

with torch.no_grad():
    if utils.is_local():
        from tqdm import tqdm
        gen = tqdm(enumerate(test_loader), total=len(test_loader))
    else:
        gen = enumerate(test_loader)

    for i, data in gen:
        X, Y_fine_grain, names, Y_coarse_grain = data[0].to(device), data[1].to(device), data[2], data[3].to(device)

        Y_pred = fine_tuner(X)
        
        if i == 0:
            Y_pred_all = Y_pred
            Y_fine_grain_all = Y_fine_grain
            Y_coarse_grain_all = Y_coarse_grain
        else:
            Y_pred_all = torch.cat([Y_pred, Y_pred_all], dim=0)
            Y_fine_grain_all = torch.cat([Y_fine_grain, Y_fine_grain_all], dim=0)
            Y_coarse_grain_all = torch.cat([Y_coarse_grain, Y_coarse_grain_all], dim=0)

        Y_pred_fine_grain = Y_pred[:, :len(data_preprocessing.fine_grain_classes)]
        Y_pred_coarse_grain = Y_pred[:, len(data_preprocessing.fine_grain_classes):]

        predicted_fine = torch.max(Y_pred_fine_grain, 1)[1]
        predicted_coarse = torch.max(Y_pred_coarse_grain, 1)[1]

        test_fine_ground_truth += Y_fine_grain.tolist()
        test_coarse_ground_truth += Y_coarse_grain.tolist()

        test_fine_prediction += predicted_fine.tolist()
        test_coarse_prediction += predicted_coarse.tolist()

        name_list += names

Testing vit_b_16 on mps...


In [10]:
rule_and_confidence_score_list = compute_sat_testing_value(logits_to_predicate,
                        Y_pred_all, Y_coarse_grain_all, Y_fine_grain_all)

test_fine_accuracy, test_coarse_accuracy = (
    get_and_print_metrics(fine_predictions=test_fine_prediction,
                          coarse_predictions=test_coarse_prediction, 
                          loss="BCE"))


with [94mBCE[0m loss

Fine-grain prior combined accuracy: [92m69.15[0m%, fine-grain prior combined average f1: [92m67.26[0m%
Coarse-grain prior combined accuracy: [92m84.33[0m%, coarse-grain prior combined average f1: [92m82.67[0m%

Total prior inconsistencies [91m60[0m/[91m1621[0m ([91m3.7[0m%)


In [22]:
# Filter the list
filtered_list = [item for item in rule_and_confidence_score_list if item[0] == 0 and "wrong" in item[1]]


In [23]:
filtered_list

[[0,
  'wrong rules: for all x, P(x, l[Air Defense]) imply -P(x, l[30N6E])',
  0.7139595150947571],
 [0,
  'wrong rules: for all x, P(x, l[Air Defense]) imply -P(x, l[Iskander])',
  0.5634914636611938],
 [0,
  'wrong rules: for all x, P(x, l[Air Defense]) imply -P(x, l[Pantsir-S1])',
  0.6213214993476868],
 [0,
  'wrong rules: for all x, P(x, l[Air Defense]) imply -P(x, l[Rs-24])',
  0.5094217658042908],
 [0,
  'wrong rules: for all x, P(x, l[BMD]) imply -P(x, l[BMD])',
  0.6057076454162598],
 [0,
  'wrong rules: for all x, P(x, l[BMP]) imply -P(x, l[BMP-1])',
  0.7136292457580566],
 [0,
  'wrong rules: for all x, P(x, l[BMP]) imply -P(x, l[BMP-2])',
  0.6492296457290649],
 [0,
  'wrong rules: for all x, P(x, l[BMP]) imply -P(x, l[BMP-T15])',
  0.6144813299179077],
 [0,
  'wrong rules: for all x, P(x, l[BTR]) imply -P(x, l[BRDM])',
  0.5974348783493042],
 [0,
  'wrong rules: for all x, P(x, l[BTR]) imply -P(x, l[BTR-60])',
  0.49617326259613037],
 [0,
  'wrong rules: for all x, P(x, l[

In [16]:
for item in rule_and_confidence_score_list:
  pass

In [18]:
len(rule_and_confidence_score_list)

793

In [19]:
item[1]

'for all x, - (P(x, Tornado) and P(x,TOS-1))'