# Imports and Installs

#### General Imports

The only package that we need here (outside of SageMaker jobs) will be pandas and date utilities.

In [12]:
import os
import pickle
import pandas as pd
from datetime import date

#### SageMaker Imports

For inference, the only packages we need is the PyTorch Estimator/Model

In [3]:
import sagemaker as sm
from sagemaker.pytorch import PyTorchModel

#### SageMaker Parameters

We will set any parameters we need to pass to the model.

In [4]:
today             = date.today()
today_str         = today.strftime('%Y-%m-%d')
role              = sm.get_execution_role()
sagemaker_session = sm.session.Session()
region            = sagemaker_session._region_name
bucket            = sagemaker_session.default_bucket()

#### Requirements File

This cell creates a requirements.txt file for any needed packages. The only package we need here is the transformers library.

In [5]:
path = './source_dir'
try: 
    os.mkdir(path) 
except OSError as error: 
    pass

In [6]:
%%writefile ./source_dir/requirements.txt
transformers

Overwriting ./source_dir/requirements.txt


# Create Inference Script
When adding our custom model for inference, we can overwrite functions around data input/output and prediction.
For real-time prediction, the methods we will overwrite are:
1. `input_fn(request_body, request_content_type)`
2. `output_fn(prediction, response_content_type)`
3. `model_fn(model_dir)`
4. `predict_fn(input_object, model)`

In [7]:
%%writefile ./source_dir/medical_language_endpoint_cpu.py

import os
import joblib
import argparse
import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler

from transformers import GPT2LMHeadModel
from transformers import GPT2Config
from transformers import AutoTokenizer


DEVICE = torch.device('cpu')
MODELNAME = '20230112_distilgpt2_medical_generator'

def input_fn(request_body, request_content_type):
    if request_content_type == 'text/csv':
        return request_body
    else:
        return 'Letters | '


def output_fn(prediction, response_content_type):
    return str(prediction)


def model_fn(model_dir):
    model_path = os.path.join(model_dir, MODELNAME)
    model = GPT2LMHeadModel.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
    model_dict = {'model': model, 'tokenizer':tokenizer}
    return model_dict


def predict_fn(input_object, model):
    prompt = input_object
    prompt = torch.tensor(model['tokenizer'].encode(prompt)).unsqueeze(0)
    prompt = prompt.to(DEVICE)
    response = model['model'].generate(
        prompt,
        do_sample=True,
        top_k=75,
        max_length=300,
        top_p=0.99,
        num_return_sequences=1
    )
    response = [model['tokenizer'].decode(x, skip_special_tokens=True).replace(input_object, '') for x in response][0]
    return response

Overwriting ./source_dir/medical_language_endpoint_cpu.py


### Inference Parameters



In [9]:
trials = 2
instance_type = 'ml.m4.xlarge'
model_data    = f's3://{bucket}/distilGPT-Training-Job-test-2023-01-18-23-12-15-239/output/model.tar.gz'
endpoint_name = f'distilGPT-Medical-Endpoint-{today_str}'

model = PyTorchModel(
    model_data=model_data,
    role=role,
    source_dir='./source_dir',
    entry_point='medical_language_endpoint_cpu.py',
    framework_version='1.8',
    py_version='py3',
)

#### Deploy Endpoint

In [10]:
endpoint = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
)

---------!

# Generate Examples

We can now use `invoke_endpoint()` to generate text.

In [13]:
labels = pickle.load(open('labels.pkl', 'rb'))
generated_samples = []
generated_labels = []


for lbl in labels:
    print(lbl)
    prompt = lbl + ' | '
    for _ in range(trials):
        response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(EndpointName=endpoint_name, Body=prompt.encode(encoding="ascii"), ContentType="text/csv")
        try:
            generated_samples.append(response['Body'].read().decode('ascii'))
        except:
            generated_samples.append('')
        generated_labels.append(lbl)
df_generated = pd.DataFrame({'medical_specialty':generated_labels, 'description':generated_samples})
df_generated

 Allergy / Immunology
 Bariatrics
 Cardiovascular / Pulmonary
 Dentistry
 Urology
 General Medicine
 Surgery
 Speech - Language
 SOAP / Chart / Progress Notes
 Sleep Medicine
 Rheumatology
 Radiology
 Psychiatry / Psychology
 Podiatry
 Physical Medicine - Rehab
 Pediatrics - Neonatal
 Pain Management
 Orthopedic
 Ophthalmology
 Office Notes
 Obstetrics / Gynecology
 Neurosurgery
 Neurology
 Nephrology
 Letters
 Lab Medicine - Pathology
 IME-QME-Work Comp etc.
 Hospice - Palliative Care
 Hematology - Oncology
 Gastroenterology
 ENT - Otolaryngology
 Endocrinology
 Emergency Room Reports
 Discharge Summary
 Diets and Nutritions
 Dermatology
 Cosmetic / Plastic Surgery
 Consult - History and Phy.
 Chiropractic


Unnamed: 0,medical_specialty,description
0,Allergy / Immunology,A 23-year-old white female presents with comp...
1,Allergy / Immunology,A 23-year-old white female presents with comp...
2,Bariatrics,"Patient with a diagnosis of pancreatitis, dev..."
3,Bariatrics,Patient with a family history of premature co...
4,Cardiovascular / Pulmonary,Echocardiographic Examination Report. Angin...
...,...,...
73,Cosmetic / Plastic Surgery,Belly button piercing for insertion of belly ...
74,Consult - History and Phy.,"The patient is a 16-month-old boy, who had a ..."
75,Consult - History and Phy.,An 84-year-old woman with a history of hypert...
76,Chiropractic,"Extractable epilepsy, here for video EEG."


#### Delete Endpoint

In [14]:
endpoint.delete_endpoint(delete_endpoint_config=True)