Skip to content

Commit

Permalink
integrated with Nicolo's PR
Browse files Browse the repository at this point in the history
  • Loading branch information
NourFahmy committed Sep 28, 2023
1 parent 74dadef commit f1735cf
Showing 1 changed file with 74 additions and 14 deletions.
88 changes: 74 additions & 14 deletions scripts/c-btmInference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.nn import PairwiseDistance
import numpy

def load_models(model_names):
"""
takes in list of model names and loads model and corresponding tokenizers
model_names: takes in list of model names and loads model and corresponding tokenizers
returns separate lists of models and tokenizers to be accessed by other functions later
"""
models = []
Expand All @@ -25,6 +30,9 @@ def generateNextTokenFromExpert(prompt, tokenizer, model):
takes in prompt, tokenizer and model which should be passed from topKFilter
as topKFilter determines which domains are most relevant given the context
returns the predicted token of the model and its corresponding probability
prompt: the string prompt from input_file
tokenizer: the tokenizer that's been trained on the most relevant domain
model: the model that's been trained on the most relevant domain
"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
generated_tokens = []
Expand All @@ -33,7 +41,7 @@ def generateNextTokenFromExpert(prompt, tokenizer, model):
with torch.no_grad():
output = model(input_ids)

next_token_logits = output.logits[:, -1, :]
next_token_logits = output.logits[:, -1, :]
next_token_probabilities = next_token_logits.softmax(dim=-1)

# Sample the next token based on probabilities
Expand All @@ -48,27 +56,54 @@ def generateNextTokenFromExpert(prompt, tokenizer, model):
generated_text = tokenizer.decode(generated_tokens)
return generated_text, chosen_token_probability




def findEnsembleWeights(embedder, prompt, T, clusterCenters):
"""
takes in cluster centers and embedder using in clustering, and prompt, and T parameter
calculates weighted logit between context and cluster center
embedder: the embedder that was used to train the clustering model
prompt: the string prompt from input_file
T: temperature parameter for softmax
clusterCenters: list of cluster centers
"""
ensembleWeights = []
ensembleWeights = torch.tensor([])#[]
pdist = torch.nn.PairwiseDistance(p=2)
embeddedInput = embedder.encode(prompt)

for domain, clusterCenter in enumerate(clusterCenters): # assuming modelList and clusterCenters have matching indices
embeddedInput = embedder.encode(prompt)
clusterCenter = torch.tensor(clusterCenter)
embeddedInput = torch.tensor(embeddedInput) # TODO: do we need to tensor-ify them?
ensembleWeights.append(torch.exp(-1 * torch.pow(pdist(embeddedInput, clusterCenter),2) / T))
ensembleWeight = 0

for token in embeddedInput[0]:

token = torch.tensor(token)
ensembleWeight += torch.exp(-1 * torch.pow(pdist(token, clusterCenter),2) / T)
# Check if ensembleWeights is empty, if so, initialize it
if ensembleWeights is None:
ensembleWeights = ensembleWeight
else:
# Check if ensembleWeights and ensembleWeight have compatible shapes
if ensembleWeights.ndim == 0:
# If ensembleWeights is a scalar, convert it to a 1-dimensional tensor
ensembleWeights = ensembleWeights.unsqueeze(0)
if ensembleWeight.ndim == 0:
# If ensembleWeight is a scalar, convert it to a 1-dimensional tensor
ensembleWeight = ensembleWeight.unsqueeze(0)

return torch.tensor(ensembleWeights) # TODO: return torch.tensor instead of list?
# Concatenate the tensors
ensembleWeights = torch.cat((ensembleWeights, ensembleWeight))

return ensembleWeights



def topKFilter(ensembleWeights, k):
"""
takes in ensemble weights and k parameter
returns top k ensemble weights to determine most relevant domains given context
ensembleWeights: list of ensemble weights as calculated in findEnsembleWeights
k: number of top experts to choose
"""
topK = torch.topk(ensembleWeights, k=k)
indices = topK.indices
Expand All @@ -81,6 +116,13 @@ def findNextToken(embedder, models, tokenizers, prompt, k, T, clusterCenters):
takes in prompt, temperature T parameter and clusterCenters
finds k most relevant domains given context
returns most likely next token from predictions of most relevant domain experts
embedder: the embedder that was used to train the clustering model
models: list of models that were trained on most relevant domains
tokenizers: list of tokenizers that were trained on most relevant domains
prompt: the string prompt from input_file
k: number of most relevant domains to choose from
T: temperature parameter for softmax
clusterCenters: list of cluster centers
"""
ensembleWeights = findEnsembleWeights(embedder, prompt, T, clusterCenters)
topKValues, topKIndices = topKFilter(ensembleWeights, k)
Expand All @@ -102,22 +144,40 @@ def generateSequence(embedder, prompt, end_token, models, tokenizers, maxLength,
parameter k, temperature T and cluster centers
finds most likely token from most relevant domains based on prompt
builds sequence until end token is generated or maxLength is reached
embedder: the embedder that was used to train the clustering model
models: list of models that were trained on most relevant domains
tokenizers: list of tokenizers that were trained on most relevant domains
prompt: the string prompt from input_file
end_token: the token that ideally is uniform across all tokenizers
k: number of most relevant domains to choose from
T: temperature parameter for softmax
clusterCenters: list of cluster centers
"""
currToken, currSequence = None, prompt
while len(currSequence) < maxLength:# or currToken != end_token:
while (len(currSequence) < maxLength) and (currToken != end_token):
currToken, currTokenProb = findNextToken(embedder, models, tokenizers, currSequence, k, T, clusterCenters)
currSequence = currSequence + currToken

return generateSequence

return currSequence

def run_inference(embedder, model_names, input_file, output_file, end_token, maxLength, k, T, clusterCenters, models, tokenizers):

models, tokenizers = load_models(model_names)
def run_inference(embedder, model_names, input_file, output_file, end_token, maxLength, k, T, clusterCenters): #, models, tokenizers):
"""
embedder: the embedder that was used to train the clustering model
model_names: list of model names
input_file: input file name, contents are prompts
output_file: where generated sequences are written to
end_token: end token to signify termination of sequence
maxLength: max length of generated sequence
k: number of most relevant domains to choose from
T: temperature parameter for softmax
clusterCenters: list of cluster centers
"""
models, tokenizers = load_models(model_names)

with open(input_file, 'r') as file:
input_data = file.readlines()

print(input_data)
results = []
for prompt in input_data:
results.append(generateSequence(embedder, prompt, end_token, models, tokenizers, maxLength, k, T, clusterCenters))
Expand Down

0 comments on commit f1735cf

Please sign in to comment.