In [None]:
import shap
import sklearn
import numpy as np
import pandas as pd
import sklearn
import sklearn.ensemble
import sklearn.metrics
import accelerate
import pytorch_lightning as pl
from transformers import pipeline, AutoTokenizer, AutoModel, DataCollatorWithPadding, EvalPrediction, TrainingArguments, Trainer, OPTForSequenceClassification, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
from torch.optim import AdamW
from torch.utils.data import TensorDataset
import torch
import torch.nn as nn
import evaluate
import tqdm.notebook as tq
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from __future__ import print_function
import os 

In [None]:
# Hyperparameters
MAX_LEN = 2048
MODEL = "facebook/opt-350m"
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
EPOCHS = 5
LEARNING_RATE = 3e-05
# LEARNING_RATE = 5e-05

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", cache_dir='./model_ckpt/')


In [2]:
class TokenizerWrapper:
    def __init__(self, tokenizer, MAX_LEN):
        self.tokenizer = tokenizer
        self.max_length = MAX_LEN
        self.classes = [class_ for class_ in labels_10_top50["icd_code"] if class_]
        self.class2id = {class_: id for id, class_ in enumerate(classes)}
        self.id2class = {id: class_ for class_, id in class2id.items()}
        
    def multi_labels_to_ids(self, labels: list[str]) -> list[float]:
        ids = [0.0] * len(self.class2id)  # BCELoss requires float as target type
        for label in labels:
            ids[self.class2id[label]] = 1.0
        return ids
    
    def tokenize_function(self, example):
        result = self.tokenizer(
            example["text"],
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_tensors='pt'
        )
        result["label"] = torch.tensor([self.multi_labels_to_ids(eval(label)) for label in example["label"]])
        return result

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "output/"
print(device)

In [None]:
train_path = 'data/train_10_top50.csv'
test_path = 'data/test_10_top50.csv'
val_path = 'data/val_10_top50.csv'
labels_path = 'data/icd10_codes_top50.csv'

train_10_top50 = pd.read_csv(train_path)
val_10_top50 = pd.read_csv(val_path)
test_10_top50 = pd.read_csv(test_path)

train_10_top50_shorten = pd.read_csv(train_path)[:5000]
val_10_top50_shorten = pd.read_csv(val_path)[:5000]
test_10_top50_shorten = pd.read_csv(test_path)[:5000]

train_short_path = "data/train_10_top50_short.csv"
val_short_path = "data/val_10_top50_short.csv"
test_short_path = "data/test_10_top50_short.csv"

train_10_top50_shorten.to_csv(train_short_path, index=False)
val_10_top50_shorten.to_csv(val_short_path, index=False)
test_10_top50_shorten.to_csv(test_short_path, index=False)

labels_10_top50 = pd.read_csv('data/icd10_codes_top50.csv')

In [None]:
classes = [class_ for class_ in labels_10_top50["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

In [None]:
data_files = {
        "train": train_short_path,
        "validation": val_short_path,
        "test": test_short_path,
    }
tokenizer_wrapper = TokenizerWrapper(tokenizer, MAX_LEN)
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=1)
dataset = dataset.with_format("torch")

In [None]:
lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "v_proj"],
    task_type=TaskType.SEQ_CLS,
    lora_alpha=32,
    lora_dropout=0.05,
)

config, unused_kwargs = AutoConfig.from_pretrained(
    MODEL,
    num_labels=len(classes),
    id2label=id2class,
    label2id=class2id,
    problem_type="multi_label_classification",
    return_unused_kwargs=True,
)


model = OPTForSequenceClassification.from_pretrained(
    MODEL,
    config=config,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
config = PeftConfig.from_pretrained("model_ckpt/icd-9-shard20-checkpoint-1182/")
model = PeftModel.from_pretrained(model,
                                  "model_ckpt/icd-9-shard20-checkpoint-1182/",
                                  is_trainable=False,)

In [None]:
model.print_trainable_parameters()

In [None]:
classifier = pipeline("text-classification", model=model,
                                        tokenizer=tokenizer,
                                        device=device)

In [None]:
explainer = shap.DeepExplainer(model, data=dataset['validation'])

In [None]:
shap_values = explainer.shap_values(dataset['test'])

In [None]:
shap.summary_plot(shap_values, dataset['train'], feature_names=id2class)