forked from bnb-chain/greenfield-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
basic_model.py
64 lines (47 loc) · 2.09 KB
/
basic_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#import boto3
import asyncio
import logging
import random
import string
import io
from greenfield_python_sdk.config import NetworkConfiguration, NetworkTestnet, get_account_configuration
from greenfield_python_sdk.greenfield_client import GreenfieldClient
from greenfield_python_sdk.key_manager import KeyManager
from greenfield_python_sdk.models.bucket import CreateBucketOptions
from greenfield_python_sdk.models.object import CreateObjectOptions, GetObjectOption, PutObjectOptions
from transformers import T5ForConditionalGeneration, T5Tokenizer
logging.basicConfig(level=logging.INFO)
config = get_account_configuration()
async def main():
network_configuration = NetworkConfiguration(**NetworkTestnet().model_dump())
key_manager = KeyManager(private_key=config.private_key)
logging.info(f"Main account address: {key_manager.address}")
async with GreenfieldClient(network_configuration=network_configuration, key_manager=key_manager) as client:
logging.info(f"---> TEST Greenfield <---")
bucket_name = "bucket_name"
object_name = "object_name"
await client.async_init()
## download Object of model
path = "/path/to/your/dir"
logging.info(f"---> Get Object <---")
await client.object.fget_object(
bucket_name,
object_name,
path,
opts=GetObjectOption()
)
logging.info(f"Result: {fget_object}\n\n")
# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained(path)
tokenizer = T5Tokenizer.from_pretrained(model_dir,legacy=False)
# Prepare input text
input_text = "translate English to French: The quick brown fox jumps over the lazy dog."
# Tokenize input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate output with max_new_tokens
output = model.generate(input_ids, max_new_tokens=50) # Generates up to 50 new tokens
# Decode and print the translated text
translated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(translated_text)
if __name__ == "__main__":
asyncio.run(main())