## Installation and Setup

In [None]:
!pip install "sagemaker>=2.140.0" "transformers==4.26.1" "datasets[s3]==2.10.1" --upgrade

In [None]:
import sagemaker

sess = sagemaker.Session()
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
	sagemaker_session_bucket = sess.default_bucket()

s3_prefix = "patient-doctor-text-classifier"

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

## Preprocess

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# load dataset
train_dataset, validation_dataset, test_dataset = load_dataset("LukeGPT88/patient-doctor-text-classifier-dataset", split=["train", "validation", "test"])

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# create tokenization function
def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True)

# tokenize train, validation and test datasets
train_dataset = train_dataset.map(tokenize, batched=True)
validation_dataset = validation_dataset.map(tokenize, batched=True)
test_dataset = test_dataset.map(tokenize, batched=True)

# set dataset format for PyTorch
train_dataset =  train_dataset.rename_column("label", "labels")
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
validation_dataset = validation_dataset.rename_column("label", "labels")
validation_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset = test_dataset.rename_column("label", "labels")
test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

## Upload datasets to S3 bucket

In [None]:
# save train_dataset to s3
training_input_path = f's3://{sess.default_bucket()}/{s3_prefix}/train'
train_dataset.save_to_disk(training_input_path)

# save validation_dataset to s3
validation_input_path = f's3://{sess.default_bucket()}/{s3_prefix}/validation'
validation_dataset.save_to_disk(validation_input_path)

# save test_dataset to s3
test_input_path = f's3://{sess.default_bucket()}/{s3_prefix}/test'
test_dataset.save_to_disk(test_input_path)

## Clean up the S3 bucket

In [None]:
import boto3    

s3resource = boto3.client('s3')
buckets = s3resource.list_buckets()['Buckets']
bucket_names = [bucket['Name'] for bucket in buckets]

s3 = boto3.resource('s3')

for bucket_name in bucket_names:
  bucket = s3.Bucket(bucket_name)
  bucket.object_versions.delete()
  bucket.delete()