# Goal

Use https://blog.paperspace.com/sentence-embeddings-pytorch-lightning/ with as few lines of code/text as possible :) 

Budget: $50 in Anthropic API credits!

Constraints: 2 hours of active "coding" or "prompting" time on the day of the talk (https://lu.ma/LAIxNYC).

Data from:

https://colab.research.google.com/github/onefact/electronic-health-records-analysis/blob/main/notebooks/loading_physionet_mimic_iii_data.ipynb

https://colab.research.google.com/github/onefact/electronic-health-records-analysis/blob/main/notebooks/loading_physionet_mimic_iv_data.ipynb 

https://colab.research.google.com/github/onefact/electronic-health-records-analysis/blob/main/notebooks/loading_physionet_mimic_iv_clinical_notes.ipynb

# Setup and prerequisites

Create a new text file, populated from `requirements.in` file in https://colab.research.google.com/github/jaanli/language-model-notebooks/blob/main/notebooks/getting-started.ipynb

In [1]:
%%capture

!pip install -r requirements.txt

# Load packages and extensions needed to get help from large language models like Claude and ChatGPT :)

In [1]:
%load_ext jupyter_ai

In [2]:
%load_ext dotenv

In [3]:
%dotenv 

In [4]:
# Load duckdb, which lets us efficiently load large files
import duckdb

# Load pandas, which lets us manipulate dataframes
import pandas as pd

# Import jupysql Jupyter extension to create SQL cells
%load_ext sql

# Set configrations on jupysql to directly output data to Pandas and to simplify the output that is printed to the notebook.
%config SqlMagic.autopandas = True

%config SqlMagic.feedback = False
%config SqlMagic.displaycon = False

# Allow named parameters (python variables) in SQL cells
%config SqlMagic.named_parameters=True

# Connect jupysql to DuckDB using a SQLAlchemy-style connection string. Either connect to an in memory DuckDB, or a file backed db.
%sql duckdb:///:memory:

Please use a valid option: "warn", "enabled", or "disabled". 
For more information, see the docs: https://jupysql.ploomber.io/en/latest/api/configuration.html#named-parameters


# Look at input data

In [5]:
!ls ~/data/mimic/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/discharge.parquet

/Users/me/data/mimic/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/discharge.parquet


In [13]:
%%capture discharge
%%bash
for file in /Users/me/data/mimic/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/discharge.parquet; do 
    echo "$file"
    echo "----------"
    duckdb -c "DESCRIBE SELECT * FROM parquet_scan('$file')"
    echo "=========="
done

In [14]:
print(discharge)

/Users/me/data/mimic/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/discharge.parquet
----------
┌─────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name │ column_type │  null   │   key   │ default │  extra  │
│   varchar   │   varchar   │ varchar │ varchar │ varchar │ varchar │
├─────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ note_id     │ VARCHAR     │ YES     │         │         │         │
│ subject_id  │ INTEGER     │ YES     │         │         │         │
│ hadm_id     │ INTEGER     │ YES     │         │         │         │
│ note_type   │ VARCHAR     │ YES     │         │         │         │
│ note_seq    │ INTEGER     │ YES     │         │         │         │
│ charttime   │ TIMESTAMP   │ YES     │         │         │         │
│ storetime   │ TIMESTAMP   │ YES     │         │         │         │
│ text        │ VARCHAR     │ YES     │         │         │         │
└─────────────┴─────────────┴─────────┴─────────┴─────

In [15]:
!ls ~/data/mimic/mimic-iv-2.2/d_icd_diagnoses.parquet

/Users/me/data/mimic/mimic-iv-2.2/d_icd_diagnoses.parquet


In [16]:
%%capture d_icd_diagnoses
%%bash
for file in /Users/me/data/mimic/mimic-iv-2.2/d_icd_diagnoses.parquet; do 
    echo "$file"
    echo "----------"
    duckdb -c "DESCRIBE SELECT * FROM parquet_scan('$file')"
    echo "=========="
done

In [17]:
print(d_icd_diagnoses)

/Users/me/data/mimic/mimic-iv-2.2/d_icd_diagnoses.parquet
----------
┌─────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name │ column_type │  null   │   key   │ default │  extra  │
│   varchar   │   varchar   │ varchar │ varchar │ varchar │ varchar │
├─────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ icd_code    │ VARCHAR     │ YES     │         │         │         │
│ icd_version │ VARCHAR     │ YES     │         │         │         │
│ long_title  │ VARCHAR     │ YES     │         │         │         │
└─────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┘



In [18]:
%%capture d_icd_procedures
%%bash
for file in /Users/me/data/mimic/mimic-iv-2.2/d_icd_procedures.parquet; do 
    echo "$file"
    echo "----------"
    duckdb -c "DESCRIBE SELECT * FROM parquet_scan('$file')"
    echo "=========="
done

In [19]:
print(d_icd_procedures)

/Users/me/data/mimic/mimic-iv-2.2/d_icd_procedures.parquet
----------
┌─────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name │ column_type │  null   │   key   │ default │  extra  │
│   varchar   │   varchar   │ varchar │ varchar │ varchar │ varchar │
├─────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ icd_code    │ VARCHAR     │ YES     │         │         │         │
│ icd_version │ BIGINT      │ YES     │         │         │         │
│ long_title  │ VARCHAR     │ YES     │         │         │         │
└─────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┘



In [22]:
%%capture diagnoses_icd
%%bash
for file in /Users/me/data/mimic/mimic-iv-2.2/diagnoses_icd.parquet; do 
    echo "$file"
    echo "----------"
    duckdb -c "DESCRIBE SELECT * FROM parquet_scan('$file')"
    echo "=========="
done

In [23]:
print(diagnoses_icd)

/Users/me/data/mimic/mimic-iv-2.2/diagnoses_icd.parquet
----------
┌─────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name │ column_type │  null   │   key   │ default │  extra  │
│   varchar   │   varchar   │ varchar │ varchar │ varchar │ varchar │
├─────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ subject_id  │ INTEGER     │ YES     │         │         │         │
│ hadm_id     │ INTEGER     │ YES     │         │         │         │
│ seq_num     │ INTEGER     │ YES     │         │         │         │
│ icd_code    │ VARCHAR     │ YES     │         │         │         │
│ icd_version │ INTEGER     │ YES     │         │         │         │
└─────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┘



In [24]:
%%capture procedures_icd
%%bash
for file in /Users/me/data/mimic/mimic-iv-2.2/procedures_icd.parquet; do 
    echo "$file"
    echo "----------"
    duckdb -c "DESCRIBE SELECT * FROM parquet_scan('$file')"
    echo "=========="
done

In [25]:
print(procedures_icd)

/Users/me/data/mimic/mimic-iv-2.2/procedures_icd.parquet
----------
┌─────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name │ column_type │  null   │   key   │ default │  extra  │
│   varchar   │   varchar   │ varchar │ varchar │ varchar │ varchar │
├─────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ subject_id  │ BIGINT      │ YES     │         │         │         │
│ hadm_id     │ BIGINT      │ YES     │         │         │         │
│ seq_num     │ BIGINT      │ YES     │         │         │         │
│ chartdate   │ DATE        │ YES     │         │         │         │
│ icd_code    │ VARCHAR     │ YES     │         │         │         │
│ icd_version │ BIGINT      │ YES     │         │         │         │
└─────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┘



In [26]:
%%capture icu_icustays
%%bash
for file in /Users/me/data/mimic/mimic-iv-2.2/icu_icustays.parquet; do 
    echo "$file"
    echo "----------"
    duckdb -c "DESCRIBE SELECT * FROM parquet_scan('$file')"
    echo "=========="
done

In [27]:
print(icu_icustays)

/Users/me/data/mimic/mimic-iv-2.2/icu_icustays.parquet
----------
┌────────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│  column_name   │ column_type │  null   │   key   │ default │  extra  │
│    varchar     │   varchar   │ varchar │ varchar │ varchar │ varchar │
├────────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ subject_id     │ BIGINT      │ YES     │         │         │         │
│ hadm_id        │ BIGINT      │ YES     │         │         │         │
│ stay_id        │ BIGINT      │ YES     │         │         │         │
│ first_careunit │ VARCHAR     │ YES     │         │         │         │
│ last_careunit  │ VARCHAR     │ YES     │         │         │         │
│ intime         │ TIMESTAMP   │ YES     │         │         │         │
│ outtime        │ TIMESTAMP   │ YES     │         │         │         │
│ los            │ DOUBLE      │ YES     │         │         │         │
└────────────────┴─────────────┴─────────┴─────────┴──────

# Include documentation as context for newer libraries

In [28]:
%%capture lightning_docs

!curl -s "https://blog.paperspace.com/sentence-embeddings-pytorch-lightning/" | sed -e 's/<[^>]*>//g; /^$/d' | tr -s '\n'

In [35]:
example = """
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder
from sentence_transformers import util
predict_data=["I like to watch tv","Watching Television is my favorite time pass","The cat was running after the butterfly","It is so windy today"]
# Wrapping the prediction data inside a datamodule
datamodule = TextClassificationData.from_lists(
    predict_data=predict_data,
    batch_size=8,
)

# We are loading a pre-trained SentenceEmbedder
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2")

trainer = flash.Trainer(gpus=1)

#Since this task is tackled unsupervised, the predict method generates sentence embeddings using the prediction input
embeddings = trainer.predict(model, datamodule=datamodule)

for i in range(0,len(predict_data),2):
  embed1,embed2=embeddings[0][i],embeddings[0][i+1]
  # we are using cosine similarity to compute the similarity score
  cosine_scores = util.cos_sim(embed1, embed2)
  if cosine_scores>=0.5:
      label="Similar"
  else:
      label="Not Similar"
  print("sentence 1:{} | sentence 2:{}| prediction: {}".format(predict_data[i],predict_data[i+1],label))
"""

# Write prompts to get help from large language models for lightning stack

In [37]:
%%ai anthropic-chat:claude-3-opus-20240229 --format code

Given the schema in these files: 

{icu_icustays}
{procedures_icd}
{diagnoses_icd}
{d_icd_procedures}
{d_icd_diagnoses}
{discharge}

And this example code for using Lightning with sentence transformers (ignore the fine-tuning parts, but remember the key abstractions, methods, etc!), from https://blog.paperspace.com/sentence-embeddings-pytorch-lightning/:

```
```


And this example duckdb dialect SQL query that creates a histogram of length of stays in addition to their procedures and diagnoses (from https://colab.research.google.com/github/onefact/electronic-health-records-analysis/blob/main/notebooks/loading_physionet_mimic_iv_clinical_notes.ipynb): 

```
SELECT 
    i.hadm_id,
    ROUND(i.los) AS los_days,
    dp.icd_code AS procedure_icd_code,
    dp.long_title AS procedure_description,
    dd.icd_code AS diagnosis_icd_code,  
    dd.long_title AS diagnosis_description
FROM icu_icustays i
LEFT JOIN procedures_icd p ON i.hadm_id = p.hadm_id
LEFT JOIN d_icd_procedures dp ON p.icd_code = dp.icd_code
LEFT JOIN diagnoses_icd d ON i.hadm_id = d.hadm_id
LEFT JOIN d_icd_diagnoses dd ON d.icd_code = dd.icd_code
```

Proceed to create simple PyTorch Lightning code using the `jaxtyping` and `beartype` libraries for runtime static type checking. 

Ensure that critical methods are decorated using the `@beartype` decorator, and type hints are given using the jaxtyping syntax (https://kidger.site/thoughts/jaxtyping/).

Proceed step-by-step as an elite site reliability/devops/L20 principal warez engineer at google, returning as few tokens as possible, to debug this python & lightning code. Give the complete corrected code ready to be run in a Jupyter Notebook cell!

The output should be the following: 

* for every `procedure_description` and `diagnosis_description` in the data (`icd_code` in most tables/parquet files)
* segment the `text` field in the `discharge` table in the `hadm_id` (hospital admission identifer) that uniquely identifies an intensive care unit admission
* do this using the above example methods like sentence transformers, faiss, etc
* remember to use runtime static type checking with beartyping and jaxtyping throughout
* create code that takes a random sample of sentences from clinical notes, and computes the cosine similarity from every sentence to every `procedure_description` or `diagnosis_description` in the data
* then use the `torchmetrics` library to compute the accuracy / recall / standard metrics for this task and report the performance 
* do this for a small number of hospital admission identifiers (selected at random) and a small number of sentences (also selected at random) for demonstration purposes

You can do it! 

In [None]:
import random
from typing import List, Tuple, Dict

import duckdb
import pandas as pd
import torch
from beartype import beartype
from jaxtyping import Float, Int, Array, PyTree
from lightning.pytorch import LightningModule
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score


@beartype
def load_data(hadm_ids: List[Int[32]]) -> Tuple[pd.DataFrame, pd.DataFrame]:
    query = """
        SELECT 
            i.hadm_id,
            ROUND(i.los) AS los_days,
            dp.icd_code AS procedure_icd_code,
            dp.long_title AS procedure_description,
            dd.icd_code AS diagnosis_icd_code,  
            dd.long_title AS diagnosis_description
        FROM icu_icustays i
        LEFT JOIN procedures_icd p ON i.hadm_id = p.hadm_id
        LEFT JOIN d_icd_procedures dp ON p.icd_code = dp.icd_code
        LEFT JOIN diagnoses_icd d ON i.hadm_id = d.hadm_id
        LEFT JOIN d_icd_diagnoses dd ON d.icd_code = dd.icd_code
        WHERE i.hadm_id IN ({})
    """.format(', '.join(map(str, hadm_ids)))

    con = duckdb.connect()
    df = con.execute(query).df()

    query = """
        SELECT 
            hadm_id,
            text
        FROM discharge
        WHERE hadm_id IN ({})
    """.format(', '.join(map(str, hadm_ids)))

    notes_df = con.execute(query).df()

    return df, notes_df


@beartype
def get_embeddings(model: SentenceTransformer, sentences: List[str]) -> Array[Float[64]]:
    embeddings = model.encode(sentences)
    return embeddings


class SimilarityModel(LightningModule):
    def __init__(self, encoder_model: str):
        super().__init__()
        self.encoder = SentenceTransformer(encoder_model)
        self.metrics = MetricCollection([Accuracy(), Precision(), Recall(), F1Score()])

    @beartype
    def forward(self, sentences: List[str], labels: List[str]) -> PyTree:
        sentence_embeddings = get_embeddings(self.encoder, sentences)
        label_embeddings = get_embeddings(self.encoder, labels)

        similarities = cosine_similarity(sentence_embeddings, label_embeddings)
        predicted_labels = [labels[i] for i in similarities.argmax(axis=1)]

        self.metrics.update(predicted_labels, labels)

        return {"loss": torch.tensor(0), "predicted_labels": predicted_labels}

    def training_step(self, batch, batch_idx):
        sentences, labels = batch
        return self(sentences, labels)

    def validation_step(self, batch, batch_idx):
        sentences, labels = batch
        return self(sentences, labels)

    def test_step(self, batch, batch_idx):
        sentences, labels = batch
        return self(sentences, labels)

    def training_epoch_end(self, outputs):
        self.log_dict(self.metrics.compute())
        self.metrics.reset()

    def validation_epoch_end(self, outputs):
        self.log_dict(self.metrics.compute())
        self.metrics.reset()

    def test_epoch_end(self, outputs):
        self.log_dict(self.metrics.compute())
        self.metrics.reset()


@beartype
def main():
    random.seed(42)
    hadm_ids = random.sample(range(100000, 200000), 10)  # Random hospital admission IDs
    df, notes_df = load_data(hadm_ids)

    labels = list(df["procedure_description"].unique()) + list(df["diagnosis_description"].unique())

    sentences =