### BERT Embeddings

In [1]:
import pandas as pd
import torch
from transformers import BertForMaskedLM
from src.preprocessing import preprocess_df, random_train_test_split, TextEncoder, set_labels_features
from src.tokenizers import CustomBertTokenizer
from src.dataset import LabValuesDataset
from src.train import train_mlm
from src.embeddings import load_model, get_embeddings

  from .autonotebook import tqdm as notebook_tqdm


### Read dataset

In [2]:
FILE = 'data/morning_lab_values.csv'

df = pd.read_csv(FILE)

In [3]:
df.head()

Unnamed: 0,hadm_id,subject_id,itemid,charttime,charthour,storetime,storehour,chartday,valuenum,cnt
0,,10312413,51222,2173-06-05 08:20:00,8,2173-06-05 08:47:00,8,2173-06-05,12.8,8
1,25669789.0,10390828,51222,2181-10-26 07:55:00,7,2181-10-26 08:46:00,8,2181-10-26,9.4,8
2,26646522.0,10447634,51222,2165-03-07 06:55:00,6,2165-03-07 07:23:00,7,2165-03-07,11.1,8
3,27308928.0,10784877,51222,2170-05-11 06:00:00,6,2170-05-11 06:43:00,6,2170-05-11,10.3,8
4,28740988.0,11298819,51222,2142-09-13 07:15:00,7,2142-09-13 09:23:00,9,2142-09-13,10.2,8


### Preprocessing

In [4]:
COLUMNS = ['Bic', 'Crt', 'Pot', 'Sod', 'Ure', 'Hgb', 'Plt', 'Wbc']

In [5]:
mrl = preprocess_df(df, columns_to_scale=COLUMNS)

In [6]:
text_encoder = TextEncoder()
mrl, grouped_mrl = text_encoder.encode_text(mrl)

In [7]:
grouped_mrl

Unnamed: 0,hadm_id,nstr
0,20000019.0,"[BicAS CrtC PotR SodBI UreG HgbAQ PltH WbcB, B..."
1,20000024.0,[BicAU CrtC PotAF SodBL UreJ HgbAV PltI WbcA]
2,20000034.0,"[BicAS CrtG PotAC SodBM UreJ HgbAR PltG WbcB, ..."
3,20000041.0,"[BicAW CrtC PotU SodBI UreG HgbAO PltJ WbcB, B..."
4,20000057.0,"[BicAK CrtC PotY SodBI UreG HgbBG PltG WbcA, B..."
...,...,...
264582,29999625.0,"[BicAZ CrtC PotAC SodBS UreN HgbBM PltO WbcB, ..."
264583,29999670.0,"[BicAZ CrtD PotW SodBO UreJ HgbAV PltG WbcA, B..."
264584,29999723.0,[BicBB CrtC PotU SodBM UreI HgbBN PltJ WbcA]
264585,29999745.0,[BicBD CrtB PotX SodBP UreE HgbBM PltJ WbcA]


### Tokenize

In [8]:
special_tokens = {0: '[PAD]', 101: '[CLS]', 102: '[SEP]', 103: '[MASK]'}
vocab_list = mrl['nstr'].str.split(' ').explode().unique().tolist()

tokenizer = CustomBertTokenizer.create_bert_tokenizer(vocab_list, special_tokens=special_tokens)

In [9]:
text = grouped_mrl['nstr'].apply(lambda x: ' [SEP] '.join(x)).tolist()
train, test = random_train_test_split(text)

In [10]:
train_inputs = tokenizer(train, return_tensors='pt', max_length=100, truncation=True, padding='max_length')

test_inputs = tokenizer(test, return_tensors='pt', max_length=100, truncation=True, padding='max_length')

### Dataset Perparation

In [11]:
MASKING = 0.20
train_inputs, test_inputs = set_labels_features(train_inputs, test_inputs, parcentage=MASKING)

In [12]:
train_dataset = LabValuesDataset(train_inputs)
test_dataset = LabValuesDataset(test_inputs)

In [13]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

### Model

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Apple slilicon:
# device = torch.device('mps') if torch.cuda.is_available() else torch.device('cpu')

In [15]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.resize_token_embeddings(len(tokenizer))
# and move our model over to the selected device
model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(624, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
     

### Training

In [None]:
train_mlm(model, train_loader, test_loader, device, tokenizer)

### Load model

In [19]:
model, tokenizer = load_model(model_path="model/", tokenizer_path="tokenizer/")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'CustomBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.


### Generate embeddings

In [25]:
texts_to_embed = ["BicAS CrtC PotR SodBI UreG HgbAQ PltH WbcB", "BicBD CrtB PotX SodBP UreE HgbBM PltJ WbcA"]

embeddings = get_embeddings(model, tokenizer, texts_to_embed)
print("Embeddings Shape:", embeddings.shape)
print(embeddings)

Embeddings Shape: (2, 10, 768)
[[[-1.1674141  -0.03205423 -0.83196145 ... -0.25210625  0.2533358
    0.5407926 ]
  [-0.9985977   0.6987042  -0.11391093 ... -0.2438933   0.10005566
   -0.2770821 ]
  [ 1.2565708   0.17699586  0.22682366 ...  0.16105117 -0.22948685
    0.41730225]
  ...
  [ 1.2413157   1.1760116  -0.5074808  ... -0.47217977 -0.99362564
   -1.2414502 ]
  [ 0.24586165 -0.2888732   0.8921513  ...  0.29769474 -0.5381557
    0.30063152]
  [-0.09828736 -0.1483139   0.0121102  ...  0.6725325   0.39603958
   -0.43334928]]

 [[-1.2411058   0.35089105 -0.77869827 ... -0.31622928  0.13699816
    0.58438724]
  [-0.69782966  0.5594827   0.5022457  ... -0.35116678  0.5592405
    0.03209981]
  [ 0.10558576  0.47911486  1.0057693  ... -0.10532114 -0.09810198
   -0.15093239]
  ...
  [ 0.46017697  0.4877037   0.68133605 ... -0.04234336 -0.8137266
    0.19985583]
  [-0.9214135  -0.39549658  0.06390405 ...  0.22225861  0.50593334
    0.48475447]
  [-0.13972664  0.02047924  0.13895945 ...  0.