Skip to content

Commit

Permalink
Merge pull request #882 from basetenlabs/bump-version-0.9.7
Browse files Browse the repository at this point in the history
Release 0.9.7
  • Loading branch information
squidarth committed Mar 27, 2024
2 parents b225023 + d389b1b commit f6ec6a4
Show file tree
Hide file tree
Showing 25 changed files with 2,676 additions and 67 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -72,3 +72,6 @@ yarn.lock

.eslintcache
/.pytest_cache/

# Slay workflow generated code.
**/.slay_generated/**
11 changes: 10 additions & 1 deletion .pre-commit-config.yaml
Expand Up @@ -41,5 +41,14 @@ repos:
entry: poetry run mypy
language: python
types: [python]
exclude: ^examples/|^truss/test.+/|model.py$
exclude: ^examples/|^truss/test.+/|model.py$|^slay.*
pass_filenames: true
- id: mypy
name: mypy-local (3.9)
entry: poetry run mypy
language: python
types: [python]
files: ^slay.*
args:
- --python-version=3.9
pass_filenames: true
324 changes: 284 additions & 40 deletions poetry.lock

Large diffs are not rendered by default.

33 changes: 20 additions & 13 deletions pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.6"
version = "0.9.7"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -24,49 +24,47 @@ keywords = [
[tool.poetry.dependencies]
blake3 = "^0.3.3"
boto3 = "^1.26.157"
fastapi = "^0.109.1"
fastapi = ">=0.109.1"
google-cloud-storage = "2.10.0"
httpx = "^0.24.1"
huggingface_hub = ">=0.19.4"
inquirerpy = "^0.3.4"
Jinja2 = "^3.1.2"
loguru = ">=0.7.2"
msgpack = ">=1.0.2"
msgpack-numpy = ">=0.4.7.1"
msgpack-numpy = ">=0.4.8"
numpy = ">=1.23.5"
packaging = ">=20.9"
pathspec = ">=0.9.0"
psutil = "^5.9.4"
psutil = ">=5.9.4"
pydantic = ">=1.10.0"
python = ">=3.8,<3.12"
python-json-logger = ">=2.0.2"
python-on-whales = "^0.68.0"
PyYAML = "^6.0"
PyYAML = ">=6.0"
rich = "^13.4.2"
rich-click = "^1.6.1"
single-source = "^0.3.0"
tenacity = "^8.0.1"
uvicorn = "^0.24.0"
uvloop = "^0.19.0"
watchfiles = "^0.19.0"

[tool.poetry.group.builder.dependencies]
blake3 = "^0.3.3"
boto3 = "^1.26.157"
click = "^8.0.3"
fastapi = "^0.109.1"
fastapi = ">=0.109.1"
google-cloud-storage = "2.10.0"
httpx = "^0.24.1"
huggingface_hub = ">=0.19.4"
Jinja2 = "^3.1.2"
loguru = ">=0.7.2"
packaging = ">=20.9"
pathspec = ">=0.9.0"
psutil = "^5.9.4"
psutil = ">=5.9.4"
python = ">=3.8,<3.12"
python-json-logger = ">=2.0.2"
PyYAML = "^6.0"
requests = "^2.28.1"
PyYAML = ">=6.0"
requests = ">=2.31"
single-source = "^0.3.0"
tenacity = "^8.0.1"
uvicorn = "^0.24.0"
Expand Down Expand Up @@ -95,8 +93,17 @@ flask = "^2.3.3"
httpx = { extras = ["cli"], version = "^0.24.1" }
mypy = "^1.0.0"
pytest-split = "^0.8.1"
requests-mock = "^1.11.0"
types-requests = "2.31.0.2"
requests-mock = ">=1.11.0"
types-requests = ">=2.31.0.2"
uvicorn = ">=0.24.0"
uvloop = ">=0.17.0"

[tool.poetry.group.slay.dependencies]
astroid = "^3.1.0"
datamodel-code-generator = "^0.25.4"
libcst = "<1.2.0"
autoflake = "<=2.2"


[build-system]
build-backend = "poetry.core.masonry.api"
Expand Down
4 changes: 4 additions & 0 deletions slay-examples/text_to_num/requirements.txt
@@ -0,0 +1,4 @@
git+https://github.com/basetenlabs/truss.git
httpx
libcst
pydantic
Empty file.
18 changes: 18 additions & 0 deletions slay-examples/text_to_num/user_package/shared_processor.py
@@ -0,0 +1,18 @@
import slay

IMAGE_NUMPY = (
slay.Image()
.pip_requirements_file(slay.make_abs_path_here("../requirements.txt"))
.pip_requirements(["numpy"])
)


class SplitText(slay.ProcessorBase):

default_config = slay.Config(image=IMAGE_NUMPY)

async def run(self, data: str, num_partitions: int) -> tuple[list[str], int]:
import numpy as np

parts = np.array_split(np.array(list(data)), num_partitions)
return ["".join(part) for part in parts], 123
201 changes: 201 additions & 0 deletions slay-examples/text_to_num/workflow.py
@@ -0,0 +1,201 @@
# This logging is needed for debuggin class initalization.
# import logging

# log_format = "%(levelname).1s%(asctime)s %(filename)s:%(lineno)d] %(message)s"
# date_format = "%m%d %H:%M:%S"
# logging.basicConfig(level=logging.DEBUG, format=log_format, datefmt=date_format)


import random
import string
from typing import Protocol

import pydantic
import slay
from truss import truss_config
from user_package import shared_processor

IMAGE_COMMON = slay.Image().pip_requirements_file(
slay.make_abs_path_here("requirements.txt")
)


class GenerateData(slay.ProcessorBase):

default_config = slay.Config(image=IMAGE_COMMON)

def run(self, length: int) -> str:
return "".join(random.choices(string.ascii_letters + string.digits, k=length))


IMAGE_TRANSFORMERS_GPU = (
slay.Image()
.pip_requirements_file(slay.make_abs_path_here("requirements.txt"))
.pip_requirements(
["transformers==4.38.1", "torch==2.0.1", "sentencepiece", "accelerate"]
)
)


MISTRAL_HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
MISTRAL_CACHE = truss_config.ModelRepo(
repo_id=MISTRAL_HF_MODEL, allow_patterns=["*.json", "*.safetensors", ".model"]
)


class MistraLLMConfig(pydantic.BaseModel):
hf_model_name: str


class MistralLLM(slay.ProcessorBase[MistraLLMConfig]):

default_config = slay.Config(
image=IMAGE_TRANSFORMERS_GPU,
compute=slay.Compute().cpu(2).gpu("A10G"),
assets=slay.Assets().cached([MISTRAL_CACHE]),
user_config=MistraLLMConfig(hf_model_name=MISTRAL_HF_MODEL),
)
# default_config = slay.Config(config_path="mistral_config.yaml")

def __init__(
self,
context: slay.Context[MistraLLMConfig] = slay.provide_context(),
) -> None:
super().__init__(context)
import torch
import transformers

model_name = self.user_config.hf_model_name

self._model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
)
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
)

self._generate_args = {
"max_new_tokens": 512,
"temperature": 1.0,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": True,
"eos_token_id": self._tokenizer.eos_token_id,
"pad_token_id": self._tokenizer.pad_token_id,
}

def run(self, data: str) -> str:
import torch

formatted_prompt = f"[INST] {data} [/INST]"
input_ids = self._tokenizer(
formatted_prompt, return_tensors="pt"
).input_ids.cuda()
with torch.no_grad():
output = self._model.generate(inputs=input_ids, **self._generate_args)
result = self._tokenizer.decode(output[0])
return result


class MistralP(Protocol):
def __init__(self, context: slay.Context) -> None:
...

def run(self, data: str) -> str:
...


class TextToNum(slay.ProcessorBase):
default_config = slay.Config(image=IMAGE_COMMON)

def __init__(
self,
context: slay.Context = slay.provide_context(),
mistral: MistralP = slay.provide(MistralLLM),
) -> None:
super().__init__(context)
self._mistral = mistral

def run(self, data: str) -> int:
number = 0
generated_text = self._mistral.run(data)
for char in generated_text:
number += ord(char)

return number


class Workflow(slay.ProcessorBase):
default_config = slay.Config(image=IMAGE_COMMON)

def __init__(
self,
context: slay.Context = slay.provide_context(),
data_generator: GenerateData = slay.provide(GenerateData),
splitter: shared_processor.SplitText = slay.provide(shared_processor.SplitText),
text_to_num: TextToNum = slay.provide(TextToNum),
) -> None:
super().__init__(context)
self._data_generator = data_generator
self._data_splitter = splitter
self._text_to_num = text_to_num

async def run(self, length: int, num_partitions: int) -> tuple[int, str, int]:
data = self._data_generator.run(length)
text_parts, number = await self._data_splitter.run(data, num_partitions)
value = 0
for part in text_parts:
value += self._text_to_num.run(part)
return value, data, number


if __name__ == "__main__":
import logging

from slay import utils

# from slay.truss_compat import deploy

root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
log_format = "%(levelname).1s%(asctime)s %(filename)s:%(lineno)d] %(message)s"
date_format = "%m%d %H:%M:%S"
formatter = logging.Formatter(fmt=log_format, datefmt=date_format)
for handler in root_logger.handlers:
handler.setFormatter(formatter)

# class FakeMistralLLM(slay.ProcessorBase):
# def run(self, data: str) -> str:
# return data.upper()

# import asyncio
# with slay.run_local():
# text_to_num = TextToNum(mistral=FakeMistralLLM())
# wf = Workflow(text_to_num=text_to_num)
# result = asyncio.run(wf.run(length=123, num_partitions=123))
# print(result)

with utils.log_level(logging.DEBUG):
remote = slay.deploy_remotely(
Workflow, workflow_name="Test", generate_only=False
)

# remote = slay.definitions.BasetenRemoteDescriptor(
# b10_model_id="7qk59gdq",
# b10_model_version_id="woz52g3",
# b10_model_name="Workflow",
# b10_model_url="https://model-7qk59gdq.api.baseten.co/production",
# )
# with utils.log_level(logging.INFO):
# response = deploy.call_workflow_dbg(
# remote, {"length": 1000, "num_partitions": 100}
# )
# print(response)
# print(response.json())
17 changes: 17 additions & 0 deletions slay/__init__.py
@@ -0,0 +1,17 @@
# flake8: noqa F401
from slay.definitions import (
Assets,
Compute,
Config,
Context,
Image,
UsageError,
make_abs_path_here,
)
from slay.public_api import (
ProcessorBase,
deploy_remotely,
provide,
provide_context,
run_local,
)

0 comments on commit f6ec6a4

Please sign in to comment.