## Imports and Installs

In [14]:
import os
import json
import boto3
import pickle
import pandas as pd
import sagemaker as sm

from datetime import date

from sagemaker.pytorch import PyTorch
from sagemaker.pytorch import PyTorchModel
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.processing import ProcessingInput
from sagemaker.processing import ProcessingOutput

#### SageMaker Parameters

In [3]:
today = date.today()
today_str = today.strftime('%Y-%m-%d')
role = sm.get_execution_role()
sagemaker_session = sm.session.Session()
region = sagemaker_session._region_name
account = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account']

bucket = sagemaker_session.default_bucket()
print(bucket)

sagemaker-us-east-1-047840628716


#### Requirements File

In [35]:
%%writefile ./source_dir/requirements.txt
transformers

Overwriting ./source_dir/requirements.txt


## Data Preparation

#### Create Processing Script

In [4]:
%%writefile ./source_dir/medical_language_processing.py

import os
import torch
import random
import pickle
import argparse
import pandas as pd

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

RANDOM_STATE = 2023
random.seed(RANDOM_STATE)
# np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)


class GPT2Dataset(Dataset):
    def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=768, bos_token='<|startoftext|>', eos_token='<|endoftext|>'):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
        
        for txt in txt_list:
            encodings_dict = tokenizer(bos_token + txt + eos_token, truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]


def save_df(df, filepath, labels=None, labels_path=None):
    df.to_csv(filepath)
    if labels is not None:
        with open(labels_path, "wb") as fp:   
            pickle.dump(labels, fp)


def save_dataset(dataset, filepath):
    torch.save(dataset, filepath)


def load_data(filename, left_col, right_col, index_col=0):
    filepath = os.path.join("/opt/ml/processing/input/", filename)
    df = pd.read_csv(filepath, index_col=index_col).dropna().reset_index(drop=True)
    labels = list(df[left_col].drop_duplicates().dropna().values.ravel())
    df = df[left_col] + ' | ' + df[right_col]
    return df, labels


def process_data(df, test_size, valid_size):
    ids = list(df.index.drop_duplicates().values.ravel())
    id_train, id_test = train_test_split(ids, test_size=test_size, shuffle=True, random_state=RANDOM_STATE)
    id_test, id_valid = train_test_split(id_test, test_size=valid_size, shuffle=True, random_state=RANDOM_STATE)

    df_train = df[df.index.isin(id_train)]
    df_valid = df[df.index.isin(id_valid)]
    df_test = df[df.index.isin(id_test)]

    print(f"Training Data Length:   {len(df_train)}")
    print(f"Testing Data Length:    {len(df_test)}")
    print(f"Validation Data Length: {len(df_valid)}")
    return df_train, df_test, df_valid


def create_dataset(df, pretrained_path, max_length, batch_size=2, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>'):
    tokenizer = AutoTokenizer.from_pretrained(pretrained_path, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token)
    dataset = GPT2Dataset(df, tokenizer, max_length=max_length, bos_token=bos_token, eos_token=eos_token)
    dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)
    return dataloader


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--test_size', type=str, default='0.2')
    parser.add_argument('--valid_size', type=str, default='0.5')
    parser.add_argument('--left_col', type=str, default='')
    parser.add_argument('--right_col', type=str, default='')
    parser.add_argument('--pretrained_path', type=str, default='distilgpt2')
    parser.add_argument('--filename', type=str, default='')
    parser.add_argument('--max_length', type=str, default='128')
    parser.add_argument('--bos_token', type=str, default='<|startoftext|>')
    parser.add_argument('--eos_token', type=str, default='<|endoftext|>')
    parser.add_argument('--pad_token', type=str, default='<|pad|>')
    parser.add_argument('--batch_size', type=str, default='3')
    args = parser.parse_args()
    
    df, labels = load_data(args.filename, args.left_col, args.right_col)
    df_train, df_test, df_valid = process_data(df, float(args.test_size), float(args.valid_size))
    dataloader_train = create_dataset(df_train, args.pretrained_path, int(args.max_length), int(args.batch_size), args.bos_token, args.eos_token, args.pad_token)
    dataloader_test = create_dataset(df_test, args.pretrained_path, int(args.max_length), int(args.batch_size), args.bos_token, args.eos_token, args.pad_token)
    dataloader_valid = create_dataset(df_valid, args.pretrained_path, int(args.max_length), int(args.batch_size), args.bos_token, args.eos_token, args.pad_token)
    
    save_df(df_train, os.path.join('/opt/ml/processing/output/','df_train.csv'), labels=labels, labels_path=os.path.join('/opt/ml/processing/output/','labels.pkl'))
    save_df(df_test, os.path.join('/opt/ml/processing/output/','df_test.csv'))
    save_df(df_valid, os.path.join('/opt/ml/processing/output/','df_valid.csv'))
    save_dataset(dataloader_train, os.path.join('/opt/ml/processing/output/','dataset_train.bin'))
    save_dataset(dataloader_test, os.path.join('/opt/ml/processing/output/','dataset_test.bin'))
    save_dataset(dataloader_valid, os.path.join('/opt/ml/processing/output/','dataset_valid.bin'))
    print("Completed Processing!")

Overwriting ./source_dir/medical_language_processing.py


#### Processing Parameters

In [5]:
# HuggingFaceProcessor only supports GPU for now
pytorch_processor = PyTorchProcessor(
    framework_version='1.8',
    role=role,
    instance_type='ml.m5.xlarge',
    instance_count=1,
    base_job_name='distilGPT-Processing-Job'
)

#s3://sagemaker-us-east-1-047840628716/data/raw/text/
source = f"s3://{bucket}/data/raw/text"
destination = "/opt/ml/processing/input"
inputs=[ProcessingInput(source=source, destination=destination)]


source = "/opt/ml/processing/output"
destination = f"s3://{bucket}/data/processed/text"
outputs = [ProcessingOutput(source=source, destination=destination)]


arguments = [
    "--test_size", "0.2",
    "--valid_size", "0.5",
    "--left_col", "medical_specialty",
    "--right_col", "description",
    "--pretrained_path", "distilgpt2",
    "--filename", "mtsamples.csv",
    "--max_len", "128",
    "--bos_token", "<|startoftext|>",
    "--eos_token", "<|endoftext|>",
    "--pad_token", "<|pad|>",
    "--batch_size", "3",
]

#### Run Processing

In [6]:
if True:
    pytorch_processor.run(
        code='medical_language_processing.py',
        inputs=inputs,
        outputs=outputs,
        arguments=arguments,
        source_dir='./source_dir'
    )

## Model

#### Create Training Script

In [7]:
%%writefile ./source_dir/medical_language_training.py

import os
import torch
import random
import pickle
import argparse
import pandas as pd

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler

from torch.optim import AdamW
from transformers import GPT2LMHeadModel
from transformers import GPT2Config
from transformers import AutoTokenizer
from transformers import get_linear_schedule_with_warmup

EXAMPLE_PROMPTS = ['']

class GPT2Dataset(Dataset):
    def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=768, bos_token='<|startoftext|>', eos_token='<|endoftext|>'):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
        
        for txt in txt_list:
            encodings_dict = tokenizer(bos_token + txt + eos_token, truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]


def load_data(filename, path, label_filename=None, label_path=None):
    filepath = os.path.join(path, filename)
    dataloader = torch.load(filepath)
    
    if label_filename is not None:
        labels = pickle.load(open(os.path.join(label_path, label_filename), 'rb'))
        EXAMPLE_PROMPTS.remove('')
        EXAMPLE_PROMPTS.extend(labels)
    return dataloader


def create_model(pretrained_path, device, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>'):
    config = GPT2Config.from_pretrained(pretrained_path, output_hidden_states=False)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_path, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token)
    model = GPT2LMHeadModel.from_pretrained(pretrained_path, config=config)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(device)
    return model, tokenizer


def set_training_parameters(dataloader, learning_rate, epsilon, warmup_steps, epochs):
    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=epsilon)
    total_steps = len(dataloader) * epochs 
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    print(f"This training run will process {total_steps} steps in total.")
    return optimizer, scheduler


def train_model(model, dataloader, tokenizer, epochs, optimizer, scheduler, device, sample_every):
    for epoch_i in range(epochs):
        model.train()
        loss_total = 0.

        for step, batch in enumerate(dataloader_train):
            model.zero_grad() 

            ids_batch = batch[0].to(device)
            labels_batch = batch[0].to(device)
            mask_batch = batch[1].to(device)

            output = model(ids_batch, labels=labels_batch, attention_mask=mask_batch)
            loss = output[0]
            loss_batch = loss.item()
            loss_total += loss_batch

            if (step % sample_every) == 0:
                model.eval()
                prompt_embedding = torch.tensor(tokenizer.encode(random.choice(EXAMPLE_PROMPTS) + ' | ')).unsqueeze(0).to(device)
                generated_sample = model.generate(
                    prompt_embedding,
                    pad_token_id=50256,
                    do_sample=True,   
                    top_k=50, 
                    max_length=128,
                    top_p=0.99, 
                    num_return_sequences=5
                )
                for example in generated_sample:
                    generated_sample = tokenizer.decode(example, skip_special_tokens=True)
                    print(generated_sample)
                print('\n')
                model.train()

            loss.backward()
            optimizer.step()
            scheduler.step()
        avg_loss = loss_total / len(dataloader_train)
    print(f'Epochs={epoch_i+1}; TotLoss={loss_total};')
    return model


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
    parser.add_argument('--valid', type=str, default=os.environ['SM_CHANNEL_VALID'])
    parser.add_argument('--labels', type=str, default=os.environ['SM_CHANNEL_LABELS'])
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--filename_train', type=str, default='')
    parser.add_argument('--filename_valid', type=str, default='')
    parser.add_argument('--filename_test', type=str, default='')
    parser.add_argument('--filename_labels', type=str, default='labels.pkl')
    parser.add_argument('--max_length', type=int, default=128)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=100)
    parser.add_argument('--epsilon', type=float, default=1e-8)
    parser.add_argument('--sample_every', type=int, default=100)
    parser.add_argument('--bos_token', type=str, default='<|startoftext|>')
    parser.add_argument('--eos_token', type=str, default='<|endoftext|>')
    parser.add_argument('--pad_token', type=str, default='<|pad|>')
    parser.add_argument('--pretrained_path', type=str, default='distilgpt2')
    parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
    args = parser.parse_args()
    
    device = torch.device(args.device)
    dataloader_train =  load_data(args.filename_train, path='/opt/ml/input/data/train/', label_filename=args.filename_labels, label_path='/opt/ml/input/data/labels/')
    if len(args.filename_valid) > 0:
        dataloader_valid =  load_data(args.filename_valid, path='/opt/ml/input/data/valid/')
    if len(args.filename_test) > 0:
        dataloader_test =  load_data(args.filename_test, path='/opt/ml/input/data/test/')
    model, tokenizer = create_model(args.pretrained_path, device=device, bos_token=args.bos_token, eos_token=args.eos_token, pad_token=args.pad_token)
    optimizer, scheduler = set_training_parameters(dataloader_train, args.learning_rate, args.epsilon, args.warmup_steps, args.epochs)
    model = train_model(model, dataloader_train, tokenizer, args.epochs, optimizer, scheduler, device, args.sample_every)
    
    model.save_pretrained(os.path.join(args.model_dir, "20230112_distilgpt2_medical_generator/"))
    tokenizer.save_pretrained(os.path.join(args.model_dir, "20230112_distilgpt2_medical_generator/"))
    print("Completed Training!")

Overwriting ./source_dir/medical_language_training.py


#### Training Parameters

In [8]:
train_path = f"s3://{bucket}/data/processed/text/dataset_train.bin"
valid_path = f"s3://{bucket}/data/processed/text/dataset_valid.bin"
labels_path = f"s3://{bucket}/data/processed/text/labels.pkl"
inputs = {"train": train_path, "valid": valid_path, "labels":labels_path}


hyperparameters = {
    "device":"cpu",
    "filename_train":"dataset_train.bin",
    "filename_valid":"dataset_valid.bin",
    "filename_test":"",
    "max_length":128,
    "batch_size":4,
    "epochs":10,
    "learning_rate":5e-4,
    "warmup_steps":100,
    "epsilon":1e-8,
    "sample_every":100,
    "pretrained_path":"distilgpt2",
}


pytorch_estimator = PyTorch(
    entry_point="medical_language_training.py",
    source_dir='./source_dir',
    role=role,
    py_version="py36",
    framework_version='1.8',
    instance_count=1,
    instance_type="ml.m5.xlarge",
    hyperparameters=hyperparameters,
    base_job_name='distilGPT-Training-Job-test',
    metric_definitions=[
        {'Name': 'train:epoch', 'Regex': 'Epochs=(.*?);'},
        {'Name': 'train:loss', 'Regex': 'TotLoss=(.*?);'}]
)

#### Run Training

In [9]:
if True:
    pytorch_estimator.fit(inputs)

#### Create Inference Script

In [10]:
%%writefile ./source_dir/medical_language_endpoint_cpu.py

import os
import joblib
import argparse
import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler

from transformers import GPT2LMHeadModel
from transformers import GPT2Config
from transformers import AutoTokenizer


DEVICE = torch.device('cpu')
MODELNAME = '20230112_distilgpt2_medical_generator'

def input_fn(request_body, request_content_type):
    if request_content_type == 'text/csv':
        return request_body
    else:
        return 'Letters | '


def output_fn(prediction, response_content_type):
    return str(prediction)


def model_fn(model_dir):
    model_path = os.path.join(model_dir, MODELNAME)
    model = GPT2LMHeadModel.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
    model_dict = {'model': model, 'tokenizer':tokenizer}
    return model_dict


def predict_fn(input_object, model):
    prompt = input_object
    prompt = torch.tensor(model['tokenizer'].encode(prompt)).unsqueeze(0)
    prompt = prompt.to(DEVICE)
    response = model['model'].generate(
        prompt,
        do_sample=True,
        top_k=75,
        max_length=300,
        top_p=0.99,
        num_return_sequences=1
    )
    response = [model['tokenizer'].decode(x, skip_special_tokens=True).replace(input_object, '') for x in response][0]
    return response

Overwriting ./source_dir/medical_language_endpoint_cpu.py


#### Inference Parameters

In [11]:
trials = 50
instance_type = 'ml.m4.xlarge'
model_data = f's3://{bucket}/distilGPT-Training-Job-test-2023-01-18-23-12-15-239/output/model.tar.gz'
endpoint_name = 'distilGPT-Medical-Endpoint'

model = PyTorchModel(
    model_data=model_data,
    role=role,
    source_dir='./source_dir',
    entry_point='medical_language_endpoint_cpu.py',
    framework_version='1.8',
    py_version='py3',
)

#### Deploy Endpoint

In [12]:
if True:
    endpoint = model.deploy(
        initial_instance_count=1,
        instance_type=instance_type,
        endpoint_name=endpoint_name,
    )

----------!

## Evaluate

#### Generate Examples

In [29]:
labels = pickle.load(open('labels.pkl', 'rb'))
generated_samples = []
generated_labels = []


for lbl in labels:
    print(lbl)
    prompt = lbl + ' | '
    for _ in range(trials):
        response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(EndpointName=endpoint_name, Body=prompt.encode(encoding="ascii"), ContentType="text/csv")
        try:
            generated_samples.append(response['Body'].read().decode('ascii'))
        except:
            generated_samples.append('')
        generated_labels.append(lbl)
df_generated = pd.DataFrame({'medical_specialty':generated_labels, 'description':generated_samples})
df_generated

 Allergy / Immunology
 Bariatrics
 Cardiovascular / Pulmonary
 Dentistry
 Urology
 General Medicine
 Surgery
 Speech - Language
 SOAP / Chart / Progress Notes
 Sleep Medicine
 Rheumatology
 Radiology
 Psychiatry / Psychology
 Podiatry
 Physical Medicine - Rehab
 Pediatrics - Neonatal
 Pain Management
 Orthopedic
 Ophthalmology
 Office Notes
 Obstetrics / Gynecology
 Neurosurgery
 Neurology
 Nephrology
 Letters
 Lab Medicine - Pathology
 IME-QME-Work Comp etc.
 Hospice - Palliative Care
 Hematology - Oncology
 Gastroenterology
 ENT - Otolaryngology
 Endocrinology
 Emergency Room Reports
 Discharge Summary
 Diets and Nutritions
 Dermatology
 Cosmetic / Plastic Surgery
 Consult - History and Phy.
 Chiropractic


Unnamed: 0,medical_specialty,description
0,Allergy / Immunology,A 23-year-old white female presents with comp...
1,Allergy / Immunology,"The patient is a 17-year-old female, who pres..."
2,Allergy / Immunology,Possible inflammatory bowel disease. Polyp o...
3,Allergy / Immunology,A 23-year-old white female presents with comp...
4,Allergy / Immunology,A 23-year-old white female presents with comp...
...,...,...
1945,Chiropractic,Diagnostic fiberoptic bronchoscopy.
1946,Chiropractic,The patient with epigastric and right upper q...
1947,Chiropractic,Acute acalculous cholecystitis. Open cholecy...
1948,Chiropractic,Excision of left breast mass. The mass was i...


In [None]:
if True:
    endpoint.sagemaker_session.delete_endpoint(endpoint_name)

#### Save Generations

In [30]:
df_generated.drop_duplicates().to_csv('./mtsamples_generated_modern.csv', sep=',')