Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

c-btm inference #50

Merged
merged 6 commits into from
Nov 9, 2023
Merged

Conversation

NourFahmy
Copy link
Collaborator

@NourFahmy NourFahmy commented May 10, 2023

inference code for c-btm - replicating formulas 2 & 3 from c-BTM paper, and tested locally.

kindly inform if anything else is needed!

link to #40

inference code for c-btm
@mrcabbage972
Copy link
Collaborator

mrcabbage972 commented May 12, 2023

Hi @NourFahmy,
The goal is to have a script which the user can call with the arguments:

  • The names of the models
  • The input data file path
  • The output file path

The script should load the models, run inference on the input data and save the results. This would allow us to evaluate the performance of the method using perplexity and also on downstream tasks.

It would be very helpful if the PR would solve the end to end. It's possible to break this up into a few PR's, if you prefer.

@NourFahmy
Copy link
Collaborator Author

NourFahmy commented May 14, 2023

Hi @mrcabbage972 - thank you for your feedback! Will update accordingly by Wednesday.

@NourFahmy NourFahmy changed the title Add files via upload c-btm inference May 19, 2023
@NourFahmy
Copy link
Collaborator Author

NourFahmy commented May 19, 2023

kindly note, still needs to be tested!
latest commit allows user to call with input and output file path and names of models.

as I understand the sequence of tasks that need to be implemented by for c-btm inference are:

  1. calculate the distance between the cluster center of a given domain and the tokenized prompt
  2. all the distances are evaluated, and the top k are retained and normalized to sum to 1, and then determined as the most relevant domains given the context
  3. the prompt is fed to the most relevant domain experts as per step 2 to generate 1 singular new token
  4. the token's probability is weighted by the distance in step 2
  5. the next token in the sequence is chosen by choosing the one with highest probability
  6. token is added to the sequence, and the process continues until we reach an end token or a max length of sequence

cc: @mrcabbage972

@kenhktsui
Copy link
Collaborator

@NourFahmy @mrcabbage972
I will be finishing the PR #61 by this weekend.
As we do not split the dataset in an unsupervised manner, I trained a few classifiers which gives us the weight of each dataset (and therefore each expert that it is trained on) instead of clustering as in the c-BTM paper.

@huu4ontocord
Copy link
Owner

Where are we on this? @NourFahmy @kenhktsui @mrcabbage972 ?

pass in embedder for prompt
tokenizers = []

for model_name in model_names:
model = AutoModelForCausalLM.from_pretrained(model_name)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some issues with loading the models and maintaining HF credentials -- had to load models and tokenizers outside of the function

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. good to know. Strange that you can't load. I made all models public now

inputs = tokenizer(prompt)
print(inputs['input_ids'])
sizeOfInputs = len(inputs['input_ids'])
outputs = model(**inputs, max_new_tokens=1,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_new_tokens not a parameter of geoptx -- how can I limit the number of tokens

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you do max_length ?

@huu4ontocord huu4ontocord merged commit c43a93a into huu4ontocord:main Nov 9, 2023
1 check failed
@huu4ontocord
Copy link
Owner

I merged it. you can keep adding to it as another PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants