<a href="https://colab.research.google.com/github/logannye/logannye/blob/main/clinical_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

### ------------------------------------------------------------------------- ###
### -------------------------- HYPERPARAMETERS ------------------------------ ###
### ------------------------------------------------------------------------- ###

batch_size = 128 # how many independent sequences we will process in parallel
block_size = 512 # the maximum context length for predictions
max_iters = 10000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6        # n_embd / n_head = # of dimensions per head
dropout = 0.4
n_layer = 8

torch.manual_seed(1738) # Fetty, baby
device


'cuda'

In [None]:
### ------------------------------------------------------------------------- ###
### --------------------------- DATASET SPECS ------------------------------- ###
### ------------------------------------------------------------------------- ###

# you can use wget to download the dataset from the web here
# read in your text corpus from the .txt file in your directory
with open('/content/clinical_notes.txt', 'r', encoding='utf-8') as f: # https://github.com/socd06/medical-nlp this dataset is a corpus of ~5000 clinical notes
  text = f.read()

# here are all the unique characteers that occur in this corpus of text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: takes  string, outputs a list of ints
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: taks a list of ints and outputs a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.85*len(data)) # this can change depending on your taste. Here we're doing 85:15 splits
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
  # make a small batch of data inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix]) # stacks up the 1-dimensional tensors of samples into rows of x
  y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # stacks up the targets in y
  x = x.to(device)  # Move x to your device (specified in header - GPU vs CPU)
  y = y.to(device)  # Move y to your device as well

  return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [None]:
### ------------------------------------------------------------------------- ###
### ------------------------ SELF ATTENTION HEAD ---------------------------- ###
### ------------------------------------------------------------------------- ###

class Head(nn.Module):
  """ one head of self-attention """
  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B, T, C = x.shape
    k = self.key(x) # (B,T,16)
    q = self.query(x) # (B,T,16)
    # compute attention scores ("affinities")
    wei = q @ k.transpose(-2,-1) * (C ** -0.5) # (B,T,C) @ (B,C,T) -> (B,T,T)
    wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # (B,T,T)
    wei = F.softmax(wei, dim=-1) # (B,T,T)
    wei = self.dropout(wei)
    # perform the weighted aggreegation of values
    v = self.value(x) # (B,T,C)
    out = wei @ v # (B,T,C) @ (B,T,C) -> (B,T,C)
    return out


### ------------------------------------------------------------------------- ###
### ------------------------ MULTI-HEAD ATTENTION --------------------------- ###
### ------------------------------------------------------------------------- ###


class MultiHeadAttention(nn.Module):
  """ multiple heads of self-attention in parallel"""

  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(num_heads * head_size, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1) # concatenate the outputs
    out = self.proj(out) # linearly transform the concatenated outputs in last layer
    return out


### ------------------------------------------------------------------------- ###
### ------------------------ FEED FORWARD NETWORK --------------------------- ###
### ------------------------------------------------------------------------- ###


class FeedForward(nn.Module):
  """ a simple linear layer following by a non-linearity """

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(n_embd, 4*n_embd),
      nn.ReLU(),
      nn.Linear(4*n_embd, n_embd), # projection layer going back into residual pathway
      nn.Dropout(dropout), # dropout for regularization
    )

  def forward(self, x):
    return self.net(x)


### ------------------------------------------------------------------------- ###
### ------------------------- TRANSFORMER BLOCKS ---------------------------- ###
### ------------------------------------------------------------------------- ###


class Block(nn.Module):
  """ Transformer block: communication followed by computation """

  def __init__(self, n_embd, n_head):
    # n_embd: embedding dimension, n_head: number of heads we'd like
    super().__init__()
    head_size = n_embd // n_head # should be 8 in our case
    self.sa = MultiHeadAttention(n_head, head_size) # communication
    self.ffwd = FeedForward(n_embd) # computation
    self.layer_norm1 = nn.LayerNorm(n_embd)
    self.layer_norm2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.layer_norm1(x)) # fork off, communicate, and then add back
    x = x + self.ffwd(self.layer_norm2(x)) # fork off, compute, and then add back
    return x


In [None]:
config = {
    'batch_size': batch_size, # Increase this if your GPU can handle it
    'lr': learning_rate,
    'epochs': max_iters,
}

In [None]:
!pip install wandb --quiet
import wandb
wandb.login(key="8198cb5f5316ad7597e44dcf4e6b5d063b12f7e8") # API Key is in your wandb account, under settings (wandb.ai/settings)

# Create your wandb run
run = wandb.init(
    name = "Testing-my-Transformer", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "hw2p2-ablations", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

[34m[1mwandb[0m: Currently logged in as: [33mlnye[0m ([33m11-785project[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
### ------------------------------------------------------------------------- ###
### --------------------------- LANGUAGE MODEL ------------------------------ ###
### ------------------------------------------------------------------------- ###


class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = GPT()
m = model.to(device)

# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

for iter in range(max_iters):

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # Log losses to WandB
        wandb.log({"train_loss": losses['train'], "val_loss": losses['val'], "step": iter})

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=1000)[0].tolist()))

14.467693 M parameters
step 0: train loss 4.7714, val loss 4.7747
step 500: train loss 2.1342, val loss 2.1948
step 1000: train loss 1.4814, val loss 1.5642
step 1500: train loss 1.1967, val loss 1.2757
step 2000: train loss 1.0520, val loss 1.1331
step 2500: train loss 0.9797, val loss 1.0579
step 3000: train loss 0.9261, val loss 1.0083
step 3500: train loss 0.8868, val loss 0.9696
step 4000: train loss 0.8573, val loss 0.9431
step 4500: train loss 0.8333, val loss 0.9174
step 5000: train loss 0.8123, val loss 0.8977
step 5500: train loss 0.7924, val loss 0.8840
step 6000: train loss 0.7763, val loss 0.8645
step 6500: train loss 0.7618, val loss 0.8500
step 7000: train loss 0.7509, val loss 0.8442
step 7500: train loss 0.7380, val loss 0.8292
step 8000: train loss 0.7235, val loss 0.8196
step 8500: train loss 0.7138, val loss 0.8106
step 9000: train loss 0.7043, val loss 0.8015
step 9500: train loss 0.6954, val loss 0.7929

HISTORY: Herpefal medical attempts and reviewed alcohold by 

Fix your weights and biases run. And change the model name to 'GPT' instead of bigramLanguageModel

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=3000)[0].tolist()))


 PNeurosurble done 2/18/96, thrombed 9/6 to 28/7/92 with colonoscopy felt wires that is three year and than 8/30 minutes has previously dularied atrium and tenderness.,SENSORY:  ,Fever Analged and nursing facets have fever any limited mitral odoing listed as necessary.  There is no repropressive articular symptoms in the emergency deficits at home.  No thyros only the central condyle, attending into the emergency room and in the emergency depidurity of the emergency department with school XVA, with burition CT scan on 08/10/29, rapuls with right-sided lumbar puncta-by pulmonary sepsis.  There is no concerns of pancreas in the emergency desiccation to the intentional.  Examination reveals a rectal latent latencille, but it showed attended forms and it show any fasciculation in the apex of the carotid ointmenant.  Memoning is covered before being of Studie.,PLAN: , The risks, benefits, alternatives of the transfusion, common unchange by CSF.,PROCEDURE IN DETAIL: , After appropriate size

In [None]:
### ------------------------------------------------------------------------- ###
### -------------------------- TEXT OUTPUT FILE ----------------------------- ###
### ------------------------------------------------------------------------- ###

# Uncomment to save the output to a new file
!mkdir '/content/data'
open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=5000)[0].tolist()))

mkdir: cannot create directory ‘/content/data’: File exists


5001

In [None]:
# save the model
torch.save(model.state_dict(), '/content/clinicalGPT_decoder_only')

In [None]:
from google.colab import files

# Download the file to your local machine
files.download('/content/clinicalGPT_decoder_only')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

LATER ON, IF YOU WANT TO RELOAD THE MODEL AND WEIGHTS

i.e., for inference or to continue training

In [None]:
# Initialize the model
model = BigramLanguageModel()  # Make sure this is the same model architecture

# Load the saved state_dict
model.load_state_dict(torch.load('path_to_save/your_model_name.pth'))

# If you're doing inference only, switch to evaluation mode
model.eval()
