<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/model/MIMIC/MIMIC_simcse_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install sentence_transformers transformers datasets

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from datasets import load_dataset
import json

In [None]:
!unzip /content/drive/MyDrive/mimic.jsonl.zip -d content
with open('content/mimic.jsonl') as f:
    data = [json.loads(line) for line in f]

In [None]:
!git clone https://github.com/danielsaggau/IR_LDC.git

In [None]:
import shutil
shutil.move("/content/content/mimic.jsonl", "/content/IR_LDC/model/MIMIC")
dataset = load_dataset("/content/IR_LDC/model/MIMIC/mimic-dataset.py")

In [None]:
dataset

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('XXXXX')"
tokenizer = AutoTokenizer.from_pretrained("/content/drive/MyDrive/bregman_mimic_clinical_longformer/20000", use_auth_token=True, use_fast=True)

In [None]:
dataset

In [None]:
from sentence_transformers import SentenceTransformer
#model = SentenceTransformer("danielsaggau/MIMIC_SIMCSE",  use_auth_token=True)
model = AutoModelForSequenceClassification.from_pretrained("/content/drive/MyDrive/bregman_mimic_clinical_longformer/20000", use_auth_token=True,num_labels=19, problem_type='multi_label_classification')

In [None]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # fp16

In [None]:
train_dataset=dataset['train']

In [None]:
num_labels = train_dataset.features['labels'].feature.num_classes
label_ids = train_dataset.features['labels'].feature.names

label_names = label_ids
label_list = list(range(num_labels))

In [None]:
   def preprocess_function(examples):
        # Tokenize the texts
        batch = tokenizer(
            examples["text"],
            padding='max_length',
            max_length=512,
            truncation=True)
        
        batch = tokenizer.pad(
            batch,
            padding='max_length',
            max_length=512,
            pad_to_multiple_of=8,
        )
        batch["label_ids"] = [[1.0 if label in labels else 0.0 for label in label_list] for labels in examples["labels"]]
        return batch

In [None]:
tokenized_data = dataset.map(preprocess_function, batched=True, remove_columns=['labels'])     

In [None]:
from transformers import EvalPrediction
   
def compute_metrics(p: EvalPrediction):
        logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = (expit(logits) > 0.5).astype(int)
        label_ids = (p.label_ids > 0.5).astype(int)
        macro_f1 = f1_score(y_true=label_ids, y_pred=preds, average='macro', zero_division=0)
        micro_f1 = f1_score(y_true=label_ids, y_pred=preds, average='micro', zero_division=0)
        return {'macro_f1': macro_f1, 'micro_f1': micro_f1}

In [None]:
from transformers import TrainingArguments
from transformers import EarlyStoppingCallback
training_args = TrainingArguments(
    output_dir="/clongformer_mimic_classification",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=8,
    weight_decay=0.01,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    push_to_hub=True,
    fp16=True,
    metric_for_best_model="f1-micro",
    greater_is_better=True,
    load_best_model_at_end = True
)

In [None]:
tokenized_data['test']

In [None]:
import torch 
torch.cuda.empty_cache()

In [None]:
import torch, gc
import os
obj=None
gc.collect()
from scipy.special import expit

In [None]:
from transformers import Trainer
trainer = Trainer(
    model=model,
    compute_metrics=compute_metrics,
    args=training_args,
    eval_dataset=tokenized_data['test'],
    train_dataset=tokenized_data["train"],
    tokenizer=tokenizer,
    data_collator=data_collator,    
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)])
trainer.train()