In [1]:
%%writefile requirements.txt
sacremoses==0.0.53
sentencepiece==0.1.97
transformers==4.25.1
protobuf~=3.20
torch==1.13.0

Writing requirements.txt


In [None]:
!pip install -r requirements.txt -q

In [3]:
import os
import json
import torch
import requests
import warnings
import numpy as np

from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

warnings.simplefilter('ignore',UserWarning)
torch.__version__

'1.13.0+cu117'

In [None]:
def try_model(model_name, text):
    n_min = int(np.median([len(t.split()) for t in text])// 4)
    n_max = int(np.median([len(t.split()) for t in text])// 1.5)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    seq2seq = pipeline("summarization", model=model,tokenizer=tokenizer)
    return seq2seq(text,max_length=n_max, min_length=n_min)

In [6]:
text = """
The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes.
"""
text_split = [s for s in text.splitlines() if s]
try_model("JulesBelveze/t5-small-headline-generator",text_split)

[{'summary_text': 'The Inflation Reduction Act Lowers Drug Costs, Health Care Costs'}]

In [7]:
!git clone https://github.com/datarobot/datarobot-user-models.git

Cloning into 'datarobot-user-models'...


remote: Enumerating objects: 13674, done.[K
remote: Counting objects:   0% (1/608)[Kremote: Counting objects:   1% (7/608)[Kremote: Counting objects:   2% (13/608)[Kremote: Counting objects:   3% (19/608)[Kremote: Counting objects:   4% (25/608)[Kremote: Counting objects:   5% (31/608)[Kremote: Counting objects:   6% (37/608)[Kremote: Counting objects:   7% (43/608)[Kremote: Counting objects:   8% (49/608)[Kremote: Counting objects:   9% (55/608)[Kremote: Counting objects:  10% (61/608)[Kremote: Counting objects:  11% (67/608)[Kremote: Counting objects:  12% (73/608)[Kremote: Counting objects:  13% (80/608)[Kremote: Counting objects:  14% (86/608)[Kremote: Counting objects:  15% (92/608)[Kremote: Counting objects:  16% (98/608)[Kremote: Counting objects:  17% (104/608)[Kremote: Counting objects:  18% (110/608)[Kremote: Counting objects:  19% (116/608)[Kremote: Counting objects:  20% (122/608)[Kremote: Counting objects:  21% (128/608)[Kremot

remote: Compressing objects:  10% (35/341)[Kremote: Compressing objects:  11% (38/341)[Kremote: Compressing objects:  12% (41/341)[Kremote: Compressing objects:  13% (45/341)[Kremote: Compressing objects:  14% (48/341)[Kremote: Compressing objects:  15% (52/341)[Kremote: Compressing objects:  16% (55/341)[Kremote: Compressing objects:  17% (58/341)[Kremote: Compressing objects:  18% (62/341)[Kremote: Compressing objects:  19% (65/341)[Kremote: Compressing objects:  20% (69/341)[Kremote: Compressing objects:  21% (72/341)[Kremote: Compressing objects:  22% (76/341)[Kremote: Compressing objects:  23% (79/341)[Kremote: Compressing objects:  24% (82/341)[Kremote: Compressing objects:  25% (86/341)[Kremote: Compressing objects:  26% (89/341)[Kremote: Compressing objects:  27% (93/341)[Kremote: Compressing objects:  28% (96/341)[Kremote: Compressing objects:  29% (99/341)[Kremote: Compressing objects:  30% (103/341)[Kremote: Compressing objects:  31%

Receiving objects:   7% (958/13674)Receiving objects:   8% (1094/13674)Receiving objects:   9% (1231/13674)Receiving objects:  10% (1368/13674)

Receiving objects:  11% (1505/13674), 31.37 MiB | 62.73 MiB/s

Receiving objects:  11% (1562/13674), 31.37 MiB | 62.73 MiB/sReceiving objects:  12% (1641/13674), 58.63 MiB | 58.62 MiB/sReceiving objects:  13% (1778/13674), 58.63 MiB | 58.62 MiB/s

Receiving objects:  14% (1915/13674), 58.63 MiB | 58.62 MiB/sReceiving objects:  15% (2052/13674), 58.63 MiB | 58.62 MiB/s

Receiving objects:  16% (2188/13674), 58.63 MiB | 58.62 MiB/s

Receiving objects:  17% (2325/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  18% (2462/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  19% (2599/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  20% (2735/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  21% (2872/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  22% (3009/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  23% (3146/13674), 83.99 MiB | 55.99 MiB/s

Receiving objects:  24% (3282/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  25% (3419/13674), 83.99 MiB | 55.99 MiB/s

Receiving objects:  26% (3556/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  27% (3692/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  28% (3829/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  29% (3966/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  30% (4103/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  31% (4239/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  32% (4376/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  33% (4513/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  34% (4650/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  35% (4786/13674), 83.99 MiB | 55.99 MiB/sReceiving objects:  36% (4923/13674), 83.99 MiB | 55.99 MiB/s

Receiving objects:  36% (5008/13674), 120.05 MiB | 60.02 MiB/s

Receiving objects:  37% (5060/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  38% (5197/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  39% (5333/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  40% (5470/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  41% (5607/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  42% (5744/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  43% (5880/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  44% (6017/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  45% (6154/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  46% (6291/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  47% (6427/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  48% (6564/13674), 120.05 MiB | 60.02 MiB/s

Receiving objects:  49% (6701/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  50% (6837/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  51% (6974/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  52% (7111/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  53% (7248/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  54% (7384/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  55% (7521/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  56% (7658/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  57% (7795/13674), 120.05 MiB | 60.02 MiB/sReceiving objects:  58% (7931/13674), 120.05 MiB | 60.02 MiB/s

Receiving objects:  59% (8068/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  60% (8205/13674), 157.68 MiB | 63.07 MiB/s

Receiving objects:  61% (8342/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  62% (8478/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  63% (8615/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  64% (8752/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  65% (8889/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  66% (9025/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  67% (9162/13674), 157.68 MiB | 63.07 MiB/s

Receiving objects:  68% (9299/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  69% (9436/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  70% (9572/13674), 157.68 MiB | 63.07 MiB/s

Receiving objects:  71% (9709/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  72% (9846/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  73% (9983/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  74% (10119/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  75% (10256/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  76% (10393/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  77% (10529/13674), 157.68 MiB | 63.07 MiB/s

Receiving objects:  78% (10666/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  79% (10803/13674), 157.68 MiB | 63.07 MiB/sReceiving objects:  80% (10940/13674), 157.68 MiB | 63.07 MiB/s

Receiving objects:  80% (11024/13674), 187.23 MiB | 62.41 MiB/s

Receiving objects:  81% (11076/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  82% (11213/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  83% (11350/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  84% (11487/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  85% (11623/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  86% (11760/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  87% (11897/13674), 216.86 MiB | 61.96 MiB/s

Receiving objects:  88% (12034/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  89% (12170/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  90% (12307/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  91% (12444/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  92% (12581/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  93% (12717/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  94% (12854/13674), 216.86 MiB | 61.96 MiB/s

Receiving objects:  95% (12991/13674), 216.86 MiB | 61.96 MiB/s

Receiving objects:  96% (13128/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  97% (13264/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  98% (13401/13674), 216.86 MiB | 61.96 MiB/sReceiving objects:  99% (13538/13674), 216.86 MiB | 61.96 MiB/sremote: Total 13674 (delta 340), reused 467 (delta 248), pack-reused 13066[K
Receiving objects: 100% (13674/13674), 216.86 MiB | 61.96 MiB/sReceiving objects: 100% (13674/13674), 242.65 MiB | 62.24 MiB/s, done.
Resolving deltas:   0% (0/9159)Resolving deltas:   1% (92/9159)Resolving deltas:   2% (184/9159)Resolving deltas:   3% (275/9159)Resolving deltas:   4% (367/9159)Resolving deltas:   5% (459/9159)Resolving deltas:   6% (550/9159)Resolving deltas:   7% (642/9159)Resolving deltas:   8% (733/9159)Resolving deltas:   9% (825/9159)Resolving deltas:  10% (916/9159)Resolving deltas:  11% (1008/9159)Resolving deltas:  12% (1100/9159)Resolving deltas:  13% (1191/9159)Resolving deltas:  14% (1283/9159)Resolving deltas: 

Resolving deltas:  65% (5954/9159)Resolving deltas:  66% (6045/9159)Resolving deltas:  67% (6137/9159)Resolving deltas:  68% (6229/9159)Resolving deltas:  69% (6320/9159)Resolving deltas:  70% (6412/9159)Resolving deltas:  71% (6503/9159)Resolving deltas:  72% (6595/9159)Resolving deltas:  73% (6687/9159)Resolving deltas:  74% (6778/9159)Resolving deltas:  75% (6870/9159)Resolving deltas:  76% (6961/9159)Resolving deltas:  77% (7053/9159)Resolving deltas:  78% (7145/9159)Resolving deltas:  79% (7236/9159)Resolving deltas:  80% (7328/9159)Resolving deltas:  81% (7419/9159)Resolving deltas:  82% (7511/9159)Resolving deltas:  83% (7602/9159)Resolving deltas:  84% (7694/9159)Resolving deltas:  85% (7786/9159)Resolving deltas:  86% (7877/9159)Resolving deltas:  87% (7969/9159)Resolving deltas:  88% (8060/9159)Resolving deltas:  89% (8152/9159)Resolving deltas:  90% (8244/9159)Resolving deltas:  91% (8335/9159)Resolving deltas:  92% (8427/9159)Resolving deltas:  9

Resolving deltas: 100% (9159/9159)Resolving deltas: 100% (9159/9159), done.


Updating files:  57% (419/725)Updating files:  58% (421/725)Updating files:  59% (428/725)Updating files:  60% (435/725)Updating files:  61% (443/725)Updating files:  62% (450/725)Updating files:  63% (457/725)Updating files:  64% (464/725)Updating files:  65% (472/725)Updating files:  66% (479/725)Updating files:  67% (486/725)Updating files:  68% (493/725)Updating files:  69% (501/725)Updating files:  70% (508/725)Updating files:  71% (515/725)Updating files:  72% (522/725)

Updating files:  73% (530/725)Updating files:  74% (537/725)Updating files:  75% (544/725)Updating files:  76% (551/725)Updating files:  77% (559/725)Updating files:  78% (566/725)Updating files:  79% (573/725)

Updating files:  80% (580/725)

Updating files:  81% (588/725)Updating files:  82% (595/725)

Updating files:  83% (602/725)Updating files:  84% (609/725)Updating files:  85% (617/725)Updating files:  86% (624/725)Updating files:  87% (631/725)Updating files:  88% (638/725)

Updating files:  89% (646/725)Updating files:  90% (653/725)Updating files:  91% (660/725)Updating files:  92% (667/725)Updating files:  93% (675/725)Updating files:  94% (682/725)

Updating files:  95% (689/725)Updating files:  96% (696/725)Updating files:  97% (704/725)

Updating files:  98% (711/725)Updating files:  99% (718/725)Updating files: 100% (725/725)Updating files: 100% (725/725), done.


In [8]:
!ls datarobot-user-models/public_dropin_environments/python3_pytorch/

Dockerfile  __init__.py		 env_info.json	requirements.txt
README.md   dr_requirements.txt  fit.sh		start_server.sh


In [9]:
%%writefile datarobot-user-models/public_dropin_environments/python3_pytorch/requirements.txt
sacremoses==0.0.53
sentencepiece==0.1.97
transformers==4.25.1
protobuf~=3.20
torch==1.13.0
numpy==1.22.0
pandas==1.3.0
scikit-learn==0.24.2

Overwriting datarobot-user-models/public_dropin_environments/python3_pytorch/requirements.txt


In [10]:
import datarobot as dr
from datarobot.client import Client

# Connect to DataRobot
Client()


<datarobot.rest.RESTClientObject at 0x7f782428df10>

In [None]:
environment_folder = "./datarobot-user-models/public_dropin_environments/python3_pytorch"
## Create the environment, which will eventually contain versions  ##
execution_environment = dr.ExecutionEnvironment.create(
    name="Python 3 HuggingFace1",
    description="This environment contains hf library.",
)

## Create the environment version ##
environment_version = dr.ExecutionEnvironmentVersion.create(
    execution_environment.id,
    environment_folder,
    max_wait=3600,  # 1 hour timeout
)

[ExecutionEnvironment('[DataRobot] Python 3.9 ONNX Drop-In'),
 ExecutionEnvironment('Python 3 YOLOv5'),
 ExecutionEnvironment('[DataRobot] Julia Drop-In'),
 ExecutionEnvironment('[DataRobot] Legacy Code Environment'),
 ExecutionEnvironment('[DEPRECATED] H2O Drop-In'),
 ExecutionEnvironment('[DataRobot] Python 3.9 PMML Drop-In'),
 ExecutionEnvironment('[DataRobot] R 4.2.1 Drop-In'),
 ExecutionEnvironment('[DataRobot] Python 3.9 PyTorch Drop-In'),
 ExecutionEnvironment('[DataRobot] Java 11 Drop-In (DR Codegen, H2O)'),
 ExecutionEnvironment('[DataRobot] Python 3.9 Scikit-Learn Drop-In'),
 ExecutionEnvironment('[DataRobot] Python 3.9 XGBoost Drop-In'),
 ExecutionEnvironment('[DataRobot] Python 3.9 Keras Drop-In')]

In [None]:
def save_model(model_name):
    folder_name = model_name.split("/")[1]
    os.mkdir(f"./{folder_name}")
    os.mkdir(f"./{folder_name}/model")
    os.mkdir(f"./{folder_name}/tokenizer")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.save_pretrained(f"./{folder_name}/model")
    tokenizer.save_pretrained(f"./{folder_name}/tokenizer")
    return folder_name
folder_name = save_model("JulesBelveze/t5-small-headline-generator")

In [14]:
!ls -l {folder_name}

total 0
drwxr-xr-x 2 nbx nbx  50 Dec 16 14:37 model
drwxr-xr-x 2 nbx nbx 108 Dec 16 14:37 tokenizer


In [15]:
%%writefile {folder_name}/custom.py
import time
import json
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM


def load_model(dummy):

    tokenizer1 = AutoTokenizer.from_pretrained("./tokenizer")
    model1 = AutoModelForSeq2SeqLM.from_pretrained("./model")
    seq2seq1 = pipeline("summarization", model=model1, tokenizer=tokenizer1)

    return seq2seq1


def score_unstructured(model, data, query, **kwargs):
    # Handle incoming data
    if not data:
        return data
    if isinstance(data, bytes):
        data = data.decode("utf8")

    time_before_model = time.perf_counter()

    data = [s for s in data.splitlines() if s]
    prediction = model(data)

    time_after_model = time.perf_counter()

    # Structure json output
    output_dict = {
        "prediction": prediction,
        "model_run_time_seconds": time_after_model - time_before_model,
    }

    # Serialized output
    serialized_output = json.dumps(output_dict)
    return serialized_output


if __name__ == "__main__":
    test = """Videos that say approved vaccines are dangerous and cause autism, cancer or infertility are among those that will be taken down, the company said.  The policy includes the termination of accounts of anti-vaccine influencers.  Tech giants have been criticised for not doing more to counter false health information on their sites.  In July, US President Joe Biden said social media platforms were largely responsible for people's scepticism in getting vaccinated by spreading misinformation, and appealed for them to address the issue.  YouTube, which is owned by Google, said 130,000 videos were removed from its platform since last year, when it implemented a ban on content spreading misinformation about Covid vaccines.  In a blog post, the company said it had seen false claims about Covid jabs "spill over into misinformation about vaccines in general". The new policy covers long-approved vaccines, such as those against measles or hepatitis B.  "We're expanding our medical misinformation policies on YouTube with new guidelines on currently administered vaccines that are approved and confirmed to be safe and effective by local health authorities and the WHO," the post said, referring to the World Health Organization."""
    model = load_model(None)
    preds = score_unstructured(model=model, data=test, query=None)
    print(preds)


Writing t5-small-headline-generator/custom.py


In [16]:
%cd {folder_name}

/nbx/t5-small-headline-generator


In [17]:
!python custom.py

{"prediction": [{"summary_text": "YouTube says it's seen false claims about Covid jabs \"spill over into misinformation\" 'We're expanding our medical misinformation policies on YouTube,' it says"}], "model_run_time_seconds": 0.6249331739963964}


In [18]:
%cd ..

/nbx


In [None]:
custom_model_folder = f"/nbx/{folder_name}"
## Create the custom model ##
custom_model = dr.CustomInferenceModel.create(
    name=folder_name,
    target_type=dr.TARGET_TYPE.UNSTRUCTURED,
    language='python',
    maximum_memory=4294967296
)

## Create the custom model version ##
model_version = dr.CustomModelVersion.create_clean(
    custom_model_id=custom_model.id,
    folder_path=custom_model_folder,
    base_environment_id=execution_environment.id
)

In [None]:
custom_model_label = folder_name
prediction_server = dr.PredictionServer.list()[0]
deployment = dr.Deployment.create_from_custom_model_version(
    model_version.id,
    label=custom_model_label,
    default_prediction_server_id=prediction_server.id,
    max_wait=3600,  # 1 hour timeout
)

[PredictionServer(https://mlops.dynamic.orm.datarobot.com),
 PredictionServer(https://datarobot-cfds.dynamic.orm.datarobot.com),
 PredictionServer(https://cfds-ccm-prod.orm.datarobot.com)]

In [21]:
# Make predictions on the custom model deployment
test = """Videos that say approved vaccines are dangerous and cause autism, cancer or infertility are among those that will be taken down, the company said.  
The policy includes the termination of accounts of anti-vaccine influencers.  Tech giants have been criticised for not doing more to counter false health information on their sites.  In July, US President Joe Biden said social media platforms were largely responsible for people's scepticism in getting vaccinated by spreading misinformation, and appealed for them to address the issue.  YouTube, which is owned by Google, said 130,000 videos were removed from its platform since last year, when it implemented a ban on content spreading misinformation about Covid vaccines.  In a blog post, the company said it had seen false claims about Covid jabs "spill over into misinformation about vaccines in general". The new policy covers long-approved vaccines, such as those against measles or hepatitis B.  "We're expanding our medical misinformation policies on YouTube with new guidelines on currently administered vaccines that are approved and confirmed to be safe and effective by local health authorities and the WHO," the post said, referring to the World Health Organization."""
url = '{}/predApi/v1.0/deployments/{}/predictionsUnstructured'.format(prediction_server.url, deployment.id)
headers = dr.client.get_client().headers
headers['datarobot-key'] = '544ec55f-61bf-f6ee-0caf-15c7f919a45d'
headers['Content-Type'] = 'text/plain;UTF-8'

response = requests.post(url, headers=headers, data=test)

predictions = response.json()
print(predictions)

{'prediction': [{'summary_text': 'The company says approved vaccines are dangerous and cause autism, cancer or infertility, according to a video released by the company .'}, {'summary_text': 'YouTube says it saw false claims about Covid jabs "spill over into misinformation" \'We\'re expanding our medical misinformation policies on YouTube,\' blog post says'}], 'model_run_time_seconds': 4.530242481967434}
