Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a ltn.Predicate constructor that takes in a logits model #31

Closed
sbadredd opened this issue Aug 24, 2022 · 1 comment
Closed

Add a ltn.Predicate constructor that takes in a logits model #31

sbadredd opened this issue Aug 24, 2022 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@sbadredd
Copy link
Member

sbadredd commented Aug 24, 2022

Constructors for ltn.Predicate

The constructor for ltn.Predicate accepts a model that outputs one truth degree in [0,1].

class ModelThatOutputsATruthDegree(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid) # returns one value in [0,1]

    def call(self, x):
        x = self.dense1(x)
        return self.dense2(x)

model = ModelThatOutputsATruthDegree()
P1 = ltn.Predicate(model)
P1(x) # -> call with a ltn Variable

Issue

Many models output several values simultaneously. For example, a model for the predicate P2 classifying images x into n classes type_1, ..., type_n will likely output n logits using the same hidden layers.

Eventually, we would expect to call the corresponding predicate using the syntax P2(x,type). This requires two additional steps:

  1. Transforming the logits into values in [0,1],
  2. Indexing the class using the term type.

Because this is a common use-case, we implemented a function ltn.utils.LogitsToPredicateModel for convenience. It is used in some of the examples (cf MNIST digit addition).
The syntax is:

logits_model(x) # how to call `logits_model`
P2 = ltn.Predicate(ltn.utils.LogitsToPredicateModel(logits_model), single_label=True)
P2([x,type]) # how to call the predicate

It automatically adds a final argument for class indexing and performs a sigmoid or softmax activation depending on the parameter single_label.

Proposition

It would be more elegant to have the functionality of creating a predicate from a logits model as a class constructor for ltn.Predicate.

A suggested syntax is:

P2 = ltn.Predicate.FromLogits(logits_model, activation_function="softmax", with_class_indexing=True)
  • The functionality comes as a new class constructor,
  • The activation function is more explicit than the single_label parameter in ltn.utils.LogitsToPredicateModel,
  • with_class_indexing=False still allows creating predicates in the form of P1(x), like abovementioned.

Changes to the rest of the API

The proposition adds a new constructor but shouldn't change any other method of ltn.Predicate or any framework method in general.

@sbadredd sbadredd self-assigned this Aug 24, 2022
@sbadredd sbadredd added the enhancement New feature or request label Aug 24, 2022
@gaokun12

This comment was marked as off-topic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants