In [1]:
import re
import sys

sys.path.append('..')
sys.path.append('../..')

In [2]:
import clingo
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    multilabel_confusion_matrix,
    precision_recall_fscore_support,
    f1_score,
)
 
import torch
from torch.utils.data import DataLoader

In [3]:
from dnf_layer import SemiSymbolicLayerType
from rule_learner import DNFClassifier
from test_common import SyntheticDataset
from utils import DeltaDelayedExponentialDecayScheduler
from dnf_post_train import (
    remove_unused_conjunctions,
    remove_disjunctions_when_empty_conjunctions,
    apply_threshold,
    extract_asp_rules,
)

In [4]:
RNG_SEED = 73
BATCH_SIZE = 256
NUM_EPOCHS = 100

NUM_IN = 150
NUM_CONJ = 75
NUM_LABELS = 25

# Generate dataset with `multi_label_syn_data_gen.py` first
SYNTH_DATA_PATH = f'synth_multi_label_data_in{NUM_IN}_conj{NUM_CONJ}_label{NUM_LABELS}.npz'

In [5]:
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)

In [6]:
dnpz = np.load(SYNTH_DATA_PATH)
full_nullary = dnpz['nullary']
full_target = dnpz['target']

In [7]:
nullary_full_train, nullary_test, target_full_train, target_test = train_test_split(
    full_nullary, full_target, random_state=RNG_SEED)
nullary_train, nullary_val, target_train, target_val = train_test_split(
    nullary_full_train, target_full_train, test_size=0.2, random_state=RNG_SEED
)

train_loader = DataLoader(SyntheticDataset(nullary_train, target_train), BATCH_SIZE)
val_dataset = SyntheticDataset(nullary_val, target_val)
val_loader = DataLoader(val_dataset, len(val_dataset)) # Load everything in 1 val loader
test_dataset = SyntheticDataset(nullary_test, target_test)
test_loader = DataLoader(test_dataset, len(test_dataset)) # Load everything in 1 test loader

In [8]:
model = DNFClassifier(NUM_IN, NUM_CONJ, NUM_LABELS, 0.1)
delta_decay_scheduler = DeltaDelayedExponentialDecayScheduler(
    initial_delta=0.1,
    delta_decay_delay=10,
    delta_decay_steps=1,
    delta_decay_rate=1.1,
)
optimiser = torch.optim.Adam(model.parameters(), 0.001)
criterion = torch.nn.BCEWithLogitsLoss()

# Training

In [9]:
for i in range(NUM_EPOCHS):
    model.train()
    epoch_loss = []
    for x_train, y_train in train_loader:
        optimiser.zero_grad()
        y_hat = model(x_train.float())
        loss = criterion(y_hat.squeeze(), (y_train == 1).float() )
        loss.backward()
        optimiser.step()

        epoch_loss.append(loss.item())
    
    model.eval()
    macro_f1 = None

    for x_val, y_val in val_loader:
        with torch.no_grad():
            y_hat = torch.sigmoid(model(x_val.float()))
            y_pred = torch.where(y_hat > 0.5, 1, 0)
            macro_f1 = f1_score(y_val, y_pred, average='macro', zero_division=0)

    new_delta_val = delta_decay_scheduler.step(model, i)
    avg_loss = sum(epoch_loss) / len(epoch_loss)
    print(f'[{i + 1:3d}] Delta: {new_delta_val:.3f}  '
          f'Train avg loss: {avg_loss:.3f}  '
          f'Val macro f1: {macro_f1:.3f}')

[  1] Delta: 0.100  Train avg loss: 0.699  Val macro f1: 0.487
[  2] Delta: 0.100  Train avg loss: 0.588  Val macro f1: 0.431
[  3] Delta: 0.100  Train avg loss: 0.545  Val macro f1: 0.436
[  4] Delta: 0.100  Train avg loss: 0.515  Val macro f1: 0.468
[  5] Delta: 0.100  Train avg loss: 0.486  Val macro f1: 0.509
[  6] Delta: 0.100  Train avg loss: 0.457  Val macro f1: 0.563
[  7] Delta: 0.100  Train avg loss: 0.426  Val macro f1: 0.623
[  8] Delta: 0.100  Train avg loss: 0.396  Val macro f1: 0.686
[  9] Delta: 0.100  Train avg loss: 0.367  Val macro f1: 0.745
[ 10] Delta: 0.100  Train avg loss: 0.339  Val macro f1: 0.793
[ 11] Delta: 0.100  Train avg loss: 0.313  Val macro f1: 0.832
[ 12] Delta: 0.110  Train avg loss: 0.290  Val macro f1: 0.866
[ 13] Delta: 0.121  Train avg loss: 0.274  Val macro f1: 0.882
[ 14] Delta: 0.133  Train avg loss: 0.262  Val macro f1: 0.899
[ 15] Delta: 0.146  Train avg loss: 0.252  Val macro f1: 0.911
[ 16] Delta: 0.161  Train avg loss: 0.245  Val macro f1

In [10]:
for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
print(f'Test macro F1: {macro_f1:.3f}')
print()
# print(multilabel_confusion_matrix(y, y_pred))

Test macro F1: 0.987



In [11]:
torch.save(model.state_dict(), f'multi_label_dnf_synth_{RNG_SEED}.pth')

In [12]:
# sd = torch.load(f'multi_label_dnf_synth_{RNG_SEED}.pth')
# model.load_state_dict(sd)

# Post Training Process

In [13]:
def prune_layer_weight(
    model,
    layer_type: SemiSymbolicLayerType,
    epsilon: float,
    data_loader: DataLoader, # This should be val loader
) -> int:
    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        curr_weight = model.dnf.conjunctions.weights.data.clone()
    else:
        curr_weight = model.dnf.disjunctions.weights.data.clone()

    for x, y in data_loader:
        with torch.no_grad():
            # Assuming val loader (only 1 batch)
            y_hat = (torch.tanh(model(x.float())) + 1) / 2
            y_pred = torch.where(y_hat > 0.5, 1, 0)
            og_macro_f1 = f1_score(y, y_pred, average='macro')

    prune_count = 0
    weight_device = curr_weight.device

    flatten_weight_len = len(torch.reshape(curr_weight, (-1,)))
    base_iterator = range(flatten_weight_len)
    for i in base_iterator:
        curr_weight_flatten = torch.reshape(curr_weight, (-1,))

        if curr_weight_flatten[i] == 0:
            continue

        mask = torch.ones(flatten_weight_len, device=weight_device)
        mask[i] = 0
        mask = mask.reshape(curr_weight.shape)

        masked_weight = curr_weight * mask

        if layer_type == SemiSymbolicLayerType.CONJUNCTION:
            model.dnf.conjunctions.weights.data = masked_weight
        else:
            model.dnf.disjunctions.weights.data = masked_weight

        for x, y in data_loader:
            with torch.no_grad():
                y_hat = (torch.tanh(model(x.float())) + 1) / 2
                y_pred = torch.where(y_hat > 0.5, 1, 0)
                new_macro_f1 = f1_score(y, y_pred, average='macro')

        performance_drop = og_macro_f1 - new_macro_f1
        if performance_drop < epsilon:
            prune_count += 1
            curr_weight *= mask

    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        model.dnf.conjunctions.weights.data = curr_weight
    else:
        model.dnf.disjunctions.weights.data = curr_weight
    return prune_count

In [14]:
prune_epsilon = 0.005
print('Prune disj layer')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.DISJUNCTION,
    prune_epsilon, val_loader)

for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')

print(f'Pruned disj count:        {prune_count}')
print(f'New test perf after disj: {macro_f1:.3f}')
# print(multilabel_confusion_matrix(y, y_pred))
print()

unused_conj = remove_unused_conjunctions(model)
print(f'Remove unused conjunctions: {unused_conj}')
print()

print('Prune conj layer')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.CONJUNCTION, 
    prune_epsilon, val_loader)


for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
print(f'Pruned conj count:        {prune_count}')
print(f'New test perf after disj: {macro_f1:.3f}')
# print(multilabel_confusion_matrix(y, y_pred))
print()

removed_disj = remove_disjunctions_when_empty_conjunctions(model)
print(
    f'Remove disjunction that uses empty conjunctions: {removed_disj}'
)
print()

print('Prune disj layer again')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.DISJUNCTION,
    prune_epsilon, val_loader)

for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')

print(f'Pruned disj count:        {prune_count}')
print(f'New test perf after disj: {macro_f1:.3f}')
# print(multilabel_confusion_matrix(y, y_pred))
print()

Prune disj layer
Pruned disj count:        1815
New test perf after disj: 0.981

Remove unused conjunctions: 33

Prune conj layer
Pruned conj count:        5843
New test perf after disj: 0.976

Remove disjunction that uses empty conjunctions: 900

Prune disj layer again
Pruned disj count:        24
New test perf after disj: 0.975



In [15]:
for x, y in test_loader:
    with torch.no_grad():
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        y_pred = torch.where(y_hat > 0.5, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
print(f'Prune procedure final f1:    {macro_f1:.3f}')
print()
# print(multilabel_confusion_matrix(y, y_pred))

Prune procedure final f1:    0.975



In [16]:
# pre_threshold_sd = model.state_dict()
conj_min = torch.min(model.dnf.conjunctions.weights.data)
conj_max = torch.max(model.dnf.conjunctions.weights.data)
disj_min = torch.min(model.dnf.disjunctions.weights.data)
disj_max = torch.max(model.dnf.disjunctions.weights.data)

abs_max = torch.max(torch.abs(torch.Tensor([conj_min, conj_max, disj_min, disj_max])))

og_conj_weight = model.dnf.conjunctions.weights.data.clone()
og_disj_weight = model.dnf.disjunctions.weights.data.clone()

t_vals = torch.arange(0, abs_max + 0.01, 0.01)

In [17]:
# model.load_state_dict(pre_threshold_sd)
acc_scores = []
for v in t_vals:
    apply_threshold(model, og_conj_weight, og_disj_weight,
                              v, 6.0)
    for x, y in val_loader:
        with torch.no_grad():
            y_hat = torch.tanh(model(x.float()))
            y_pred = torch.where(y_hat > 0, 1, 0)
            macro_f1 = f1_score(y, y_pred, average='macro')
    acc_scores.append(macro_f1)

best_acc_score = max(acc_scores)
best_t = t_vals[torch.argmax(torch.Tensor(acc_scores))]
print(f'Best t: {best_t.item():.3f}    Macro f1: {best_acc_score:.3f}')
apply_threshold(model, og_conj_weight, og_disj_weight, best_t)

Best t: 0.350    Macro f1: 0.962


In [18]:
for x, y in test_loader:
    with torch.no_grad():
        y_hat = torch.tanh(model(x.float()))
        y_pred = torch.where(y_hat > 0, 1, 0)
        macro_f1 = f1_score(y, y_pred, average='macro')
    print(f'Test macro f1 after threshold: {macro_f1:.3f}')
print()
# print(multilabel_confusion_matrix(y, y_pred))

Test macro f1 after threshold: 0.963



In [19]:
output_rules = extract_asp_rules(model.state_dict())
for r in output_rules:
    print(r)

conj_0 :- has_attr_123.
conj_2 :- not has_attr_39.
conj_3 :- has_attr_49.
conj_4 :- not has_attr_53.
conj_6 :- not has_attr_110, not has_attr_111.
conj_7 :- not has_attr_18, has_attr_19.
conj_14 :- has_attr_90, has_attr_114.
conj_21 :- not has_attr_134, has_attr_135.
conj_22 :- has_attr_145, not has_attr_147.
conj_25 :- has_attr_103.
conj_30 :- has_attr_56, not has_attr_57.
conj_31 :- has_attr_60, not has_attr_65.
conj_32 :- not has_attr_53, has_attr_56.
conj_34 :- has_attr_103, has_attr_114.
conj_36 :- has_attr_114.
conj_37 :- not has_attr_74, has_attr_75.
conj_39 :- has_attr_80, has_attr_114.
conj_40 :- not has_attr_48, has_attr_49, has_attr_123.
conj_42 :- not has_attr_22, not has_attr_23.
conj_43 :- has_attr_126, has_attr_127.
conj_44 :- not has_attr_4, not has_attr_5.
conj_46 :- has_attr_80, not has_attr_81.
conj_53 :- not has_attr_40, has_attr_41.
conj_56 :- not has_attr_27, has_attr_29.
conj_57 :- has_attr_16, has_attr_17.
conj_58 :- not has_attr_4, not has_attr_5, not has_attr_

In [20]:
show_statements = [f'#show label/1.']


y_pred_batch = []

for x_batch, y_batch in test_loader:
    for i in range(len(x_batch)):
        x = x_batch[i]
        y = y_batch[i]
        x_asp = [f"has_attr_{j}." for j in range(len(x)) if x[j] == 1]
        ctl = clingo.Control(["--warn=none"])
        ctl.add("base", [], " ".join(x_asp + output_rules + show_statements))
        ctl.ground([("base", [])])
        with ctl.solve(yield_=True) as handle:
            asp_model = handle.model()
        if not asp_model or str(asp_model) == "":
            prediction_one_hot = torch.zeros(y.shape)
        else:
            # Find predicted all label
            p = re.compile(r"\d+")
            predict_labels = [int(l) for l in p.findall(str(asp_model))]
            prediction_one_hot = torch.zeros(y.shape)
            prediction_one_hot[predict_labels] = 1
        y_pred_batch.append(prediction_one_hot)
    y_pred = torch.cat(y_pred_batch).reshape(-1, NUM_LABELS)
    macro_f1 = f1_score(y_batch, y_pred, average='macro')
    print(f'Rules F1 {macro_f1}')
    print()
    # print(multilabel_confusion_matrix(y_batch, y_pred))

Rules F1 0.9634756277512571



# Original Rules

In [21]:
dnpz['rule_str']

array(['c0 :- a0.', 'c1 :- a2, not a3.', 'c2 :- not a4, not a5.',
       'c3 :- a6, a7.', 'c4 :- a8, not a9.', 'c5 :- not a11.',
       'c6 :- not a12, a13.', 'c7 :- a14, not a15.', 'c8 :- a16, a17.',
       'c9 :- not a18, a19.', 'c10 :- not a20, not a21.',
       'c11 :- not a22, not a23.', 'c12 :- a24.', 'c13 :- a27.',
       'c14 :- not a29.', 'c15 :- not a31.', 'c16 :- a32, a33.',
       'c17 :- not a35.', 'c18 :- not a36, a37.', 'c19 :- not a39.',
       'c20 :- not a40, a41.', 'c21 :- a42, not a43.',
       'c22 :- a44, not a45.', 'c23 :- a46, not a47.',
       'c24 :- not a48, a49.', 'c25 :- not a50, a51.',
       'c26 :- not a52, not a53.', 'c27 :- a55.', 'c28 :- a56, not a57.',
       'c29 :- not a58.', 'c30 :- not a60.', 'c31 :- not a62, a63.',
       'c32 :- a65.', 'c33 :- not a66, a67.', 'c34 :- not a68, not a69.',
       'c35 :- a71.', 'c36 :- not a72, a73.', 'c37 :- not a74, a75.',
       'c38 :- not a76.', 'c39 :- not a78, not a79.',
       'c40 :- a80, not a81.', 'c41 