Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I extend this to multi-label classification? #264

Open
JeanHung opened this issue Apr 16, 2023 · 2 comments
Open

How can I extend this to multi-label classification? #264

JeanHung opened this issue Apr 16, 2023 · 2 comments

Comments

@JeanHung
Copy link

No description provided.

@andsteing
Copy link
Collaborator

You would need to extend the script in two places:

First, use something like

label = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(data['label'])[0]), (num_classes,))

to get multiple labels in one-hot format in the input processing

label = tf.one_hot(data['label'], num_classes) # pylint: disable=no-value-for-parameter

Second, use a sigmoid loss instead of the cross-entropy loss here:

def loss_fn(params, images, labels):
logits = apply_fn(
dict(params=params),
rngs=dict(dropout=dropout_rng),
inputs=images,
train=True)
return cross_entropy_loss(logits=logits, labels=labels)

@JeanHung
Copy link
Author

Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants