-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
c-btm inference #50
Merged
Merged
c-btm inference #50
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
655ea6d
Add files via upload
NourFahmy 2b39187
updated inference script to take in input and output file path and mo…
47314fb
updated c-btm inference script
bed3ec1
updated to generate one token at a time
74dadef
remove print statement
f1735cf
integrated with Nicolo's PR
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
import torch | ||
from torch.nn import PairwiseDistance | ||
import numpy | ||
|
||
def load_models(model_names): | ||
""" | ||
model_names: takes in list of model names and loads model and corresponding tokenizers | ||
returns separate lists of models and tokenizers to be accessed by other functions later | ||
""" | ||
models = [] | ||
tokenizers = [] | ||
|
||
for model_name in model_names: | ||
model = AutoModelForCausalLM.from_pretrained(model_name) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
# Freeze the parameters of the loaded models if needed | ||
model.eval() | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
models.append(model) | ||
tokenizers.append(tokenizer) | ||
|
||
return models, tokenizers | ||
|
||
def generateNextTokenFromExpert(prompt, tokenizer, model): | ||
""" | ||
takes in prompt, tokenizer and model which should be passed from topKFilter | ||
as topKFilter determines which domains are most relevant given the context | ||
returns the predicted token of the model and its corresponding probability | ||
prompt: the string prompt from input_file | ||
tokenizer: the tokenizer that's been trained on the most relevant domain | ||
model: the model that's been trained on the most relevant domain | ||
""" | ||
input_ids = tokenizer.encode(prompt, return_tensors="pt") | ||
generated_tokens = [] | ||
|
||
for _ in range(2): | ||
with torch.no_grad(): | ||
output = model(input_ids) | ||
|
||
next_token_logits = output.logits[:, -1, :] | ||
next_token_probabilities = next_token_logits.softmax(dim=-1) | ||
|
||
# Sample the next token based on probabilities | ||
next_token = torch.multinomial(next_token_probabilities, num_samples=1) | ||
|
||
# Get the probability of the chosen token | ||
chosen_token_probability = next_token_probabilities[0, next_token.item()].item() | ||
|
||
generated_tokens.append(next_token.item()) | ||
input_ids = torch.cat((input_ids, next_token), dim=-1) | ||
|
||
generated_text = tokenizer.decode(generated_tokens) | ||
return generated_text, chosen_token_probability | ||
|
||
|
||
|
||
|
||
def findEnsembleWeights(embedder, prompt, T, clusterCenters): | ||
""" | ||
takes in cluster centers and embedder using in clustering, and prompt, and T parameter | ||
calculates weighted logit between context and cluster center | ||
embedder: the embedder that was used to train the clustering model | ||
prompt: the string prompt from input_file | ||
T: temperature parameter for softmax | ||
clusterCenters: list of cluster centers | ||
""" | ||
ensembleWeights = torch.tensor([])#[] | ||
pdist = torch.nn.PairwiseDistance(p=2) | ||
embeddedInput = embedder.encode(prompt) | ||
|
||
for domain, clusterCenter in enumerate(clusterCenters): # assuming modelList and clusterCenters have matching indices | ||
ensembleWeight = 0 | ||
|
||
for token in embeddedInput[0]: | ||
|
||
token = torch.tensor(token) | ||
ensembleWeight += torch.exp(-1 * torch.pow(pdist(token, clusterCenter),2) / T) | ||
# Check if ensembleWeights is empty, if so, initialize it | ||
if ensembleWeights is None: | ||
ensembleWeights = ensembleWeight | ||
else: | ||
# Check if ensembleWeights and ensembleWeight have compatible shapes | ||
if ensembleWeights.ndim == 0: | ||
# If ensembleWeights is a scalar, convert it to a 1-dimensional tensor | ||
ensembleWeights = ensembleWeights.unsqueeze(0) | ||
if ensembleWeight.ndim == 0: | ||
# If ensembleWeight is a scalar, convert it to a 1-dimensional tensor | ||
ensembleWeight = ensembleWeight.unsqueeze(0) | ||
|
||
# Concatenate the tensors | ||
ensembleWeights = torch.cat((ensembleWeights, ensembleWeight)) | ||
|
||
return ensembleWeights | ||
|
||
|
||
|
||
def topKFilter(ensembleWeights, k): | ||
""" | ||
takes in ensemble weights and k parameter | ||
returns top k ensemble weights to determine most relevant domains given context | ||
ensembleWeights: list of ensemble weights as calculated in findEnsembleWeights | ||
k: number of top experts to choose | ||
""" | ||
topK = torch.topk(ensembleWeights, k=k) | ||
indices = topK.indices | ||
topK = [float(p)/torch.sum(topK.values) for p in topK.values] | ||
|
||
return topK, indices | ||
|
||
def findNextToken(embedder, models, tokenizers, prompt, k, T, clusterCenters): | ||
""" | ||
takes in prompt, temperature T parameter and clusterCenters | ||
finds k most relevant domains given context | ||
returns most likely next token from predictions of most relevant domain experts | ||
embedder: the embedder that was used to train the clustering model | ||
models: list of models that were trained on most relevant domains | ||
tokenizers: list of tokenizers that were trained on most relevant domains | ||
prompt: the string prompt from input_file | ||
k: number of most relevant domains to choose from | ||
T: temperature parameter for softmax | ||
clusterCenters: list of cluster centers | ||
""" | ||
ensembleWeights = findEnsembleWeights(embedder, prompt, T, clusterCenters) | ||
topKValues, topKIndices = topKFilter(ensembleWeights, k) | ||
expertsPredictedTokens = [] | ||
for i, index in enumerate(topKIndices): | ||
predToken, predTokenProb = generateNextTokenFromExpert( | ||
prompt, | ||
tokenizers[index], | ||
models[index] | ||
) | ||
predTokenProb *= topKValues[i] | ||
expertsPredictedTokens.append((predToken, predTokenProb.numpy())) # convert to regular number for max | ||
|
||
return max(expertsPredictedTokens, key=lambda x: x[1]) | ||
|
||
def generateSequence(embedder, prompt, end_token, models, tokenizers, maxLength, k, T, clusterCenters): | ||
""" | ||
takes in prompt, end_token which ideally is uniform across all tokenizers, | ||
parameter k, temperature T and cluster centers | ||
finds most likely token from most relevant domains based on prompt | ||
builds sequence until end token is generated or maxLength is reached | ||
embedder: the embedder that was used to train the clustering model | ||
models: list of models that were trained on most relevant domains | ||
tokenizers: list of tokenizers that were trained on most relevant domains | ||
prompt: the string prompt from input_file | ||
end_token: the token that ideally is uniform across all tokenizers | ||
k: number of most relevant domains to choose from | ||
T: temperature parameter for softmax | ||
clusterCenters: list of cluster centers | ||
""" | ||
currToken, currSequence = None, prompt | ||
while (len(currSequence) < maxLength) and (currToken != end_token): | ||
currToken, currTokenProb = findNextToken(embedder, models, tokenizers, currSequence, k, T, clusterCenters) | ||
currSequence = currSequence + currToken | ||
|
||
return currSequence | ||
|
||
|
||
def run_inference(embedder, model_names, input_file, output_file, end_token, maxLength, k, T, clusterCenters): #, models, tokenizers): | ||
""" | ||
embedder: the embedder that was used to train the clustering model | ||
model_names: list of model names | ||
input_file: input file name, contents are prompts | ||
output_file: where generated sequences are written to | ||
end_token: end token to signify termination of sequence | ||
maxLength: max length of generated sequence | ||
k: number of most relevant domains to choose from | ||
T: temperature parameter for softmax | ||
clusterCenters: list of cluster centers | ||
""" | ||
models, tokenizers = load_models(model_names) | ||
|
||
with open(input_file, 'r') as file: | ||
input_data = file.readlines() | ||
print(input_data) | ||
results = [] | ||
for prompt in input_data: | ||
results.append(generateSequence(embedder, prompt, end_token, models, tokenizers, maxLength, k, T, clusterCenters)) | ||
|
||
with open(output_file, 'w') as file: | ||
for result in results: | ||
file.write(f"{result}\n") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some issues with loading the models and maintaining HF credentials -- had to load models and tokenizers outside of the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. good to know. Strange that you can't load. I made all models public now