# Load libraries

In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, EarlyStoppingCallback
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from datasets import Dataset, DatasetDict, load_from_disk
from rouge import Rouge

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device:', device)

model_path="t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)

SEED = 42
np.random.seed(SEED)
sklearn.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

# Data Preparation

In [None]:
DATASET_DIR = 'path/to/Datasets'
PRED_DATA_DIR = 'path/to/sentence_predictions'

# predicted_train_df = pd.read_csv('train_pred.csv')
train_df = pd.read_csv(os.path.join(PRED_DATA_DIR, 'train_pred.csv'))
train_df.rename(columns={'prediction': 'predicted_label'}, inplace=True)

val_df = pd.read_csv(os.path.join(PRED_DATA_DIR, 'val_pred.csv'))
val_df.rename(columns={'prediction': 'predicted_label'}, inplace=True)

print(train_df['label'].value_counts())
print(train_df['predicted_label'].value_counts())

In [None]:
mapping = {'positive': 'SA', 
           'negative': 'SA', 
           'unsure': 'SA', 
           'none': 'none'}

train_df['label'] = train_df['label'].map(mapping)
val_df['label'] = val_df['label'].map(mapping)

predicted_mapping = {'positive': 'SA',
                     'negative': 'SA',
                     'unsure': 'SA',
                     'none': 'none'}

train_df['predicted_label'] = train_df['predicted_label'].map(predicted_mapping)
val_df['predicted_label'] = val_df['predicted_label'].map(predicted_mapping)

print(train_df['label'].value_counts())
print(train_df['predicted_label'].value_counts())

In [None]:
def get_stay_df(df):
    # stay_df = df.groupby('hadmid').agg({'context': 'n '.joi}).reset_index()
    # list of context for each hadmid (not joined by string)
    stay_df = df.groupby('hadmid').agg({'context': list}).reset_index()
    stay_df.rename(columns={'context': 'text'}, inplace=True)
    return stay_df

train_stay_df = get_stay_df(train_df)
val_stay_df = get_stay_df(val_df)

# train_stay_df['len'] = train_stay_df['text'].apply(len)
# val_stay_df['len'] = val_stay_df['text'].apply(len)
print(train_stay_df['text'].apply(func=len).describe())

train_stay_df

In [None]:
def getRouge1(ref, pred, kind='r'):  # tokenized input
    return Rouge().get_scores(pred.lower(), ref.lower())[0]['rouge-1'][kind]


def drop_similar_sents(sents):
    # sort sents by alphabet asc then by length des
    sents = sorted(sents, key=lambda x: (x, -len(x)))

    # remove similar sents
    unique_sents = []
    for sent in sents:
        if not any([getRouge1(sent, unique_sent) >= 0.9 for unique_sent in unique_sents]):
            unique_sents.append(sent)

    return unique_sents

train_stay_df['text'] = train_stay_df['text'].apply(drop_similar_sents)
val_stay_df['text'] = val_stay_df['text'].apply(drop_similar_sents)
print(train_stay_df['text'].apply(func=len).describe())

train_stay_df['text'] = train_stay_df['text'].apply(lambda x: ' '.join(x))
val_stay_df['text'] = val_stay_df['text'].apply(lambda x: ' '.join(x))

train_stay_df

In [None]:
annotation_dir = 'path/to/annotations'

with open(os.path.join(annotation_dir, 'train.json')) as f:
    train_labels = json.load(f)
with open(os.path.join(annotation_dir, 'val.json')) as f:
    val_labels = json.load(f)

train_stay_df['label'] = train_stay_df['hadmid'].apply(lambda x: train_labels[str(x)])
val_stay_df['label'] = val_stay_df['hadmid'].apply(lambda x: val_labels[str(x)])

print(train_stay_df['label'].value_counts())
train_stay_df

In [None]:
import imblearn
from imblearn.over_sampling import RandomOverSampler

max_count = train_stay_df['label'].value_counts().max()
sampling_strategy = {
    'positive': max_count,
    'negative': max_count,
    'unsure': max_count,
    'neutral': train_stay_df['label'].value_counts().get('neutral', 0),
}

sampler = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=SEED)
print(train_stay_df['label'].value_counts())

train_stay_df, _ = sampler.fit_resample(train_stay_df, train_stay_df['label'])
print(train_stay_df['label'].value_counts())

train_stay_df = train_stay_df.sample(frac=1, random_state=SEED).reset_index(drop=True)

train_stay_df.shape

In [None]:
dataset = DatasetDict({'train': Dataset.from_pandas(train_stay_df), 
                       'val': Dataset.from_pandas(val_stay_df)})

dataset = dataset.shuffle(seed=SEED)
dataset

In [None]:
MAX_LEN = 1024

def tokenize_function(batch):
    text_tokenized = tokenizer(batch["text"], padding=True, max_length=MAX_LEN, truncation=True, return_tensors="pt")
    label_tokenized = tokenizer(batch["label"], padding=True, max_length=10, truncation=True, return_tensors="pt")
    label_tokenized = label_tokenized['input_ids']
    label_tokenized[label_tokenized == tokenizer.pad_token_id] = -100
    return {'input_ids': text_tokenized['input_ids'],
           'attention_mask': text_tokenized['attention_mask'],
           'labels': label_tokenized}

dataset['train'] = dataset['train'].map(tokenize_function, batched=True, remove_columns=['text', 'label'])
dataset['val'] = dataset['val'].map(tokenize_function, batched=True, remove_columns=['text', 'label'])

# Metrics

In [None]:
def metrics(goldens, predicts, avg='macro', mode='val'):
    def scoring(goldens, predicts, average):
        return precision_score(goldens, predicts, average=avg), recall_score(goldens, predicts, average=avg), f1_score(goldens, predicts, average=avg)
    
    scores = scoring(goldens, predicts, avg)

    return {
        'f1': scores[2],
        'p': scores[0],
        'r': scores[1],
}

In [None]:
def compute_metrics(eval_preds): 
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = metrics(predicts=decoded_preds, goldens=decoded_labels)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
#     print(result)
    return result

# Model

In [None]:
base_model1 = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM, 
    r=8, lora_alpha=32,
    bias="none",
    lora_dropout=0.1,
)

model = get_peft_model(base_model1, peft_config)
model.print_trainable_parameters()

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, label_pad_token_id=-100)

training_args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    predict_with_generate=True,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.15,
    num_train_epochs=50,
    logging_strategy="steps",
    logging_steps=100,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=500,
    eval_steps=500,
    metric_for_best_model="f1",
    greater_is_better=True,
    save_total_limit=1,
    load_best_model_at_end=True,
)

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=5,
                                                early_stopping_threshold=0.001)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset['train'],
    eval_dataset=dataset['val'],
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]
)

In [None]:
trainer.train()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()