In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

num_items = 1000
x = torch.rand(num_items, 2)
c = x.round()
y = (c.sum(dim=1) == 1).float()[:, None]
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

In [2]:
class ConceptEmbedding(torch.nn.Module):
    def __init__(self, in_features, n_concepts, emb_size):
        super().__init__()
        self.n_concepts = n_concepts
        self.emb_size = emb_size

        self.concept_context_generator = torch.nn.Sequential(
                torch.nn.Linear(in_features, 2 * emb_size * n_concepts),
                torch.nn.LeakyReLU(),
            )
        self.concept_prob_predictor = torch.nn.Sequential(
            torch.nn.Linear(2 * emb_size, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        concept_embs = self.concept_context_generator(x)
        concept_embs_shape = x.shape[:-1] + (self.n_concepts, 2 * self.emb_size)
        concept_embs = concept_embs.view(*concept_embs_shape)
        concept_probs = self.concept_prob_predictor(concept_embs)
        concept_pos = concept_embs[..., :self.emb_size]
        concept_neg = concept_embs[..., self.emb_size:]
        concept_embs = concept_pos * concept_probs + concept_neg * (1 - concept_probs)
        return concept_embs, concept_probs.squeeze(-1)



In [3]:
embedding_size = 8
concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    ConceptEmbedding(10, c.shape[1], embedding_size),
)
task_predictor = torch.nn.Sequential(
    torch.nn.Linear(c.shape[1]*embedding_size, 1),
)
model = torch.nn.Sequential(concept_encoder, task_predictor)

In [4]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(501):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb.reshape(len(c_emb), -1))
    # compute loss
    concept_loss = loss_form_c(c_pred, c_train)
    task_loss = loss_form_y(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

In [5]:
c_emb, c_pred = concept_encoder.forward(x_test)
y_pred = task_predictor(c_emb.reshape(len(c_emb), -1))

task_accuracy = accuracy_score(y_test, y_pred > 0)
concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
print(f"Task accuracy: {task_accuracy:.2f}, Concept accuracy: {concept_accuracy:.2f}")

Task accuracy: 0.98, Concept accuracy: 0.98


In [6]:
y_train = F.one_hot(y_train.long().ravel()).float()
y_test = F.one_hot(y_test.long().ravel()).float()


In [None]:
class ConceptMemoryReasoningLayer(torch.nn.Module):
    def __init__(self, n_concepts, n_classes, emb_size=32, n_rules=2):
        super(ConceptMemoryReasoningLayer, self).__init__()
        self.n_concepts = n_concepts
        self.n_classes = n_classes
        self.n_rules = n_rules
        self.emb_size = emb_size

        self.rule_book = nn.Embedding(n_classes * n_rules, self.emb_size)

        self.rule_decoder = nn.Sequential(
            nn.Linear(self.emb_size, self.emb_size),
            nn.LeakyReLU(),
            nn.Linear(self.emb_size, 3 * n_concepts),
        )

        """self.rule_decoder_highend = nn.Embedding(n_classes * n_rules * n_concepts, 3)"""

        self.rule_selector = nn.Sequential(
            nn.Linear(self.n_concepts, self.emb_size),
            nn.LeakyReLU(),
            nn.Linear(self.emb_size, n_classes * n_rules),
        )
    
    def decode_rules(self):
        rule_embs = self.rule_book.weight.view(self.n_classes, self.n_rules, self.emb_size)
        rules_decoded = self.rule_decoder(rule_embs).view(self.n_classes, self.n_rules, self.n_concepts, 3)
        #rules_decoded = self.rule_decoder_highend.weight.view(self.n_classes, self.n_rules, self.n_concepts, 3)
        rules_decoded = F.softmax(rules_decoded, dim=-1)
        if not self.training:
            # argmax to get the most likely rule
            rules_decoded = F.one_hot(torch.argmax(rules_decoded, dim=-1), num_classes=3)
        return rules_decoded


    def forward(self, x, return_explanation=False, concept_names=None):
        
        rules_decoded = self.decode_rules()
        rule_scores = self.rule_selector(x).view(-1, self.n_classes, self.n_rules)
        rule_scores = F.softmax(rule_scores, dim=-1) # (batch_dim, n_classes, n_rules)
        if not self.training:
            # argmax to get the most likely rule
            rule_scores = F.one_hot(torch.argmax(rule_scores, dim=-1), num_classes=self.n_rules)
        
        agg_rules = (rules_decoded[None, ...] * rule_scores[..., None, None]) # # (batch_dim, n_classes, n_rules, n_concepts, 3)
        agg_rules = agg_rules.sum(dim=-3) 
        pos_rules = agg_rules[..., 0] 
        neg_rules = agg_rules[..., 1] 
        irr_rules = agg_rules[..., 2]
        x = x[..., None, :]
        # batch_dim, n_classes, n_concepts
        preds = (pos_rules * x + neg_rules * (1 - x) + irr_rules).prod(dim=-1)
        c_rec = 0.5 * irr_rules + pos_rules
        
        
        aux_loss = F.binary_cross_entropy(c_rec, x.repeat(1, c_rec.shape[1], 1),reduction="none").mean(dim=-1)
        aux_loss = (aux_loss * preds).mean()

        explanations = None
        assert not return_explanation or not self.training, "Explanation can only be returned in eval mode"
        if return_explanation:
            if concept_names is None:
                concept_names = [f"c_{i}" for i in range(self.n_concepts)]
            rule_counts = (rule_scores.round().to(torch.long) * preds[..., None].round().to(torch.long)).sum(dim=0) # (n_classes, n_rules)
            from collections import defaultdict
            rule_strings = defaultdict(list)
            explanations = {}
            for i in range(self.n_classes):
                class_name = f"y_{i}"
                for j in range(self.n_rules):
                    for c in range(self.n_concepts):
                        is_pos = rules_decoded[i, j, c, 0]
                        is_neg = rules_decoded[i, j, c, 1]
                        is_irr = rules_decoded[i, j, c, 2]
                        if is_pos > 0.5:
                            rule_strings[(i, j)].append(concept_names[c])
                        elif is_neg > 0.5:
                            rule_strings[(i, j)].append(f"~{concept_names[c]}")
                explanations[class_name] = []
                for j in range(self.n_rules):
                    if rule_counts[i, j] > 0:
                        rule_str = " & ".join(rule_strings[(i, j)])
                        explanations[class_name].append(f"Rule {j}: {rule_str} (Counts: {rule_counts[i, j]})")

        return preds, aux_loss, explanations
    


In [8]:

task_predictor = ConceptMemoryReasoningLayer(2, y_train.shape[1], emb_size=embedding_size)
model = torch.nn.Sequential(concept_encoder, task_predictor)

In [9]:
c_emb, c_pred = concept_encoder(x_train)
y_pred, aux_loss, _ = task_predictor(c_pred)

In [10]:
y_pred.shape, aux_loss.shape

(torch.Size([670, 2]), torch.Size([]))

In [11]:
concept_encoder(x_train)[0].shape, concept_encoder(x_train)[1].shape

(torch.Size([670, 2, 8]), torch.Size([670, 2]))

In [12]:
x_train.shape

torch.Size([670, 2])

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.BCELoss()

In [14]:
with torch.autograd.set_detect_anomaly(True):
    
    model.train()
    for epoch in range(512):
        optimizer.zero_grad()

        # generate concept and task predictions
        c_emb, c_pred = concept_encoder(x_train)
        y_pred, aux_loss, _ = task_predictor(c_pred)

        #aux_loss = (y_train * aux_loss).mean() 
        # compute loss
        concept_loss = loss_form(c_pred, c_train)
        task_loss = loss_form(y_pred, y_train)
        loss = concept_loss + 0.5*task_loss + 2 * aux_loss

        loss.backward()
        optimizer.step()

In [15]:
task_predictor.decode_rules().round()

tensor([[[[0., 1., 0.],
          [0., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.]]],


        [[[0., 1., 0.],
          [1., 0., 0.]],

         [[1., 0., 0.],
          [0., 1., 0.]]]], grad_fn=<RoundBackward0>)

In [16]:
task_predictor.eval()

x_test = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
y_test = F.one_hot(torch.tensor([0, 1, 1, 0]), num_classes=2)
task_predictor(x_test)[0].round()

tensor([[1., 0.],
        [0., 1.],
        [0., 1.],
        [1., 0.]])

In [17]:
task_predictor(x_test, return_explanation=True)[2]

{'y_0': ['Rule 0: ~c_0 & ~c_1 (Counts: 1)', 'Rule 1: c_0 & c_1 (Counts: 1)'],
 'y_1': ['Rule 0: ~c_0 & c_1 (Counts: 1)', 'Rule 1: c_0 & ~c_1 (Counts: 1)']}

In [18]:
c_emb, c_pred = concept_encoder.forward(x_test)
y_pred_, aux_loss_, explanations = task_predictor(c_pred)

aux_loss_ = (y_test * aux_loss_).mean()
task_accuracy = accuracy_score(y_test, y_pred_ > 0.5)
#concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
print(f"Task accuracy: {task_accuracy:.2f}", f"Auxillary loss: {aux_loss_:.2f}")

Task accuracy: 1.00 Auxillary loss: 0.00


In [19]:
local_explanations

NameError: name 'local_explanations' is not defined

In [None]:
global_explanations