In [1]:
import sys

sys.path.append('..')
sys.path.append('../..')

In [2]:
import numpy as np
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader

In [3]:
from dnf_layer import SemiSymbolicLayerType
from rule_learner import DNFClassifier
from test_common import SyntheticDataset
from utils import DeltaDelayedExponentialDecayScheduler

In [4]:
SYNTH_DATA_PATH = 'synth_data_in10_conj5.npz'
RNG_SEED = 75
BATCH_SIZE = 64
NUM_EPOCHS = 100

In [5]:
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)

In [6]:
dnpz = np.load(SYNTH_DATA_PATH)
full_nullary = dnpz['nullary']
full_target = dnpz['target']

In [7]:
nullary_full_train, nullary_test, target_full_train, target_test = train_test_split(
    full_nullary, full_target, random_state=RNG_SEED)
nullary_train, nullary_val, target_train, target_val = train_test_split(
    nullary_full_train, target_full_train, test_size=0.2, random_state=RNG_SEED
)

train_loader = DataLoader(SyntheticDataset(nullary_train, target_train), BATCH_SIZE)
val_loader = DataLoader(SyntheticDataset(nullary_val, target_val), BATCH_SIZE)
test_loader = DataLoader(SyntheticDataset(nullary_test, target_test), BATCH_SIZE)

In [8]:
model = DNFClassifier(10, 5, 1, 0.1)
delta_decay_scheduler = DeltaDelayedExponentialDecayScheduler(
    initial_delta=0.1,
    delta_decay_delay=10,
    delta_decay_steps=1,
    delta_decay_rate=1.1,
)
optimiser = torch.optim.Adam(model.parameters(), 0.001)
criterion = torch.nn.BCELoss()

In [9]:
for i in range(NUM_EPOCHS):
    model.train()
    epoch_loss = []
    for x, y in train_loader:
        optimiser.zero_grad()
        y_hat = (torch.tanh(model(x.float())) + 1) / 2
        loss = criterion(y_hat.squeeze(), (y == 1).float() )
        loss.backward()
        optimiser.step()

        epoch_loss.append(loss.item())
    
    model.eval()    
    val_correct = 0
    val_total = 0
    for x, y in val_loader:
        with torch.no_grad():
            y_hat = torch.tanh(model(x.float()))
            y_pred = torch.where(y_hat > 0, 1, -1)
            comp = y_pred.squeeze() == y
            val_correct += torch.count_nonzero(comp).item()
            val_total += len(y)

    new_delta_val = delta_decay_scheduler.step(model, i)
    avg_loss = sum(epoch_loss) / len(epoch_loss)
    print(f'[{i + 1:3d}] Delta: {new_delta_val:.3f}  '
          f'Train avg loss: {avg_loss:.3f}  '
          f'Val acc: {val_correct / val_total:.3f}')

[  1] Delta: 0.100  Train avg loss: 0.549  Val acc: 0.744
[  2] Delta: 0.100  Train avg loss: 0.444  Val acc: 0.775
[  3] Delta: 0.100  Train avg loss: 0.417  Val acc: 0.785
[  4] Delta: 0.100  Train avg loss: 0.395  Val acc: 0.795
[  5] Delta: 0.100  Train avg loss: 0.379  Val acc: 0.804
[  6] Delta: 0.100  Train avg loss: 0.368  Val acc: 0.809
[  7] Delta: 0.100  Train avg loss: 0.359  Val acc: 0.809
[  8] Delta: 0.100  Train avg loss: 0.352  Val acc: 0.812
[  9] Delta: 0.100  Train avg loss: 0.345  Val acc: 0.811
[ 10] Delta: 0.100  Train avg loss: 0.331  Val acc: 0.853
[ 11] Delta: 0.100  Train avg loss: 0.306  Val acc: 0.893
[ 12] Delta: 0.110  Train avg loss: 0.287  Val acc: 0.897
[ 13] Delta: 0.121  Train avg loss: 0.265  Val acc: 0.906
[ 14] Delta: 0.133  Train avg loss: 0.244  Val acc: 0.917
[ 15] Delta: 0.146  Train avg loss: 0.224  Val acc: 0.928
[ 16] Delta: 0.161  Train avg loss: 0.205  Val acc: 0.935
[ 17] Delta: 0.177  Train avg loss: 0.186  Val acc: 0.937
[ 18] Delta: 0

In [10]:
torch.save(model.state_dict(), f'dnf_synth_{RNG_SEED}.pth')

In [11]:
# model.load_state_dict(torch.load(f'dnf_synth_{RNG_SEED}.pth')) 

In [12]:
def eval_binary_dnf(model, data_loader, print_error=False):
    model.eval()

    correct = 0
    total = 0

    for x, y in data_loader:
        with torch.no_grad():
            y_hat = torch.tanh(model(x.float()))
            y_pred = torch.where(y_hat > 0, 1, -1)
            comp = y_pred.squeeze() == y
            correct += torch.count_nonzero(comp).item()
            total += len(y)

            if torch.count_nonzero(comp) != len(y) and print_error:
                print(x)
                print(y_pred.squeeze())
                print(y)
                print()
    
    return correct, total


In [13]:
correct, total =  eval_binary_dnf(model, test_loader)
print(f'Total:   {total}')
print(f'Correct: {correct}')
print(f'Acc:     {correct / total:.3f}')

Total:   4995
Correct: 4992
Acc:     0.999


In [14]:
def prune_layer_weight(
    model,
    layer_type: SemiSymbolicLayerType,
    epsilon: float,
    data_loader: DataLoader,
) -> int:
    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        curr_weight = model.dnf.conjunctions.weights.data.clone()
    else:
        curr_weight = model.dnf.disjunctions.weights.data.clone()

    og_correct, og_total = eval_binary_dnf(model, data_loader)
    og_accuracy = og_correct / og_total

    prune_count = 0
    weight_device = curr_weight.device

    flatten_weight_len = len(torch.reshape(curr_weight, (-1,)))
    base_iterator = range(flatten_weight_len)
    for i in base_iterator:
        curr_weight_flatten = torch.reshape(curr_weight, (-1,))

        if curr_weight_flatten[i] == 0:
            continue

        mask = torch.ones(flatten_weight_len, device=weight_device)
        mask[i] = 0
        mask = mask.reshape(curr_weight.shape)

        masked_weight = curr_weight * mask

        if layer_type == SemiSymbolicLayerType.CONJUNCTION:
            model.dnf.conjunctions.weights.data = masked_weight
        else:
            model.dnf.disjunctions.weights.data = masked_weight

        new_correct, new_total = eval_binary_dnf(model, data_loader)
        new_accuracy = new_correct / new_total
        performance_drop = og_accuracy - new_accuracy
        if performance_drop < epsilon:
            prune_count += 1
            curr_weight *= mask

    if layer_type == SemiSymbolicLayerType.CONJUNCTION:
        model.dnf.conjunctions.weights.data = curr_weight
    else:
        model.dnf.disjunctions.weights.data = curr_weight
    return prune_count

In [15]:
prune_epsilon = 0.005
print('Prune disj layer')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.DISJUNCTION, 
    prune_epsilon, val_loader)
new_correct, new_total = eval_binary_dnf(model, val_loader)
print(f'Pruned disj count:   {prune_count}')
print(f'New perf after disj: {new_correct / new_total:.3f}')

print('Prune conj layer')
prune_count = prune_layer_weight(model, SemiSymbolicLayerType.CONJUNCTION, 
    prune_epsilon, val_loader)
new_correct, new_total = eval_binary_dnf(model, val_loader)
print(f'Pruned conj count:   {prune_count}')
print(f'New perf after conj: {new_correct / new_total:.3f}')

Prune disj layer
Pruned disj count:   4
New perf after disj: 1.000
Prune conj layer
Pruned conj count:   43
New perf after conj: 0.997


In [16]:
pre_threshold_sd = model.state_dict()

In [17]:
def apply_threshold(
    model, og_conj_weight, og_disj_weight, t_val, const: float = 6.0,
) -> None:
    new_conj_weight = (
        (torch.abs(og_conj_weight) > t_val) * torch.sign(og_conj_weight) * const
    )
    model.dnf.conjunctions.weights.data = new_conj_weight

    new_disj_weight = (
        (torch.abs(og_disj_weight) > t_val) * torch.sign(og_disj_weight) * const
    )
    model.dnf.disjunctions.weights.data = new_disj_weight

In [18]:
# model.load_state_dict(pre_threshold_sd)

conj_min = torch.min(model.dnf.conjunctions.weights.data)
conj_max = torch.max(model.dnf.conjunctions.weights.data)
disj_min = torch.min(model.dnf.disjunctions.weights.data)
disj_max = torch.max(model.dnf.disjunctions.weights.data)

abs_max = torch.max(torch.abs(torch.Tensor([conj_min, conj_max, disj_min, disj_max])))

og_conj_weight = model.dnf.conjunctions.weights.data.clone()
og_disj_weight = model.dnf.disjunctions.weights.data.clone()

acc_scores = []
t_vals = torch.arange(0, abs_max + 0.01, 0.01)

for v in t_vals:
    apply_threshold(model, og_conj_weight, og_disj_weight,
                              v, 6.0)
    correct, total = eval_binary_dnf(model, val_loader)
    acc = correct / total
    acc_scores.append(acc)

best_acc_score = max(acc_scores)
best_t = t_vals[torch.argmax(torch.Tensor(acc_scores))]
print(f'Best t: {best_t.item():.3f}    Acc: {best_acc_score:.3f}')
apply_threshold(model, og_conj_weight, og_disj_weight, best_t)
final_correct, final_total = eval_binary_dnf(model, val_loader)
print(f'Jacc after threshold: {final_correct / final_total:.3f}')

Best t: 0.000    Acc: 0.997
Jacc after threshold: 0.997


In [19]:
final_sd = model.state_dict()

output_rules = []

# Get all conjunctions
conj_w = final_sd["dnf.conjunctions.weights"]
conjunction_map = dict()
for i, w in enumerate(conj_w):
    if torch.all(w == 0):
        # No conjunction is applied here
        continue

    conjuncts = []
    for j, v in enumerate(w):
        if v < 0:
            # Negative weight, negate the atom
            conjuncts.append(f"not has_attr_{j}")
        elif v > 0:
            # Positive weight, normal atom
            conjuncts.append(f"has_attr_{j}")

    conjunction_map[i] = conjuncts

# Get DNF
disj_w = final_sd["dnf.disjunctions.weights"]
not_covered_classes = []
for i, w in enumerate(disj_w):
    if torch.all(w == 0):
        # No DNF for class i
        not_covered_classes.append(i)
        continue

    disjuncts = []
    for j, v in enumerate(w):
        if v < 0 and j in conjunction_map:
            # Negative weight, negate the existing conjunction
            
            # Need to add auxilary predicate (conj_X) which is not yet
            # in the final rules list
            output_rules.append(
                f"conj_{j} :- {', '.join(conjunction_map[j])}."
            )
            output_rules.append(f"class({i}) :- not conj_{j}.")
        elif v > 0 and j in conjunction_map:
            # Postivie weight, add normal conjunction
            body = ", ".join(conjunction_map[j])
            output_rules.append(f"class({i}) :- {body}.")

In [20]:
final_sd['dnf.conjunctions.weights']

tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  6.,  6., -6.,  6.,  6., -6.,  6.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])

In [21]:
final_sd['dnf.disjunctions.weights']

tensor([[0., 0., 0., 6., 0.]])

In [22]:
output_rules

['class(0) :- has_attr_3, has_attr_4, not has_attr_5, has_attr_6, has_attr_7, not has_attr_8, has_attr_9.']

In [23]:
dnpz['rule_str']

array(['t :- not nullary(1), not nullary(2), nullary(3), nullary(4), not nullary(5), nullary(6), nullary(7), not nullary(8), nullary(9).'],
      dtype='<U128')