In [2]:
import torch
import numpy as np
from datasets import Dataset
import pandas as pd
from torch import nn
import transformers
from collections import Counter
import json
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from nbtools.utils import files

# Load Dataset

In [3]:
# read in data
data_pth = f'{files.project_root()}/data/census_data.xlsx'
census = pd.read_excel(data_pth, sheet_name='Census (Stays)')
diagnoses = pd.read_excel(data_pth, sheet_name='DXes')
cohorts = pd.read_excel(data_pth, sheet_name='Cohorts')

# perform inner join to see how many records match
"""
coh_cen = pd.merge(
    cohorts, census, how='inner', on='RESIDENT ID'
)
coh_diag = pd.merge(
    cohorts, diagnoses, how='inner', on='RESIDENT ID'
)

print(f'# of cohorts that map to census: {len(coh_cen)}')
print(f'# of cohorts that map to diagnoses: {len(coh_diag)}')
"""


"\ncoh_cen = pd.merge(\n    cohorts, census, how='inner', on='RESIDENT ID'\n)\ncoh_diag = pd.merge(\n    cohorts, diagnoses, how='inner', on='RESIDENT ID'\n)\n\nprint(f'# of cohorts that map to census: {len(coh_cen)}')\nprint(f'# of cohorts that map to diagnoses: {len(coh_diag)}')\n"

In [4]:
census_cohort = pd.merge(
    census, cohorts, how='inner', on='ADMISS. ID'
)

print(census_cohort)

     ADMISS. ID  FACILITY ID_x  RESIDENT ID_x        START DATE_x  \
0       6723036             12        6722939 2025-09-30 22:26:00   
1       6720202             70        6723127 2025-09-30 22:25:00   
2       6711299             80        6722918 2025-09-30 22:05:00   
3       6711220             78        6723143 2025-09-30 21:12:00   
4       6704842             80        6722788 2025-09-30 20:59:00   
..          ...            ...            ...                 ...   
945     3197032              9        6704667 2025-09-20 18:55:00   
946     3196243             21        6704767 2025-09-20 18:54:00   
947     3195746              8        6704806 2025-09-20 18:53:00   
948     3188664             90        6704980 2025-09-20 18:52:00   
949     3184824             66        6706560 2025-09-20 18:49:00   

             END DATE_x               START REASON  \
0   2025-10-19 10:33:00           Actual Admission   
1   2025-10-18 11:30:00           Actual Admission   
2   2025-

# Merge sheets and Convert to Dataset

In [13]:
# convert to dataset
df = pd.merge(
    census_cohort, diagnoses, how='inner', left_on='RESIDENT ID_x', right_on='RESIDENT ID'
)
df['END DATE_y'] = df['END DATE_y'].astype(str)
ds = Dataset.from_pandas(df)

# determine labels
label_names = list(set(ds['PMTS COHORT']))
labels = [label_names.index(el['PMTS COHORT']) for el in ds]
counter = Counter(labels)

label_counts = {label: counter[i] for i, label in enumerate(label_names)}

print(json.dumps(label_counts, indent=4))

if 'label' not in ds.features.keys():
    ds = ds.add_column('label', labels)
    print('added labels to dataset')

ds = ds.select_columns(['PrimaryDescription', 'Others', 'label'])
ds.set_format('torch')
ds = ds.filter(lambda x: x['PrimaryDescription'] is not None and x['Others'] is not None)

# TODO this is for sanity check. remove later
ds = ds.select(range(64))


{
    "Cardiac - Congestive Heart Failure": 130,
    "Transplant (Other than cardiac)": 4,
    "Pneumonia": 28,
    "Wound": 26,
    "Sepsis": 46,
    "- Other - ": 405,
    "Renal": 13,
    " ": 101,
    "Psych": 9,
    "Orthopedic - Total Hip / Total Knee Arthroplasty": 9,
    "Chronic Obstructive Pulmonary Disease": 37,
    "Orthopedic - Other (than Total Hip/Knee Arthroplasty)": 45,
    "Cancer": 18,
    "Stroke": 79
}
added labels to dataset


Filter:   0%|          | 0/950 [00:00<?, ? examples/s]

# Define a new model using BERT to fine-tune

In [8]:
class CohortClassifier(nn.Module):
    def __init__(self, n_class):
        super(CohortClassifier, self).__init__()

        # pretrained BERT model
        self.bert = transformers.BertModel.from_pretrained(
            'bert-base-uncased'
        )

        # classification head
        self.classifier = nn.Linear(768, n_class)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        out1 = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids
        )
        return self.classifier(out1.pooler_output)

# Training Parameters and Training Loop

In [22]:
# initialize model and tokenizer
tok = transformers.BertTokenizer.from_pretrained(
    'bert-base-uncased'
)
model = CohortClassifier(n_class=len(labels))

# Hyperparameters
batch_size = 8
epochs = 5
lr = 5e-4
optimizer = torch.optim.AdamW(
    model.parameters(), lr=lr,
    weight_decay=0
)
loss_fn = nn.CrossEntropyLoss()

# create dataloader
dl = DataLoader(
    ds, batch_size=batch_size,
    shuffle=True,
)

for epoch in tqdm(range(epochs), position=0, desc='Epoch'):
    for batch in tqdm(dl, position=1, desc='Batch', leave=False):

        text_input = [
            f'Primary Description: {pd}\nOthers: {oth}' 
            for pd, oth 
            in zip(batch['PrimaryDescription'], batch['Others'])
        ]

        inputs = tok(
            text_input,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=256
        )

        output = model(**inputs)

        loss = loss_fn(output, batch['label'])
        acc = torch.mean((torch.argmax(output, dim=-1) == batch['label']).float())*100
        print(f'loss={loss.item():.4f}, accuracy: {acc.item():.2f}%')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Batch:   0%|          | 0/8 [00:00<?, ?it/s]

loss=7.1337, accuracy: 0.00%
loss=6.6450, accuracy: 0.00%
loss=6.2053, accuracy: 50.00%
loss=4.9128, accuracy: 50.00%
loss=6.0846, accuracy: 25.00%
loss=5.4563, accuracy: 50.00%
loss=5.9948, accuracy: 0.00%
loss=3.6716, accuracy: 50.00%


Batch:   0%|          | 0/8 [00:00<?, ?it/s]

loss=3.1279, accuracy: 50.00%
loss=3.7798, accuracy: 37.50%
loss=3.1216, accuracy: 50.00%
loss=2.5267, accuracy: 50.00%
loss=3.4469, accuracy: 25.00%
loss=2.8608, accuracy: 37.50%
loss=2.7054, accuracy: 37.50%
loss=2.1461, accuracy: 50.00%


Batch:   0%|          | 0/8 [00:00<?, ?it/s]

loss=2.4546, accuracy: 37.50%
loss=1.4769, accuracy: 75.00%
loss=2.4194, accuracy: 25.00%
loss=1.8152, accuracy: 50.00%
loss=1.9571, accuracy: 50.00%
loss=2.6357, accuracy: 25.00%
loss=2.1576, accuracy: 37.50%
loss=2.3252, accuracy: 37.50%


Batch:   0%|          | 0/8 [00:00<?, ?it/s]

loss=1.8914, accuracy: 62.50%
loss=1.4300, accuracy: 75.00%
loss=2.3380, accuracy: 25.00%
loss=2.1455, accuracy: 25.00%
loss=1.4247, accuracy: 62.50%
loss=2.1015, accuracy: 37.50%
loss=2.5751, accuracy: 25.00%
loss=2.0942, accuracy: 25.00%


Batch:   0%|          | 0/8 [00:00<?, ?it/s]

loss=1.4800, accuracy: 62.50%
loss=2.2014, accuracy: 25.00%
loss=1.8956, accuracy: 25.00%
loss=1.8832, accuracy: 50.00%
loss=1.8713, accuracy: 50.00%
loss=1.7106, accuracy: 62.50%
loss=2.3607, accuracy: 50.00%
loss=2.3780, accuracy: 12.50%
