In [1]:
import sys

sys.path.append('..')
sys.path.append('../..')

In [2]:
import clingo
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import multilabel_confusion_matrix, precision_recall_fscore_support, f1_score
 
import torch
from torch.utils.data import DataLoader

In [3]:
from dnf_layer import SemiSymbolicLayerType
from rule_learner import DNFClassifier
from test_common import SyntheticDataset
from utils import DeltaDelayedExponentialDecayScheduler

In [4]:
SYNTH_DATA_PATH = 'synth_multi_label_data_in15_conj5.npz'
RNG_SEED = 75
BATCH_SIZE = 64
NUM_EPOCHS = 100

In [5]:
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)

In [6]:
dnpz = np.load(SYNTH_DATA_PATH)
full_nullary = dnpz['nullary']
full_target = dnpz['target']

In [7]:
nullary_full_train, nullary_test, target_full_train, target_test = train_test_split(
    full_nullary, full_target, random_state=RNG_SEED)
nullary_train, nullary_val, target_train, target_val = train_test_split(
    nullary_full_train, target_full_train, test_size=0.2, random_state=RNG_SEED
)

train_loader = DataLoader(SyntheticDataset(nullary_train, target_train), BATCH_SIZE)
val_dataset = SyntheticDataset(nullary_val, target_val)
val_loader = DataLoader(val_dataset, len(val_dataset))
test_dataset = SyntheticDataset(nullary_test, target_test)
test_loader = DataLoader(test_dataset, len(test_dataset))

In [8]:
model = DNFClassifier(15, 5, 3, 0.1)
delta_decay_scheduler = DeltaDelayedExponentialDecayScheduler(
    initial_delta=0.1,
    delta_decay_delay=10,
    delta_decay_steps=1,
    delta_decay_rate=1.1,
)
optimiser = torch.optim.Adam(model.parameters(), 0.001)
criterion = torch.nn.BCELoss()

In [9]:
for i in range(NUM_EPOCHS):
    model.train()
    epoch_loss = []
    for x, y in train_loader:
        optimiser.zero_grad()
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        loss = criterion(y_hat.squeeze(), (y == 1).float() )
        loss.backward()
        optimiser.step()

        epoch_loss.append(loss.item())
    
    model.eval()
    macro_f1 = None

    for x, y in val_loader:
        with torch.no_grad():
            y_hat = (torch.tanh(model(x.float())) + 1) / 2
            y_pred = torch.where(y_hat > 0.5, 1, 0)
            macro_f1 = f1_score(y, y_pred, average='macro')

    new_delta_val = delta_decay_scheduler.step(model, i)
    avg_loss = sum(epoch_loss) / len(epoch_loss)
    print(f'[{i + 1:3d}] Delta: {new_delta_val:.3f}  '
          f'Train avg loss: {avg_loss:.3f}  '
          f'Val macro f1: {macro_f1:.3f}')

[  1] Delta: 0.100  Train avg loss: 0.601  Val macro f1: 0.913
[  2] Delta: 0.100  Train avg loss: 0.435  Val macro f1: 0.924
[  3] Delta: 0.100  Train avg loss: 0.333  Val macro f1: 0.938
[  4] Delta: 0.100  Train avg loss: 0.271  Val macro f1: 0.967
[  5] Delta: 0.100  Train avg loss: 0.230  Val macro f1: 0.975
[  6] Delta: 0.100  Train avg loss: 0.202  Val macro f1: 0.977
[  7] Delta: 0.100  Train avg loss: 0.182  Val macro f1: 0.978
[  8] Delta: 0.100  Train avg loss: 0.168  Val macro f1: 0.979
[  9] Delta: 0.100  Train avg loss: 0.156  Val macro f1: 0.980
[ 10] Delta: 0.100  Train avg loss: 0.147  Val macro f1: 0.980
[ 11] Delta: 0.100  Train avg loss: 0.140  Val macro f1: 0.980
[ 12] Delta: 0.110  Train avg loss: 0.134  Val macro f1: 0.980
[ 13] Delta: 0.121  Train avg loss: 0.128  Val macro f1: 0.979
[ 14] Delta: 0.133  Train avg loss: 0.123  Val macro f1: 0.977
[ 15] Delta: 0.146  Train avg loss: 0.119  Val macro f1: 0.977
[ 16] Delta: 0.161  Train avg loss: 0.115  Val macro f1

In [10]:
for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
print(f'Test macro F1: {macro_f1:.3f}')

Test macro F1: 1.000


In [11]:
torch.save(model.state_dict(), f'multi_label_dnf_synth_{RNG_SEED}.pth')

In [12]:
multilabel_confusion_matrix(y, y_pred)

array([[[ 330,    0],
        [   0, 2170]],

       [[  95,    0],
        [   0, 2405]],

       [[ 863,    0],
        [   0, 1637]]])

In [13]:
# sd = torch.load(f'multi_label_dnf_synth_{RNG_SEED}.pth')
# model.load_state_dict(sd)

In [14]:
def prune_layer_weight(
    model,
    layer_type: SemiSymbolicLayerType,
    epsilon: float,
    data_loader: DataLoader,
) -> int:
    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        curr_weight = model.dnf.conjunctions.weights.data.clone()
    else:
        curr_weight = model.dnf.disjunctions.weights.data.clone()

    for x, y in data_loader:
        with torch.no_grad():
            y_hat = (torch.tanh(model(x.float())) + 1) / 2
            y_pred = torch.where(y_hat > 0.5, 1, 0)
            og_macro_f1 = f1_score(y, y_pred, average='macro')

    prune_count = 0
    weight_device = curr_weight.device

    flatten_weight_len = len(torch.reshape(curr_weight, (-1,)))
    base_iterator = range(flatten_weight_len)
    for i in base_iterator:
        curr_weight_flatten = torch.reshape(curr_weight, (-1,))

        if curr_weight_flatten[i] == 0:
            continue

        mask = torch.ones(flatten_weight_len, device=weight_device)
        mask[i] = 0
        mask = mask.reshape(curr_weight.shape)

        masked_weight = curr_weight * mask

        if layer_type == SemiSymbolicLayerType.CONJUNCTION:
            model.dnf.conjunctions.weights.data = masked_weight
        else:
            model.dnf.disjunctions.weights.data = masked_weight

        for x, y in data_loader:
            with torch.no_grad():
                y_hat = (torch.tanh(model(x.float())) + 1) / 2
                y_pred = torch.where(y_hat > 0.5, 1, 0)
                new_macro_f1 = f1_score(y, y_pred, average='macro')

        performance_drop = og_macro_f1 - new_macro_f1
        if performance_drop < epsilon:
            prune_count += 1
            curr_weight *= mask

    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        model.dnf.conjunctions.weights.data = curr_weight
    else:
        model.dnf.disjunctions.weights.data = curr_weight
    return prune_count

In [15]:
prune_epsilon = 0.005
print('Prune disj layer')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.DISJUNCTION, 
    prune_epsilon, val_loader)

for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')

print(f'Pruned disj count:   {prune_count}')
print(f'New perf after disj: {macro_f1:.3f}')
print(multilabel_confusion_matrix(y, y_pred))

print('Prune conj layer')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.CONJUNCTION, 
    prune_epsilon, val_loader)
correct = 0
total = 0

for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')

print(f'Pruned conj count:   {prune_count}')
print(f'New perf after disj: {macro_f1:.3f}')
print(multilabel_confusion_matrix(y, y_pred))

Prune disj layer
Pruned disj count:   8
New perf after disj: 0.997
[[[ 330    0]
  [  34 2136]]

 [[  95    0]
  [   0 2405]]

 [[ 863    0]
  [   0 1637]]]
Prune conj layer
Pruned conj count:   62
New perf after disj: 0.994
[[[ 330    0]
  [   0 2170]]

 [[  95    0]
  [  80 2325]]

 [[ 863    0]
  [   0 1637]]]


In [16]:
for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
print(f'Pref:    {macro_f1:.3f}')

Pref:    0.994


In [17]:
pre_threshold_sd = model.state_dict()

In [18]:
def apply_threshold(
    model, og_conj_weight, og_disj_weight, t_val, const: float = 6.0,
) -> None:
    new_conj_weight = (
        (torch.abs(og_conj_weight) > t_val) * torch.sign(og_conj_weight) * const
    )
    model.dnf.conjunctions.weights.data = new_conj_weight

    new_disj_weight = (
        (torch.abs(og_disj_weight) > t_val) * torch.sign(og_disj_weight) * const
    )
    model.dnf.disjunctions.weights.data = new_disj_weight

In [19]:
conj_min = torch.min(model.dnf.conjunctions.weights.data)
conj_max = torch.max(model.dnf.conjunctions.weights.data)
disj_min = torch.min(model.dnf.disjunctions.weights.data)
disj_max = torch.max(model.dnf.disjunctions.weights.data)

abs_max = torch.max(torch.abs(torch.Tensor([conj_min, conj_max, disj_min, disj_max])))

og_conj_weight = model.dnf.conjunctions.weights.data.clone()
og_disj_weight = model.dnf.disjunctions.weights.data.clone()

t_vals = torch.arange(0, abs_max + 0.01, 0.01)

In [20]:
# model.load_state_dict(pre_threshold_sd)
acc_scores = []
for v in t_vals:
    apply_threshold(model, og_conj_weight, og_disj_weight,
                              v, 6.0)
    for x, y in val_loader:
        with torch.no_grad():
            y_hat = torch.tanh(model(x.float()))
            y_pred = torch.where(y_hat > 0, 1, 0)
            macro_f1 = f1_score(y, y_pred, average='macro')
    acc_scores.append(macro_f1)

best_acc_score = max(acc_scores)
best_t = t_vals[torch.argmax(torch.Tensor(acc_scores))]
print(f'Best t: {best_t.item():.3f}    Macro f1: {best_acc_score:.3f}')
apply_threshold(model, og_conj_weight, og_disj_weight, best_t)

Best t: 0.340    Macro f1: 0.992


In [21]:
for x, y in test_loader:
    with torch.no_grad():
        y_hat = torch.tanh(model(x.float()))
        y_pred = torch.where(y_hat > 0, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
    print(f'Test macro f1: {macro_f1:.3f}')

Test macro f1: 0.994


In [22]:
multilabel_confusion_matrix(y, y_pred)

array([[[ 330,    0],
        [   0, 2170]],

       [[   0,   95],
        [   0, 2405]],

       [[ 863,    0],
        [   0, 1637]]])

In [23]:
final_sd = model.state_dict()

output_rules = []

# Get all conjunctions
conj_w = final_sd["dnf.conjunctions.weights"]
conjunction_map = dict()
for i, w in enumerate(conj_w):
    if torch.all(w == 0):
        # No conjunction is applied here
        continue

    conjuncts = []
    for j, v in enumerate(w):
        if v < 0:
            # Negative weight, negate the atom
            conjuncts.append(f"not a{j}")
        elif v > 0:
            # Positive weight, normal atom
            conjuncts.append(f"a{j}")

    conjunction_map[i] = conjuncts

# Get DNF
disj_w = final_sd["dnf.disjunctions.weights"]
not_covered_classes = []
for i, w in enumerate(disj_w):
    if torch.all(w == 0):
        # No DNF for class i
        not_covered_classes.append(i)
        continue

    disjuncts = []
    for j, v in enumerate(w):
        if v < 0 and j in conjunction_map:
            # Negative weight, negate the existing conjunction
            
            # Need to add auxilary predicate (conj_X) which is not yet
            # in the final rules list
            output_rules.append(
                f"c{j} :- {', '.join(conjunction_map[j])}."
            )
            output_rules.append(f"l{i} :- not c{j}.")
        elif v > 0 and j in conjunction_map:
            # Postivie weight, add normal conjunction
            body = ", ".join(conjunction_map[j])
            output_rules.append(f"l{i} :- {body}.")

In [24]:
final_sd

OrderedDict([('dnf.conjunctions.weights',
              tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
                        0.],
                      [ 6.,  0.,  6.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
                        0.],
                      [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -6.,  0., -6.,  0.,  0.,
                        0.],
                      [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -6.,  6.,
                        6.],
                      [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  6.,  0.,  0.,  0.,  0.,
                       -0.]])),
             ('dnf.disjunctions.weights',
              tensor([[ 0.,  0.,  6.,  6.,  0.],
                      [ 0.,  6., -6.,  6., -6.],
                      [ 0.,  6.,  0.,  0.,  0.]]))])

In [25]:
output_rules

['l0 :- not a9, not a11.',
 'l0 :- not a12, a13, a14.',
 'l1 :- a0, a2.',
 'c2 :- not a9, not a11.',
 'l1 :- not c2.',
 'l1 :- not a12, a13, a14.',
 'c4 :- a9.',
 'l1 :- not c4.',
 'l2 :- a0, a2.']

In [26]:
show_statements = [f'#show l{i}/0.' for i in range(3)]


y_pred_batch = []

for x_batch, y_batch in test_loader:
    for i in range(len(x_batch)):
        x = x_batch[i]
        y = y_batch[i]
        x_asp = [f"a{j}." for j in range(len(x)) if x[j] == 1]
        ctl = clingo.Control(["--warn=none"])
        ctl.add("base", [], " ".join(x_asp + output_rules + show_statements))
        ctl.ground([("base", [])])
        sh = ctl.solve(yield_=True)
        if str(sh.model()) == '':
            prend_one_hot = torch.zeros(y.shape)
        else:
            pred_classes = [int(m[1:]) for m in str(sh.model()).split(' ')]
            pred_one_hot = torch.ones(y.shape)
            pred_one_hot[pred_classes] = 1
        y_pred_batch.append(pred_one_hot)
    y_pred = torch.cat(y_pred_batch).reshape(-1, 3)
    macro_f1 = f1_score(y_batch, y_pred, average='macro')
    print(macro_f1)
    print(multilabel_confusion_matrix(y_batch, y_pred))

0.9004543090242665
[[[   0  330]
  [   0 2170]]

 [[   0   95]
  [   0 2405]]

 [[   0  863]
  [   0 1637]]]


In [27]:
dnpz['rule_str']

array(['c0 :- a0, a2.', 'c1 :- not a3, not a4, not a5.',
       'c2 :- a6, a7, a8.', 'c3 :- not a9, not a11.',
       'c4 :- not a12, a13, a14.', 'l0 :- c3.', 'l0 :- c4.', 'l1 :- c0.',
       'l1 :- c1.', 'l1 :- c4.', 'l2 :- c0.'], dtype='<U29')