In [2]:
from src.tokenize.constants import SpecialToken as ST
from src.tokenize.datasets import TimelineDataset

In [3]:
import torch
from torch.utils.data import DataLoader, Subset

In [4]:
train_dataset = TimelineDataset(
    input_dir="/workspace/ehr_stuff/EHR_FM/data/tokenized_datasets/mimic_train",
    n_positions=2048,
    is_encoder_decoder=False,
)

In [5]:
vocab = train_dataset.vocab

In [6]:
vocab_size = (len(vocab) // 64 + 1) * 64 if len(vocab) % 64 != 0 else len(vocab)
tokens_of_interest = [ST.DEATH, ST.ADMISSION, ST.DISCHARGE]
tokens_of_interest = {stoken: vocab.encode(stoken) for stoken in tokens_of_interest}

In [7]:
tokens_of_interest # from admissions table

{<SpecialToken.DEATH: 'MEDS_DEATH'>: 438,
 <SpecialToken.ADMISSION: 'HOSPITAL_ADMISSION'>: 142,
 <SpecialToken.DISCHARGE: 'HOSPITAL_DISCHARGE'>: 141}

In [8]:
val_size = int(6 * 1_000_000)
train_dataset, val_dataset = (
    Subset(train_dataset, indices=indices)
    for indices in torch.split_with_sizes(
        torch.arange(len(train_dataset)), [len(train_dataset) - val_size, val_size]
    )
)

In [9]:
def make_infinite_loader(loader):
    while True:
        yield from iter(loader)


In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
train_dataloader = make_infinite_loader(train_dataloader)

In [11]:
batch = next(iter(train_dataloader))

In [12]:
batch[0].shape

torch.Size([32, 2048])

In [23]:
vocab.encode("MEDS_DEATH")

438

In [15]:
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
val_dataloader = make_infinite_loader(val_dataloader)

In [16]:
eval_iters = len(val_dataset) // (32 * 2048) + 1

In [17]:
vocab.decode(batch[0][0][200:250])

['Q1',
 'LAB//50885//MG/DL',
 'Q1',
 'LAB//50893//MG/DL',
 'Q2',
 'LAB//50902//MEQ/L',
 'Q4',
 'LAB//50908//%',
 'Q4',
 'LAB//50910//IU/L',
 'Q9',
 'LAB//50911//NG/ML',
 'Q7',
 'LAB//50912//MG/DL',
 'Q8',
 'LAB//50931//MG/DL',
 'Q8',
 'LAB//50960//MG/DL',
 'Q4',
 'LAB//50970//MG/DL',
 'Q10',
 'LAB//50971//MEQ/L',
 'Q10',
 'LAB//50983//MEQ/L',
 'Q1',
 'LAB//50993//UIU/ML',
 'Q9',
 'LAB//51003//NG/ML',
 'Q7',
 'LAB//51006//MG/DL',
 'Q10',
 'LAB//51237//UNK',
 'Q5',
 'LAB//51274//SEC',
 'Q7',
 'LAB//51275//SEC',
 'Q8',
 'LAB//51221//%',
 'Q4',
 'LAB//51222//G/DL',
 'Q3',
 'LAB//51248//PG',
 'Q3',
 'LAB//51249//%',
 'Q2',
 'LAB//51250//FL',
 'Q5',
 'LAB//51265//K/UL',
 'Q7',
 'LAB//51277//%']

In [21]:
# vocab.decode(batch[1][1])

In [19]:
batch[1]

tensor([[-100, -100, -100,  ...,   24,  109,   59],
        [-100, -100, -100,  ...,   73,   28,   74],
        [-100, -100, -100,  ...,   31,   44,   27],
        ...,
        [-100, -100, -100,  ...,   34,   42,   30],
        [-100, -100, -100,  ...,   35,   24,   96],
        [-100, -100, -100,  ...,  119,   31,  155]])

### model

In [24]:
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel, GPT2Config

In [25]:
config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=2048,
    n_embd=64,
    n_layer=1, ## change this stuff later if the model is bad
    n_head=4,
    n_inner=None,
    activation_function="gelu",
    resid_pdrop=0, ## change this stuff later if the model is bad
    embd_pdrop=0, ## change this stuff later if the model is bad
    attn_pdrop=0, ## change this stuff later if the model is bad
    bias=False, # model doesn't perform well without bias
)

In [26]:
## model.py
import math
from collections import namedtuple
from functools import lru_cache

import torch
import torch.nn as nn
import transformers.activations
from torch.nn import functional as F
from transformers import GPT2Config

ModelOutput = namedtuple("ModelOutput", ["loss", "logits"])


class CausalSelfAttention(nn.Module):
    def __init__(self, config, attention_weights: list | None = None):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.attn_pdrop
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if not self.flash or attention_weights is not None:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(config.n_positions, config.n_positions)).view(
                    1, 1, config.n_positions, config.n_positions
                ),
                persistent=False,
            )
        self.attention_weights = attention_weights

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the
        # batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash and self.attention_weights is None:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=True,
            )
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
            self.attention_weights.append(att.detach().cpu())
        y = (
            y.transpose(1, 2).contiguous().view(B, T, C)
        )  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.activation = transformers.activations.get_activation(config.activation_function)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.activation(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, config, attention_weights: list | None = None):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config, attention_weights=attention_weights)
        self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT2LMNoBiasModel(nn.Module):
    def __init__(
        self,
        config: GPT2Config,
        return_attention=False,
    ):
        super().__init__()
        self.config = config

        self.return_attention = return_attention
        self.attention_weights = [] if return_attention else None

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embd),
                wpe=nn.Embedding(config.n_positions, config.n_embd),
                drop=nn.Dropout(config.embd_pdrop),
                h=nn.ModuleList(
                    [Block(config, self.attention_weights) for _ in range(config.n_layer)]
                ),
                ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        # init all weights
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        pos = torch.arange(0, config.n_positions, dtype=torch.long)
        self.register_buffer("pos", pos, persistent=False)

    @staticmethod
    def _init_weights(module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    @lru_cache
    def num_parameters(self, exclude_embeddings=True):
        n_params = sum(p.numel() for p in self.parameters())
        if exclude_embeddings:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def forward(self, input_ids, labels=None) -> ModelOutput:
        _, t = input_ids.size()
        if self.return_attention:
            self.attention_weights.clear()

        tok_emb = self.transformer.wte(input_ids)
        pos_emb = self.transformer.wpe(self.pos[:t])
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if labels is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return ModelOutput(loss=loss, logits=logits)

    @torch.no_grad()
    def get_next_token(self, x: torch.Tensor, return_probs: bool = False, top_k: int | None = None):
        logits = self(x).logits
        logits = logits[:, -1, :]
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float("Inf")
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        if return_probs:
            return next_token, probs
        return next_token

In [27]:
model = GPT2LMNoBiasModel(config)

In [28]:
model

GPT2LMNoBiasModel(
  (transformer): ModuleDict(
    (wte): Embedding(4480, 64)
    (wpe): Embedding(2048, 64)
    (drop): Dropout(p=0, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=64, out_features=192, bias=False)
          (c_proj): Linear(in_features=64, out_features=64, bias=False)
          (attn_dropout): Dropout(p=0, inplace=False)
          (resid_dropout): Dropout(p=0, inplace=False)
        )
        (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=64, out_features=256, bias=False)
          (activation): GELUActivation()
          (c_proj): Linear(in_features=256, out_features=64, bias=False)
          (dropout): Dropout(p=0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_fe

In [29]:

## training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# scaler
scaler = torch.amp.GradScaler("float16")


In [30]:
# optimizer
optimizer = configure_optimizers(
    raw_model, cfg.weight_decay, cfg.lr, (cfg.beta1, cfg.beta2), device
)

NameError: name 'configure_optimizers' is not defined

In [31]:
from src.tokenize.vocabulary import Vocabulary
from src.tokenize.constants import SpecialToken as ST

In [32]:
vocab = Vocabulary("/workspace/ehr_stuff/EHR_FM/data/tokenized_datasets/mimic_train/vocab_t4432.csv")

In [33]:
vocab_size = (len(vocab) // 64 + 1) * 64 if len(vocab) % 64 != 0 else len(vocab)
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 64


In [34]:
vocab_size = len(vocab)

In [35]:
tokens_of_interest = [ST.DEATH, ST.ADMISSION, ST.DISCHARGE]

In [36]:
tokens_of_interest = {stoken: vocab.encode(stoken) for stoken in tokens_of_interest}


KeyError: <SpecialToken.DEATH: 'MEDS_DEATH'>

In [37]:
vocab.encode("MEDS_DEATH")

KeyError: 'MEDS_DEATH'

In [40]:
vocab.stoi("MEDS_DEATH")

TypeError: 'dict' object is not callable