In [2]:
import torch
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

num_items = 1000
x = torch.rand(num_items, 2)
c = x.round()
y = (c.sum(dim=1) == 1).float()[:, None]
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

In [3]:
class ConceptEmbedding(torch.nn.Module):
    def __init__(self, in_features, n_concepts, emb_size):
        super().__init__()
        self.n_concepts = n_concepts
        self.emb_size = emb_size

        self.concept_context_generator = torch.nn.Sequential(
                torch.nn.Linear(in_features, 2 * emb_size * n_concepts),
                torch.nn.LeakyReLU(),
            )
        self.concept_prob_predictor = torch.nn.Sequential(
            torch.nn.Linear(2 * emb_size, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        concept_embs = self.concept_context_generator(x)
        concept_embs_shape = x.shape[:-1] + (self.n_concepts, 2 * self.emb_size)
        concept_embs = concept_embs.view(*concept_embs_shape)
        concept_probs = self.concept_prob_predictor(concept_embs)
        concept_pos = concept_embs[..., :self.emb_size]
        concept_neg = concept_embs[..., self.emb_size:]
        concept_embs = concept_pos * concept_probs + concept_neg * (1 - concept_probs)
        return concept_embs, concept_probs.squeeze(-1)
    
    
class ConceptEmbedding(torch.nn.Module):
    def __init__(self, in_features, n_concepts, emb_size):
        super().__init__()
        self.n_concepts = n_concepts
        self.emb_size = emb_size

        self.concept_context_generator = torch.nn.Sequential(
                torch.nn.Linear(in_features, 2 * emb_size * n_concepts),
                torch.nn.LeakyReLU(),
            )
        self.concept_prob_predictor = torch.nn.Sequential(
            torch.nn.Linear(2 * emb_size, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        concept_embs = self.concept_context_generator(x)
        concept_embs_shape = x.shape[:-1] + (self.n_concepts, 2 * self.emb_size)
        concept_embs = concept_embs.view(*concept_embs_shape)
        concept_probs = self.concept_prob_predictor(concept_embs)
        concept_pos = concept_embs[..., :self.emb_size]
        concept_neg = concept_embs[..., self.emb_size:]
        concept_embs = concept_pos * concept_probs + concept_neg * (1 - concept_probs)
        return concept_embs, concept_probs.squeeze(-1)

In [4]:
embedding_size = 8
concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    ConceptEmbedding(10, c.shape[1], embedding_size),
)
task_predictor = torch.nn.Sequential(
    torch.nn.Linear(c.shape[1]*embedding_size, 1),
)
model = torch.nn.Sequential(concept_encoder, task_predictor)

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(501):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb.reshape(len(c_emb), -1))
    # compute loss
    concept_loss = loss_form_c(c_pred, c_train)
    task_loss = loss_form_y(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

In [6]:
c_emb, c_pred = concept_encoder.forward(x_test)
y_pred = task_predictor(c_emb.reshape(len(c_emb), -1))

task_accuracy = accuracy_score(y_test, y_pred > 0)
concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
print(f"Task accuracy: {task_accuracy:.2f}, Concept accuracy: {concept_accuracy:.2f}")

Task accuracy: 0.98, Concept accuracy: 0.98


In [12]:
from bronze_age.models.concept_reasoner import ConceptReasoningLayer
import torch.nn.functional as F

y_train = F.one_hot(y_train.long().ravel()).float()
y_test = F.one_hot(y_test.long().ravel()).float()

task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1], temperature=100)
model = torch.nn.Sequential(concept_encoder, task_predictor)

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.BCELoss()
model.train()
for epoch in range(501):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb, c_pred)

    # compute loss
    concept_loss = loss_form(c_pred, c_train)
    task_loss = loss_form(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

In [9]:
local_explanations = task_predictor.explain(c_emb, c_pred, 'local')
global_explanations = task_predictor.explain(c_emb, c_pred, 'global')

In [10]:
local_explanations

[{'sample-id': 0,
  'class': 'y_1',
  'explanation': '~c_0 & c_1',
  'attention': [-1.0, 1.0]},
 {'sample-id': 1,
  'class': 'y_1',
  'explanation': '~c_0 & c_1',
  'attention': [-1.0, 1.0]},
 {'sample-id': 2,
  'class': 'y_0',
  'explanation': '~c_0 & ~c_1',
  'attention': [-1.0, -1.0]},
 {'sample-id': 3,
  'class': 'y_1',
  'explanation': 'c_0 & ~c_1',
  'attention': [1.0, -1.0]},
 {'sample-id': 4,
  'class': 'y_1',
  'explanation': '~c_0 & c_1',
  'attention': [-1.0, 1.0]},
 {'sample-id': 5,
  'class': 'y_0',
  'explanation': 'c_0 & c_1',
  'attention': [1.0, 1.0]},
 {'sample-id': 6,
  'class': 'y_0',
  'explanation': '~c_0 & ~c_1',
  'attention': [-1.0, -1.0]},
 {'sample-id': 7,
  'class': 'y_0',
  'explanation': '~c_0 & ~c_1',
  'attention': [-1.0, -1.0]},
 {'sample-id': 8,
  'class': 'y_1',
  'explanation': 'c_0 & ~c_1',
  'attention': [1.0, -1.0]},
 {'sample-id': 9,
  'class': 'y_0',
  'explanation': '~c_0 & ~c_1',
  'attention': [-1.0, -1.0]},
 {'sample-id': 10,
  'class': 'y_0

In [11]:
global_explanations

[{'class': 'y_0', 'explanation': '~c_0 & ~c_1', 'count': 176},
 {'class': 'y_0', 'explanation': 'c_0 & c_1', 'count': 163},
 {'class': 'y_1', 'explanation': '~c_0 & c_1', 'count': 165},
 {'class': 'y_1', 'explanation': 'c_0 & ~c_1', 'count': 166}]