Skip to content

Conversation

TheAthleticCoder
Copy link
Contributor

Partially resolves #947
Currently working on sampler.py and added a small unit test to top_k file.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

self,
k=5,
seed=None,
temperature=1.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do the changes suggested above, you won't have to do anything beyond passing a **kwargs through.


self.next = next
self.sampler = TopKSampler(k=5)
self.sampler_temperature = TopKSampler(k=5, temperature=1e-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just define this in the test below, as it is only used for one test.

next=next,
prompt=prompt,
)
output_ids = set(output[0].numpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to do set here, just self.assertAllEqual(output, [0] * self.length) or something like that.

@mattdangerw mattdangerw self-assigned this Apr 1, 2023
@mattdangerw
Copy link
Member

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! I think we just need to merge this with the latest changes on master.

For beam search, top-k, greedy, contrastive, we can probably just very pass temperature=1.0 in setup for the test sampler, as the output won't change for these "top" samplers. That way our test files stay short.

For top-p and random we can include a simple test that assert that a very low temp will lead to deterministic output.

output_ids = set(output[0].numpy())
self.assertContainsSubset(output_ids, range(3))

def test_outputs_in_top_p_with_temperature(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could make this simpler, we are just trying to test that temperature is working. Something like (haven't run this)...

def test_temperature(self):
    def next(prompt, state, index):
        # Return a distribution where each id is progressively less likely.
        logits = tf.range(self.vocab_size, 0, -1, dtype="float32")
        logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0)
        return tf.constant(logits), state

    prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
    # Use a very low temperature, so we always sample the first token.
    output = TopPSampler(p=0.5, temperature=1e-6)(
        next=next,
        prompt=prompt,
    )
    self.assertAllEqual(output, tf.zero_like(prompt))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for the top_p_sampler. However, when I apply the same for the random_sampler, it only passes 75% of the tests as seen here.
image
Not sure why.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome this looks great! Found one last small comment, and this needs to be merged or rebased with the master branch. Then let's land!

@mattdangerw
Copy link
Member

Thanks! Will resolve merge conflicts as I merge this.

@TheAthleticCoder
Copy link
Contributor Author

@mattdangerw Should I add the temperature arg to the contrastive sampler as well in this thread? Or will we be opening another issue for the same?

@mattdangerw
Copy link
Member

@mattdangerw Should I add the temperature arg to the contrastive sampler as well in this thread? Or will we be opening another issue for the same?

Follow up is fine, I will merge this now.

@mattdangerw
Copy link
Member

/gcbrun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a temperature argument to the base sampler class
2 participants