## Transformers library + Flask + Ngrok

In [None]:
!huggingface-cli login --token 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
!pip install flask
!pip install transformers
!pip install accelerate
!pip install pyngrok

In [None]:
import threading

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# Load the LLAMA7b model and tokenizer
token = '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'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration parameters
params = {
    "low_cpu_mem_usage": True,
    "torch_dtype": torch.float16,
}

# Load the model configuration
config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-chat-hf", stream=True)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

# Load the model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", config=config, **params)
model.eval()
model.to(device)

In [None]:
import torch
from flask import Flask, jsonify, request
from pyngrok import ngrok
from transformers import StoppingCriteria, StoppingCriteriaList

# Stop current ngrok tunnel
for tunnel in ngrok.get_tunnels():
  ngrok.disconnect(tunnel.public_url)


class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


DEFAULT_GENERATION_CONFIGS = {
    "max_new_tokens": 1024,
    "top_k": 1 # Reproducibility
}

app = Flask(__name__)
@app.route('/v1/completition', methods=['POST'])
def completition():
    try:
        # Get the input data from the request
        data = request.json
        prompt = data['prompt']
        configs = {
            **DEFAULT_GENERATION_CONFIGS,
            **data.get('configs', {})
        }

        # Configure stopping criteria
        stopping_criterias = []
        if "stop" in configs:
          stop_words = configs["stop"] or []
          stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
          stopping_criterias = StoppingCriteriaList([KeywordsStoppingCriteria(stop_ids)])
          del configs["stop"]

        with torch.no_grad():
          # Tokenize the prompt
          inputs = tokenizer(prompt, return_tensors="pt")
          inputs.to(device)

          # Generate text
          generate_ids = model.generate(
              inputs.input_ids,
              stopping_criteria=stopping_criterias,
              **configs
          )
          response = tokenizer.batch_decode(
              generate_ids[:, inputs.input_ids.shape[1]:], # Decode only the generated part
              skip_special_tokens=True,
              clean_up_tokenization_spaces=False
          )[0]

          return jsonify({
              'completition': response,
              'input_tokens': len(inputs.input_ids),
              'output_tokens': len(generate_ids[0]),
              'total_tokens': len(inputs.input_ids) + len(generate_ids[0])
            })
    except Exception as e:
        return jsonify({'error': str(e)}), 500

# Open a ngrok tunnel to the HTTP server
port = 8086
public_url = ngrok.connect(port).public_url
print(" * ngrok tunnel \"{}\" -> \"http://127.0.0.1:{}\"".format(public_url, port))
# Update any base URLs to use the public ngrok URL
app.config["BASE_URL"] = public_url
# Start the Flask server in a new thread
threading.Thread(target=app.run, kwargs={"use_reloader": False, "port":port}).start()



 * ngrok tunnel "https://e440-34-147-75-23.ngrok.io" -> "http://127.0.0.1:8086"
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:8086
INFO:werkzeug:[33mPress CTRL+C to quit[0m
