In [1]:
import re
import sys

sys.path.append('..')
sys.path.append('../..')

In [2]:
import clingo
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import jaccard_score

import torch
from torch.utils.data import DataLoader

In [3]:
from dnf_layer import SemiSymbolicLayerType
from rule_learner import DNFClassifierEO, DNFClassifier
from test_common import SyntheticDataset
from utils import DeltaDelayedExponentialDecayScheduler
from dnf_post_train import (
    remove_unused_conjunctions,
    remove_disjunctions_when_empty_conjunctions,
)

In [4]:
RNG_SEED = 73

NUM_IN = 150
NUM_CONJ = 75
NUM_CLASSES = 25

SYNTH_DATA_PATH = f'synth_multiclass_data_in{NUM_IN}_conj{NUM_CONJ}_out{NUM_CLASSES}.npz'

BATCH_SIZE = 256
NUM_EPOCHS = 100

In [5]:
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)

In [6]:
dnpz = np.load(SYNTH_DATA_PATH)
full_nullary = dnpz['nullary']
full_target = dnpz['target']

In [7]:
nullary_full_train, nullary_test, target_full_train, target_test = train_test_split(
    full_nullary, full_target, random_state=RNG_SEED)
nullary_train, nullary_val, target_train, target_val = train_test_split(
    nullary_full_train, target_full_train, test_size=0.2, random_state=RNG_SEED
)

train_loader = DataLoader(SyntheticDataset(nullary_train, target_train), BATCH_SIZE)
val_loader = DataLoader(SyntheticDataset(nullary_val, target_val), BATCH_SIZE)
test_loader = DataLoader(SyntheticDataset(nullary_test, target_test), BATCH_SIZE)

In [8]:
model = DNFClassifierEO(NUM_IN, NUM_CONJ, NUM_CLASSES, 0.1)
delta_decay_scheduler = DeltaDelayedExponentialDecayScheduler(
    initial_delta=0.1,
    delta_decay_delay=10,
    delta_decay_steps=1,
    delta_decay_rate=1.1,
)
optimiser = torch.optim.Adam(model.parameters(), 0.001)
criterion = torch.nn.CrossEntropyLoss()

In [9]:
for i in range(NUM_EPOCHS):
    model.train()
    epoch_loss = []
    for x, y in train_loader:
        optimiser.zero_grad()
        y_hat = model(x.float())
        loss = criterion(y_hat, y.float())
        loss.backward()
        optimiser.step()

        epoch_loss.append(loss.item())
    
    model.eval()    
    val_correct = 0
    val_total = 0
    for x, y in val_loader:
        with torch.no_grad():
            y_hat = torch.tanh(model(x.float()))
            _, y_pred = torch.max(y_hat, 1)
            val_correct += torch.sum(y_pred == (torch.argmax(y, dim=1))).item()
            val_total += len(y)

    new_delta_val = delta_decay_scheduler.step(model, i)
    avg_loss = sum(epoch_loss) / len(epoch_loss)
    print(f'[{i + 1:3d}] Delta: {new_delta_val:.3f}  '
          f'Train avg loss: {avg_loss:.3f}  '
          f'Val acc: {val_correct / val_total:.3f}')

[  1] Delta: 0.100  Train avg loss: 2.862  Val acc: 0.042
[  2] Delta: 0.100  Train avg loss: 1.521  Val acc: 0.261
[  3] Delta: 0.100  Train avg loss: 0.435  Val acc: 0.042
[  4] Delta: 0.100  Train avg loss: 0.123  Val acc: 0.042
[  5] Delta: 0.100  Train avg loss: 0.062  Val acc: 0.042
[  6] Delta: 0.100  Train avg loss: 0.041  Val acc: 0.042
[  7] Delta: 0.100  Train avg loss: 0.031  Val acc: 0.042
[  8] Delta: 0.100  Train avg loss: 0.025  Val acc: 0.042
[  9] Delta: 0.100  Train avg loss: 0.020  Val acc: 0.042
[ 10] Delta: 0.100  Train avg loss: 0.017  Val acc: 0.042
[ 11] Delta: 0.100  Train avg loss: 0.015  Val acc: 0.042
[ 12] Delta: 0.110  Train avg loss: 0.013  Val acc: 0.042
[ 13] Delta: 0.121  Train avg loss: 0.013  Val acc: 0.042
[ 14] Delta: 0.133  Train avg loss: 0.013  Val acc: 0.042
[ 15] Delta: 0.146  Train avg loss: 0.013  Val acc: 0.042
[ 16] Delta: 0.161  Train avg loss: 0.014  Val acc: 0.042
[ 17] Delta: 0.177  Train avg loss: 0.015  Val acc: 0.042
[ 18] Delta: 0

In [None]:
# torch.save(model.state_dict(), f'dnfeo_multi_class_synth_{RNG_SEED}.pth')
# model.load_state_dict(torch.load(f'dnfeo_synth_{RNG_SEED}.pth')) 

In [10]:
model2 = DNFClassifier(NUM_IN, NUM_CONJ, NUM_CLASSES, 1)

sd = model.state_dict()
sd.pop('eo_layer.weights')
model2.load_state_dict(sd)

<All keys matched successfully>

In [11]:
def eval_multi_class_dnf(model, data_loader):
    # Return accuracy and Jaccard score
    correct = 0
    total = 0
    jaccard_scores = []
    for x, y in data_loader:
        with torch.no_grad():
            y_hat = torch.tanh(model(x.float()))
            _, y_pred = torch.max(y_hat, 1)
            correct += torch.sum(y_pred == (torch.argmax(y, dim=1))).item()
            total += len(y)
            y_hat_jacc = (y_hat > 0).int()
            jacc = jaccard_score(y.detach().cpu().numpy(), y_hat_jacc.detach().cpu().numpy(), average='samples')
            jaccard_scores.append(jacc)

    return (correct / total), (sum(jaccard_scores) / len(jaccard_scores))

In [12]:
acc, jacc =  eval_multi_class_dnf(model2, test_loader)
print(f'Accuarcy: {acc:.3f}')
print(f'Jaccard:  {jacc:.3f}')

Accuarcy: 1.000
Jaccard:  0.995


# Post Training Process

In [13]:
def prune_layer_weight(
    model,
    layer_type: SemiSymbolicLayerType,
    epsilon: float,
    data_loader: DataLoader,
) -> int:
    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        curr_weight = model.dnf.conjunctions.weights.data.clone()
    else:
        curr_weight = model.dnf.disjunctions.weights.data.clone()

    _, og_jacc = eval_multi_class_dnf(model, data_loader)

    prune_count = 0
    weight_device = curr_weight.device

    flatten_weight_len = len(torch.reshape(curr_weight, (-1,)))
    base_iterator = range(flatten_weight_len)
    for i in base_iterator:
        curr_weight_flatten = torch.reshape(curr_weight, (-1,))

        if curr_weight_flatten[i] == 0:
            continue

        mask = torch.ones(flatten_weight_len, device=weight_device)
        mask[i] = 0
        mask = mask.reshape(curr_weight.shape)

        masked_weight = curr_weight * mask

        if layer_type == SemiSymbolicLayerType.CONJUNCTION:
            model.dnf.conjunctions.weights.data = masked_weight
        else:
            model.dnf.disjunctions.weights.data = masked_weight

        _, new_jacc = eval_multi_class_dnf(model, data_loader)
        performance_drop = og_jacc - new_jacc
        if performance_drop < epsilon:
            prune_count += 1
            curr_weight *= mask

    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        model.dnf.conjunctions.weights.data = curr_weight
    else:
        model.dnf.disjunctions.weights.data = curr_weight
    return prune_count

In [14]:
prune_epsilon = 0.005
print('Prune disj layer')
prune_count = prune_layer_weight(model2, SemiSymbolicLayerType.DISJUNCTION, 
    prune_epsilon, val_loader)
_, new_jacc = eval_multi_class_dnf(model2, val_loader)
print(f'Pruned disj count:   {prune_count}')
print(f'New perf after disj: {new_jacc:.3f}\n')

unused_conj = remove_unused_conjunctions(model)
print(f'Remove unused conjunctions: {unused_conj}')
print()

print('Prune conj layer')
prune_count = prune_layer_weight(model2, SemiSymbolicLayerType.CONJUNCTION, 
    prune_epsilon, val_loader)
_, new_jacc = eval_multi_class_dnf(model2, val_loader)
print(f'Pruned conj count:   {prune_count}')
print(f'New perf after conj: {new_jacc:.3f}\n')

removed_disj = remove_disjunctions_when_empty_conjunctions(model)
print(
    f'Remove disjunction that uses empty conjunctions: {removed_disj}'
)
print()

print('Prune disj layer again')
prune_count = prune_layer_weight(model2, SemiSymbolicLayerType.DISJUNCTION,
    prune_epsilon, val_loader)
_, new_jacc = eval_multi_class_dnf(model2, val_loader)
print(f'Pruned disj count:   {prune_count}')
print(f'New perf after disj: {new_jacc:.3f}')

Prune disj layer
Pruned disj count:   1844
New perf after disj: 0.993

Remove unused conjunctions: 0

Prune conj layer
Pruned conj count:   10203
New perf after conj: 0.988

Remove disjunction that uses empty conjunctions: 0

Prune disj layer again
Pruned disj count:   4
New perf after disj: 0.991


In [15]:
sd = model.state_dict()
sd.pop('eo_layer.weights')
model2.load_state_dict(sd)
pre_threshold_sd = model2.state_dict()

In [16]:
def apply_threshold(
    model, og_conj_weight, og_disj_weight, t_val, const: float = 6.0,
) -> None:
    new_conj_weight = (
        (torch.abs(og_conj_weight) > t_val) * torch.sign(og_conj_weight) * const
    )
    model.dnf.conjunctions.weights.data = new_conj_weight

    new_disj_weight = (
        (torch.abs(og_disj_weight) > t_val) * torch.sign(og_disj_weight) * const
    )
    model.dnf.disjunctions.weights.data = new_disj_weight

In [17]:
# model2.load_state_dict(pre_threshold_sd)

conj_min = torch.min(model2.dnf.conjunctions.weights.data)
conj_max = torch.max(model2.dnf.conjunctions.weights.data)
disj_min = torch.min(model2.dnf.disjunctions.weights.data)
disj_max = torch.max(model2.dnf.disjunctions.weights.data)

abs_max = torch.max(torch.abs(torch.Tensor([conj_min, conj_max, disj_min, disj_max])))

og_conj_weight = model2.dnf.conjunctions.weights.data.clone()
og_disj_weight = model2.dnf.disjunctions.weights.data.clone()

jacc_scores = []
t_vals = torch.arange(0, abs_max + 0.01, 0.01)

for v in t_vals:
    apply_threshold(model2, og_conj_weight, og_disj_weight,
                              v, 6.0)
    _, jacc = eval_multi_class_dnf(model2, val_loader)
    jacc_scores.append(jacc)

best_jacc_score = max(jacc_scores)
best_t = t_vals[torch.argmax(torch.Tensor(jacc_scores))]
print(f'Best t: {best_t.item():.3f}    Acc: {best_jacc_score:.3f}')
apply_threshold(model2, og_conj_weight, og_disj_weight, best_t)
_, final_jacc = eval_multi_class_dnf(model2, val_loader)
print(f'Jacc after threshold: {final_jacc:.3f}')

Best t: 0.150    Acc: 0.227
Jacc after threshold: 0.227


In [18]:
final_sd = model2.state_dict()

output_rules = []

# Get all conjunctions
conj_w = final_sd["dnf.conjunctions.weights"]
conjunction_map = dict()
for i, w in enumerate(conj_w):
    if torch.all(w == 0):
        # No conjunction is applied here
        continue

    conjuncts = []
    for j, v in enumerate(w):
        if v < 0:
            # Negative weight, negate the atom
            conjuncts.append(f"not has_attr_{j}")
        elif v > 0:
            # Positive weight, normal atom
            conjuncts.append(f"has_attr_{j}")

    conjunction_map[i] = conjuncts

# Get DNF
disj_w = final_sd["dnf.disjunctions.weights"]
not_covered_classes = []
for i, w in enumerate(disj_w):
    if torch.all(w == 0):
        # No DNF for class i
        not_covered_classes.append(i)
        continue

    disjuncts = []
    for j, v in enumerate(w):
        if v < 0 and j in conjunction_map:
            # Negative weight, negate the existing conjunction
            
            # Need to add auxilary predicate (conj_X) which is not yet
            # in the final rules list
            output_rules.append(
                f"conj_{j} :- {', '.join(conjunction_map[j])}."
            )
            output_rules.append(f"class({i}) :- not conj_{j}.")
        elif v > 0 and j in conjunction_map:
            # Postivie weight, add normal conjunction
            body = ", ".join(conjunction_map[j])
            output_rules.append(f"class({i}) :- {body}.")

In [19]:
show_statements = [f'#show class/1.']


jacc_scores = []
total_sample_count = 0

for x_batch, y_batch in test_loader:
    for i in range(len(x_batch)):
        x = x_batch[i]
        y = torch.where(y_batch[i] == 1)[0].item()
        x_asp = [f"has_attr_{j}." for j in range(len(x)) if x[j] == 1]
        ctl = clingo.Control(["--warn=none"])
        ctl.add("base", [], " ".join(x_asp + output_rules + show_statements))
        ctl.ground([("base", [])])
        with ctl.solve(yield_=True) as handle:  # type: ignore
            all_answer_sets = [str(a) for a in handle]

        target_class = f"class({y})"

        if len(all_answer_sets) != 1:
            # No or multiple answer sets, should not happen
            print('No or multiple answer sets when evaluating rules.')
            continue

        output_classes = all_answer_sets[0].split(" ")
        output_classes_set = set(output_classes)

        target_class_set = {target_class}

        jacc = len(output_classes_set & target_class_set) / len(
            output_classes_set | target_class_set
        )
        jacc_scores.append(jacc)
        total_sample_count += 1

print(f'Jaccard score: {sum(jacc_scores) / total_sample_count:.3f}')

Jaccard score: 0.226


In [20]:
output_rules

['class(0) :- has_attr_0, has_attr_5, has_attr_8, has_attr_13, has_attr_28, has_attr_41, has_attr_52, has_attr_63, has_attr_78, not has_attr_81, has_attr_83, has_attr_87, has_attr_117, has_attr_121, not has_attr_126, has_attr_131, not has_attr_140, has_attr_141, not has_attr_147.',
 'class(0) :- has_attr_12, has_attr_23, not has_attr_35, not has_attr_39, not has_attr_53, has_attr_66, has_attr_71, has_attr_77, not has_attr_82, has_attr_93, not has_attr_101, not has_attr_103, has_attr_116, not has_attr_120, has_attr_123, not has_attr_124.',
 'class(0) :- not has_attr_12, has_attr_25, not has_attr_26, not has_attr_31, not has_attr_33, not has_attr_36, not has_attr_48, not has_attr_54, not has_attr_58, not has_attr_67, not has_attr_75, not has_attr_81, has_attr_82, not has_attr_97, has_attr_100, not has_attr_101, not has_attr_106, not has_attr_107, has_attr_115, has_attr_119, not has_attr_124, not has_attr_126, has_attr_131, not has_attr_138, has_attr_149.',
 'class(0) :- has_attr_3, has_a

In [21]:
dnpz['rule_str']

array(['c0 :- a0, a2, a6, a7, a8, a10, a13, a14, a16, a17, a19, a24, a27, a28, a32, a33, a34, a37, a41, a42, a44, a46, a49, a51, a56, a57, a58, a62, a67, a68, a69, a71, a76, a77, a79, a81, a83, a86, a88, a92, a93, a95, a96, a98, a99, a103, a108, a111, a114, a115, a116, a117, a125, a129, a130, a135, a139, a140, a141, a142, a143, a144, a145, a146, a147, a148.',
       'c1 :- a0, a2, a3, a5, a6, a8, a9, a11, a12, a13, a14, a15, a16, a18, a19, a21, a24, a25, a26, a27, a30, a34, a36, a38, a39, a43, a45, a47, a48, a49, a52, a54, a56, a61, a63, a64, a68, a71, a73, a74, a76, a77, a78, a79, a82, a86, a88, a92, a94, a98, a100, a103, a104, a107, a108, a110, a111, a115, a117, a118, a122, a124, a125, a126, a127, a129, a130, a132, a133, a135, a137, a138, a139, a140, a141, a143, a145, a146, a148.',
       'c2 :- a0, a2, a4, a5, a6, a7, a8, a9, a11, a12, a13, a14, a15, a17, a19, a20, a21, a23, a26, a29, a30, a36, a37, a39, a42, a43, a45, a46, a47, a48, a54, a55, a57, a59, a60, a61, a66, a67, a72, a74,