# Demo Notebook:
## Temporal Point Process Transformer For Causal Sequence Modelling 

Including time (and TODO: values)

In [1]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TPP.task_heads.causal_tpp import TPPTransformerForCausalSequenceModelling

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling


## Build configurations

In [2]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [3]:
from CPRD.data.database import queries

# Get a list of patients which fit a reduced set of criterion
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()
# identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
all_identifiers = identifiers2

if False:
    # Lets take only the first N for faster run-time
    N = np.min((len(all_identifiers), 10000))
    print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
    identifiers = random.choices(all_identifiers, k=N)
else:
    print(f"Using all available {len(all_identifiers)} samples")
    identifiers = all_identifiers

# Build 
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="tabular",    # non-
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            include_measurements=False,
                            include_diagnoses=True
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building DL-friendly representation


Using all available 129717 samples


INFO:root:Dropping samples with no dynamic events
INFO:root:Using tabular tokenizer


107680 training samples
5983 validation samples
5982 test samples
74 vocab elements


In [14]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (73, 3)
┌───────────────────────────────────┬────────┬──────────┐
│ EVENT                             ┆ counts ┆ freq     │
│ ---                               ┆ ---    ┆ ---      │
│ str                               ┆ u32    ┆ f64      │
╞═══════════════════════════════════╪════════╪══════════╡
│ UNK                               ┆ 0      ┆ 0.0      │
│ DEPRESSION                        ┆ 86544  ┆ 0.177533 │
│ ANXIETY                           ┆ 48204  ┆ 0.098884 │
│ HYPERTENSION                      ┆ 32877  ┆ 0.067442 │
│ TYPE2DIABETES                     ┆ 27261  ┆ 0.055922 │
│ OSTEOARTHRITIS                    ┆ 25362  ┆ 0.052027 │
│ ASTHMA_PUSHASTHMA                 ┆ 25296  ┆ 0.051891 │
│ ATOPICECZEMA                      ┆ 23197  ┆ 0.047585 │
│ ALLERGICRHINITISCONJ              ┆ 18510  ┆ 0.037971 │
│ ANY_DEAFNESS_HEARING_LOSS         ┆ 17244  ┆ 0.035374 │
│ PREVALENT_IBS                     ┆ 11714  ┆ 0.02403  │
│ ALLCA_NOBCC_VFINAL                ┆ 11648  ┆ 0.023894 │

## Create models and train

In [4]:
models, m_names = [], []

# My development model
for tte_layer in ["Geometric", "Exponential"]:
    config = DemoConfig()
    config.TTELayer = tte_layer
    models.append(TPPTransformerForCausalSequenceModelling(config, vocab_size).to(device))
    m_names.append(f"TPPTransformerForCausalSequenceModelling_{tte_layer}")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using GeometricTTELayer. This module predicts the time until next event as a geometric distribution, supported on the set {0,1,...}
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution


In [5]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]

In [6]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss = 0, 0, 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                     ages=batch['ages'].to(device), 
                                                     # values=batch['values'].to(device),
                                                     attention_mask=batch['attention_mask'].to(device)   
                                                     )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss = 0, 0, 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                             ages=batch['ages'].to(device),
                                                             # values=batch['values'].to(device),
                                                             attention_mask=batch['attention_mask'].to(device)   
                                                             )
                    val_loss += loss.item()
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    # values = torch.tensor([[torch.nan]], device=device)
    # generate: sample the next 10 tokens
    new_tokens, new_ages = model.generate(tokens, ages, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    #    note, Not considering value yet.
    for _cat, _age in zip(generated.split(" "), new_ages[0, :]):
        print(f"\t {_cat} at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}


Training model `TPPTransformerForCausalSequenceModelling_Geometric`, with 10.669633 M parameters
Epoch 0:	Train loss 5.46  (3.10, 7.82). Val loss 5.44 (3.02, 7.86)
Epoch 1:	Train loss 5.37  (2.95, 7.78). Val loss 5.41 (2.97, 7.84)
Epoch 2:	Train loss 5.34  (2.90, 7.77). Val loss 5.38 (2.93, 7.84)
Epoch 3:	Train loss 5.32  (2.87, 7.77). Val loss 5.37 (2.91, 7.84)
Epoch 4:	Train loss 5.31  (2.86, 7.76). Val loss 5.37 (2.90, 7.83)
Epoch 5:	Train loss 5.31  (2.85, 7.76). Val loss 5.36 (2.89, 7.83)
Epoch 6:	Train loss 5.29  (2.83, 7.76). Val loss 5.36 (2.88, 7.84)
Epoch 7:	Train loss 5.29  (2.82, 7.76). Val loss 5.35 (2.86, 7.83)
Epoch 8:	Train loss 5.28  (2.80, 7.76). Val loss 5.34 (2.85, 7.83)
Epoch 9:	Train loss 5.27  (2.79, 7.75). Val loss 5.33 (2.84, 7.83)
Epoch 10:	Train loss 5.27  (2.78, 7.75). Val loss 5.34 (2.85, 7.83)
Epoch 11:	Train loss 5.27  (2.78, 7.75). Val loss 5.33 (2.83, 7.83)
Epoch 12:	Train loss 5.26  (2.77, 7.75). Val loss 5.33 (2.83, 7.83)
Epoch 13:	Train loss 5.26  (2



Epoch 24:	Train loss 5.23  (2.73, 7.74). Val loss 5.33 (2.81, 7.85)
	 DEPRESSION at age 20 (7300.0 days)
	 ANXIETY at age 20 (7447.0 days)
	 PREVALENT_IBS at age 22 (8199.0 days)
	 ALCOHOLMISUSE at age 27 (9723.0 days)
	 OTHER_CHRONIC_LIVER_DISEASE_OPTIMAL at age 33 (12045.0 days)
	 EATINGDISORDERS at age 35 (12911.0 days)
	 STROKEUNSPECIFIED at age 38 (13833.0 days)
	 ENDOMETRIOSIS_ADENOMYOSIS_V2 at age 40 (14687.0 days)
	 ALLCA_NOBCC_VFINAL at age 41 (14878.0 days)
	 OSTEOARTHRITIS at age 44 (15881.0 days)
	 HAEMOCHROMATOSIS at age 45 (16504.0 days)
Training model `TPPTransformerForCausalSequenceModelling_Exponential`, with 10.669633 M parameters
Epoch 0:	Train loss 2.25  (3.06, 1.43). Val loss 2.22 (3.02, 1.43)
Epoch 1:	Train loss 2.18  (2.95, 1.41). Val loss 2.20 (2.97, 1.43)
Epoch 2:	Train loss 2.15  (2.90, 1.41). Val loss 2.17 (2.91, 1.42)
Epoch 3:	Train loss 2.13  (2.87, 1.40). Val loss 2.16 (2.91, 1.42)
Epoch 4:	Train loss 2.13  (2.86, 1.40). Val loss 2.15 (2.90, 1.41)
Epoch 5:



Epoch 32:	Train loss 2.04  (2.71, 1.37). Val loss 2.11 (2.81, 1.42)
	 DEPRESSION at age 20 (7300.0 days)
	 ANXIETY at age 33 (12092.6 days)
	 ASTHMA_PUSHASTHMA at age 48 (17498.3 days)
	 OSTEOARTHRITIS at age 49 (17802.3 days)
	 HYPERTENSION at age 53 (19384.4 days)
	 SUBSTANCEMISUSE at age 54 (19769.8 days)
	 PREVALENT_IBS at age 54 (19771.3 days)
	 ALLERGICRHINITISCONJ at age 64 (23256.4 days)
	 ALLCA_NOBCC_VFINAL at age 65 (23618.5 days)
	 ANY_DEAFNESS_HEARING_LOSS at age 67 (24484.4 days)
	 ALCOHOLMISUSE at age 73 (26482.1 days)


## Comparing output to real data

In [7]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, age) in enumerate(zip(conditions[0], batch["ages"][0,:])):
    if token == 0 or idx >= 10:
        break
    print(f"{dm.decode([token])}, at age {age/365:.0f} ({age:.1f} days)")

ATOPICECZEMA, at age 68 (24782.0 days)
HYPERTENSION, at age 71 (25853.0 days)
TYPE2DIABETES, at age 73 (26609.0 days)
CKDSTAGE3TO5, at age 74 (27156.0 days)
HYPOTHYROIDISM_DRAFT_V1, at age 76 (27877.0 days)
GOUT, at age 84 (30580.0 days)
AF, at age 85 (31115.0 days)
PMRANDGCA, at age 85 (31127.0 days)
IHD_NOMI, at age 86 (31321.0 days)


In [8]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/loss_curves_TPP.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/loss_clf_curves_TPP.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/loss_tte_curves_TPP.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [9]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
            model.eval()
    
            for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
                print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
                encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
                (lgts, tte_dist), _, _ = model(encoded_prompt,
                                                       # values=torch.tensor(value).to(device),
                                                       ages=to_days(age),
                                                       is_generation=True)
                probs = torch.nn.functional.softmax(lgts, dim=2)
                print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
                print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



TPPTransformerForCausalSequenceModelling_Geometric
--------------------------------------

Control: 	 (DEPRESSION): 
	probability of type I diabetes: 0.0738%
	probability of type II diabetes: 0.2351%

Type 1: 	 (DEPRESSION,TYPE1DM): 
	probability of type I diabetes: 0.0884%
	probability of type II diabetes: 2.5615%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
	probability of type I diabetes: 1.7410%
	probability of type II diabetes: 0.1133%


TPPTransformerForCausalSequenceModelling_Exponential
--------------------------------------

Control: 	 (DEPRESSION): 
	probability of type I diabetes: 0.0653%
	probability of type II diabetes: 0.2518%

Type 1: 	 (DEPRESSION,TYPE1DM): 
	probability of type I diabetes: 0.3059%
	probability of type II diabetes: 2.1283%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
	probability of type I diabetes: 3.1414%
	probability of type II diabetes: 0.0096%


## Age: How increasing prompt age affects likelihood of age related diagnoses

In [10]:
prompt = ["DEPRESSION"]
ages = [[4],[8],[20],[30],[60],[80],[90]]

# target_conditions=["TYPE1DM"]#, "TYPE2DIABETES", "OSTEOARTHRITIS", "ANY_DEAFNESS_HEARING_LOSS"]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, age in enumerate(ages):
        print(f"\nAge {age[-1]}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist), _, _ = model(encoded_prompt,
                                       ages=to_days(age),
                                       is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100

        # top K
        k = 10
        print(f"Top {k}")
        topk_prob, topk_ind = torch.topk(probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {j:.2f}%")

        # bottom K
        k = 30
        print(f"Bottom {k}")
        topk_prob, topk_ind = torch.topk(-probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {-j:.2f}%")
        
            # print(f"Age: {age[-1]} years old:  {100*float(probs[0, 0, target_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



TPPTransformerForCausalSequenceModelling_Geometric
--------------------------------------

Age 4
Top 10
	ANXIETY: 32.00%
	ATOPICECZEMA: 22.54%
	ASTHMA_PUSHASTHMA: 9.23%
	ALLERGICRHINITISCONJ: 7.06%
	EATINGDISORDERS: 4.33%
	ANY_DEAFNESS_HEARING_LOSS: 4.17%
	PREVALENT_IBS: 2.97%
	SUBSTANCEMISUSE: 2.40%
	ALCOHOLMISUSE: 1.97%
	AUTISM: 1.74%
Bottom 30
	UNK: 0.00%
	PAD: 0.00%
	SICKLE_CELL_DISEASE: 0.00%
	CHRONIC_LIVER_DISEASE_ALCOHOL: 0.00%
	CYSTICFIBROSIS: 0.00%
	PLASMACELL_NEOPLASM: 0.00%
	AORTICANEURYSM: 0.00%
	ILD_SH: 0.00%
	SYSTEMIC_SCLEROSIS: 0.00%
	HAEMOCHROMATOSIS: 0.01%
	PARKINSONS: 0.01%
	NAFLD: 0.01%
	DOWNSSYNDROME: 0.01%
	SJOGRENSSYNDROME: 0.01%
	PSORIATICARTHRITIS2021: 0.01%
	MINFARCTION: 0.01%
	PAD_STRICT: 0.01%
	PMRANDGCA: 0.01%
	ALL_DEMENTIA: 0.02%
	AF: 0.02%
	MENIERESDISEASE: 0.02%
	PERNICIOUSANAEMIA: 0.02%
	ISCHAEMICSTROKE: 0.02%
	ADDISON_DISEASE: 0.02%
	CKDSTAGE3TO5: 0.03%
	SYSTEMIC_LUPUS_ERYTHEMATOSUS: 0.03%
	LEUKAEMIA_PREVALENCE: 0.03%
	BRONCHIECTASIS: 0.03%
	OTHER_CHR

# Appendix: model architectures

In [11]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



TPPTransformerForCausalSequenceModelling_Geometric


TPPTransformerForCausalSequenceModelling(
  (transformer): TPPTransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): Embedding(74, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=True)
          (acti): Re