### Next-Order-Set Transformer

In this doc I start builing a machine-translation transformer which will learn a mapping from `O_t` to `O_t+1` **witout** a multimodal component (i.e. no outcomes or patient characteristics). 

Based on the translation transformer: https://pytorch.org/tutorials/beginner/translation_transformer.html

In [1]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
import pickle
import pandas as pd
import yaml
from order_path_dataset import OrderPathDataset
from order_path_dataset import OrderPathProcessing
from torch.utils.data import DataLoader


In [2]:
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        # TODO: Eliminate the positional encoding of tgt inputs: 
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)
        
    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# NOTE: We do not need to positionally encode the src tokens since we're set-wise identified
class NullPositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(NullPositionalEncoding, self).__init__()
        # TODO: Eliminate the positional encoding of tgt inputs: 
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding)

In [3]:
# NOTE: We don't want to positionally encode the input sequence, just apply drop-out. 
# TODO: Check that the dim of the output here still makes sense. 
class ApplyDropout(nn.Module):
    def __init__(self,
                 dropout: float):
        self.dropout = nn.Dropout(dropout)
    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding)

In [4]:
# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size, padding_idx: int):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx)
        self.emb_size = emb_size
        self.padding_idx = padding_idx

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [5]:
class PatientEmbedding(nn.Module):
    def __init__(self, input_dim: int, emb_size: int):
        super().__init__()
        self.linear_embedding = nn.Linear(input_dim, emb_size)
        
    def forward(self, x: Tensor):
        # Keep the sqrt() for now. Seems like it applies here too. 
        x = x.to(torch.float32)
        return self.linear_embedding(x)

In [6]:
# token masking
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device='cpu')) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    # NOTE: our source mask is FALSE everywhere since the model can attend to the entire set
    src_mask = torch.zeros((src_seq_len, src_seq_len),device='cpu').type(torch.bool)
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [7]:
# Seq2Seq Network
# This is the standard pytorch transformer model for translation tasks
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 d_model: int,
                 seq_length: int, 
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 pat_emb_size: int,
                 pat_input_dim: int, 
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int, 
                 padding_idx: int, 
                 dim_feedforward: int, 
                 dropout: float):
        super(Seq2SeqTransformer, self).__init__()
        self.d_model = d_model
        self.emb_size = emb_size
        self.seq_length = seq_length # NOTE: This is auto-pop. for the standard embeddings given the O-seq-length
        # Need this for the patient X emb. which needs to match the seq-length for repetition 
        self.transformer = Transformer(d_model=d_model,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        # may want a custom embedding class, for now keep this and ensure its learned. 
        self.src_ord_emb = TokenEmbedding(src_vocab_size, 
                                          emb_size, 
                                          padding_idx)

        # fix the embedding for the missing order outcomes to 0 vector (or something)
        self.src_res_emb = TokenEmbedding(src_vocab_size, 
                                          emb_size, 
                                          padding_idx)
        
        # patinput_dim = opd.__getitem__(0)[2][0].long().shape = 841
        self.pat_cov_emb = PatientEmbedding(946, 
                                            pat_emb_size)

        # weighted sum params. are learnable
        self.alpha_o = torch.nn.Parameter(torch.randn(1))
        # self.alpha_o.requires_grad = True
        self.alpha_r = torch.nn.Parameter(torch.randn(1))
        # self.alpha_r.requires_grad = True

        # The target token embeddings stay the same
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, 
                                          emb_size, 
                                          padding_idx)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,
                orders: Tensor, 
                results: Tensor, 
                pat_cov: Tensor, 
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):

        print("Order:", orders.shape)
        print("Results:", results.shape)
        # Multimodal embedding: adding each source embedding w. weights from above
        src_emb = torch.add(torch.mul(self.alpha_o, self.src_ord_emb(orders)), 
                            torch.mul(self.alpha_r, self.src_res_emb(results)))  
        print("SRC Emb:", src_emb.shape)

        
        # add pat. embedding: 
        src_pat_emb = self.pat_cov_emb(pat_cov)
        src_pat_emb = src_pat_emb.unsqueeze(0).repeat(self.seq_length, 1, 1)
        print("PAT Emb:", src_pat_emb.shape)

        src_emb = torch.add(src_emb, src_pat_emb)
        # NOTE: need to call tensors to cuda() 
        # src_emb = torch.add(self.src_ord_emb(orders), 
        #                     self.src_res_emb(results))   
        tgt_emb = self.tgt_tok_emb(trg)
        outs = self.transformer(src_emb, 
                                tgt_emb, 
                                src_mask, 
                                tgt_mask, 
                                None,
                                src_padding_mask, 
                                tgt_padding_mask, 
                                memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src_ord: Tensor, src_res: Tensor, src_mask: Tensor):
        # Multimodal encoder: 
        src_emb = torch.add(torch.mul(self.alpha_o, self.src_ord_emb(src_ord)), 
                            torch.mul(self.alpha_r, self.src_res_emb(src_res)))  
        return self.transformer.encoder(self.dropout(src_emb), 
                                        src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.dropout(self.tgt_tok_emb(tgt)), 
                                        memory,
                                        tgt_mask)

#### Data-config: 

In [8]:
# Loading config: 
with open("/wynton/protected/home/rizk-jackson/jknecht/order_path_prediction/experiments/scripts/config_005.yaml", "r") as f:
    config = yaml.safe_load(f)

path = config["path"]

In [9]:
torch.manual_seed(config["seed"])

<torch._C.Generator at 0x7ff3a7f6d7b0>

In [10]:
# Process order-path-data, this takes a minute. 
opd_processed = OrderPathProcessing(config, path)
train_val_test = opd_processed._split_encounters(seed = config["seed"])

...NOTE: You are subsetting to only DX orders...
...Got new vocab mapping...


In [11]:
# Loading processed: 
opd_processed = torch.load((config["data_loader_path"] + "opd_processed.pt"))
opd_train = torch.load((config["data_loader_path"] + "opd_train.pt"))
opd_val = torch.load((config["data_loader_path"] + "opd_val.pt"))

In [None]:
# Testing longest set: 
_length_of_sets_train = [torch.count_nonzero(_order_set[0], dim=0).item() for _order_set in opd_train._order_pairs_list]
_length_of_sets_val = [torch.count_nonzero(_order_set[0], dim=0).item() for _order_set in opd_val._order_pairs_list]

In [None]:
print(max(_length_of_sets_train), max(_length_of_sets_val))

In [12]:
# Number of training pairs:
len(opd_train._order_pairs_list) 

395303

In [15]:
n_orders = sum([torch.count_nonzero(_order_set[0], dim=0).item() for _order_set in opd_train._order_pairs_list])

In [16]:
n_orders
# 1,603,290 orders in all the pairs. This is obviously double counting pairs. 
# So really the raw # of orders is less than this 

1603290

In [18]:
len(opd_processed._order_data)
# 1,418,716 unique orders across entire data-set 
# We have 13,200,000 unique orders in the CDW data, to approx. x10 

1418716

In [132]:
# Init. datasets for train/val:
opd_train = OrderPathDataset(config = config, 
                             path = path, 
                             order_df = opd_processed._order_data, 
                             results_df = opd_processed._outcomes_data, 
                             patient_df = opd_processed._patient_data,
                             encounter_ids = train_val_test["train_encounter_ids"][0:10])
# Collating encounters and returning training df: 
opd_train._collate_encounter_orders()
opd_train._collate_encounter_results()
opd_train._collate_encounter_patient_data()

opd_val = OrderPathDataset(config = config, 
                           path = path, 
                           order_df = opd_processed._order_data, 
                           results_df = opd_processed._outcomes_data, 
                           patient_df = opd_processed._patient_data,
                           encounter_ids = train_val_test["val_encounter_ids"][0:10])
# Collating encounters and returning training df: 
opd_val._collate_encounter_orders()
opd_val._collate_encounter_results()
opd_val._collate_encounter_patient_data()


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 27.71it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 28.52it/s]


In [196]:
# Checking shape of pat X is 946
# [x[0].shape for x in opd_train._patient_characteristics_pairs_list]

[torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),
 torch.Size([946]),


#### Init model

In [133]:
def get_unique_orders(order_df):
    # Returns unique order on id
    return order_df.dropna(subset=['encounter_id', 'order_types_id']).order_types_id.unique()

In [180]:
_order2idx = opd_train._order2idx
_idx2order = opd_train._idx2order
unique_orders = opd_processed.unique_orders # From the processed data
SRC_VOCAB_SIZE = len(_order2idx)
TGT_VOCAB_SIZE = len(_order2idx)
# Model params from config: 
SEQ_LENGTH = config["max_seq_length"]
MODEL_DIM = config["model_params"]["MODEL_DIM"]
EMB_SIZE = config["model_params"]["EMB_SIZE"]
NHEAD = config["model_params"]["NHEAD"]
FFN_HID_DIM = config["model_params"]["FFN_HID_DIM"]
BATCH_SIZE = config["model_params"]["BATCH_SIZE"]
NUM_ENCODER_LAYERS = config["model_params"]["NUM_ENCODER_LAYERS"]
NUM_DECODER_LAYERS = config["model_params"]["NUM_DECODER_LAYERS"]
DROPOUT =  config["model_params"]["DROPOUT"]
# Define special symbols and indices:
PAD_IDX = config["PAD_idx"]
BOS_IDX = config["BOS_idx"]
EOS_IDX = config["EOS_idx"]
SEP_IDX = config["SEP_idx"] # Otherwise this is 4 
UNK_IDX = config["UNK_idx"]

transformer = Seq2SeqTransformer(d_model = MODEL_DIM, 
                                 seq_length = SEQ_LENGTH,
                                 num_encoder_layers = NUM_ENCODER_LAYERS, 
                                 num_decoder_layers = NUM_DECODER_LAYERS, 
                                 emb_size = EMB_SIZE,
                                 pat_emb_size = EMB_SIZE, 
                                 pat_input_dim = 946, 
                                 nhead = NHEAD, 
                                 src_vocab_size = SRC_VOCAB_SIZE,
                                 tgt_vocab_size = TGT_VOCAB_SIZE, 
                                 padding_idx = PAD_IDX, 
                                 dim_feedforward = FFN_HID_DIM,
                                 dropout = DROPOUT)

In [181]:
MODEL_DIM

256

In [182]:
# Applyung U-transform to high-dim parameters: 
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [188]:
# Loss and g-optimizer: 
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
# NOTE: need to think about using sparse optimizers given the signficant sparsity in inputs
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#### Training Loop: 

In [184]:
# function to collate data samples into batch tensors
# def collate_batch_fn(batch):
#     src_order_batch, tgt_order_batch = [], []
#     src_result_batch = []
    
#     for pair in batch: 
#         src_order_batch.append(pair[0][0])
#         src_result_batch.append(pair[1][0])
#         tgt_order_batch.append(pair[0][1])
        
#     return (src_order_batch, src_result_batch, tgt_order_batch)

# def _get_path_tensors(batch):
#     # pairing is handled by the data-loader
#     # we abstract away from the encounter level so now its a single call 
#     src_order_batch = batch[0] 
#     src_result_batch = batch[1]
#     tgt_order_batch = batch[2]
#     return src_order_batch, src_result_batch, tgt_order_batch

def collate_batch_fn(batch):
    src_order_batch, tgt_order_batch, src_result_batch, src_pat_batch = [], [], [], []
    
    for pair in batch: 
        src_order_batch.append(pair[0][0])
        tgt_order_batch.append(pair[0][1])
        src_result_batch.append(pair[1][0])
        src_pat_batch.append(pair[2][0])
        
    return (src_order_batch, src_result_batch, src_pat_batch, tgt_order_batch)

def _get_path_tensors(batch):
    # pairing is handled by the data-loader
    # we abstract away from the encounter level so now its a single call 
    src_order_batch = batch[0] 
    src_result_batch = batch[1]
    src_pat_batch = batch[2]
    tgt_order_batch = batch[3]
    return src_order_batch, src_result_batch, src_pat_batch, tgt_order_batch

In [191]:
def train_mm_epoch(model, optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(opd_train, 
                                  batch_size=10, 
                                  collate_fn=collate_batch_fn)
    # Returns single batch at a time: 
    for batch in train_dataloader:
        # NOTE: need to collect src and tgt order sets for the batch here: 
        src_ord, src_res, src_pat, tgt_ord = _get_path_tensors(batch)
        # need to return 'results' from collate_fn() in the batch 
        # collect src and tgt again since we have two layers (i.e. multiple sequences per batch) 
        # need to collect src and tgt into simple lists
        src_ord = torch.transpose(torch.stack(src_ord), 0,1)
        src_res = torch.transpose(torch.stack(src_res), 0,1)
        src_pat = torch.stack(src_pat)
        tgt_ord = torch.transpose(torch.stack(tgt_ord), 0,1)
        # cast to GPU
        # src = src.to(DEVICE)
        # tgt = tgt.to(DEVICE)
        # need to shift the target by one. This is more complicated for me. 
        tgt_input = tgt_ord[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src_ord, tgt_input)
        logits = model(src_ord,
                       src_res, 
                       src_pat, 
                       tgt_input, 
                       src_mask, 
                       tgt_mask,
                       src_padding_mask, 
                       tgt_padding_mask, 
                       src_padding_mask)
        optimizer.zero_grad()
        tgt_out = tgt_ord[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()
        losses += loss.item()
    return losses / len(list(train_dataloader))

In [192]:
train_mm_epoch(model = transformer, 
               optimizer = optimizer)

Order: torch.Size([80, 10])
Results: torch.Size([80, 10])
SRC Emb: torch.Size([80, 10, 256])
PAT Emb: torch.Size([80, 10, 256])


8.902532577514648

In [76]:
train_dataloader = DataLoader(opd_train, 
                                  batch_size=10, 
                                  collate_fn=collate_batch_fn)

In [101]:
src_ord = torch.transpose(torch.stack(batch[0]), 0,1)
src_res = torch.transpose(torch.stack(batch[1]), 0,1)
alpha_o = torch.nn.Parameter(torch.randn(1))
alpha_r = torch.nn.Parameter(torch.randn(1))

In [100]:
src_res_emb = TokenEmbedding(SRC_VOCAB_SIZE,
                             EMB_SIZE, 
                             PAD_IDX)

In [173]:
src_res_emb(src_ord).shape

torch.Size([80, 10, 256])

In [103]:
src_emb = torch.add(torch.mul(alpha_o, src_res_emb(src_ord)), torch.mul(alpha_r, src_res_emb(src_res))) 

In [125]:
src_res_emb(src_ord).shape

torch.Size([80, 10, 256])

In [92]:
src_pat = torch.transpose(torch.stack(batch[2]), 0,1)
src_pat.shape

torch.Size([946, 10])

In [94]:
src_pat = torch.stack(batch[2])
src_pat.shape

torch.Size([10, 946])

In [95]:
src_pat.shape

torch.Size([10, 946])

In [108]:
pat_emb_tst = PatientEmbedding(946, EMB_SIZE)
test_emb = pat_emb_tst(src_pat)

In [112]:
test_emb.shape

torch.Size([10, 256])

In [119]:
test_emb.unsqueeze(0).repeat(80, 1, 1).shape

torch.Size([80, 10, 256])

In [None]:
def evaluate_mm_transformer(model):
    model.eval()
    losses = 0
    val_dataloader = DataLoader(opd_val, 
                                batch_size=BATCH_SIZE, 
                                collate_fn=collate_batch_fn)

    for batch in val_dataloader:
        src_ord, src_res, tgt_ord = _get_path_tensors(batch)
        src_ord = torch.transpose(torch.stack(src_ord), 0,1)
        src_res = torch.transpose(torch.stack(src_res), 0,1)
        tgt_ord = torch.transpose(torch.stack(tgt_ord), 0,1)
        # src = src.to(DEVICE)
        # tgt = tgt.to(DEVICE)
        tgt_input = tgt_ord[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src_ord, tgt_input)
        logits = model(src_ord,
                       src_res, 
                       tgt_input, 
                       src_mask, 
                       tgt_mask,
                       src_padding_mask, 
                       tgt_padding_mask, 
                       src_padding_mask)
        tgt_out = tgt_ord[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [None]:
evaluate_mm_transformer(model = transformer)

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least, then save the
    model state.
    """
    def __init__(self, best_valid_loss=float('inf')):
        self.best_valid_loss = best_valid_loss
        
    def __call__(self, config, current_valid_loss, epoch, model, optimizer, criterion):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': criterion,
                       },config["best_model_path"])


In [None]:
# Training loop: 
save_best_model = SaveBestModel()
train_loss = []
val_loss = []

for epoch in range(1, config["num_epochs"]+1):
    
    train_epoch_loss = train_mm_epoch(model = transformer, 
                                      optimizer = optimizer)
    train_loss.append(train_epoch_loss)
    
    val_epoch_loss = evaluate_mm_transformer(model = transformer)
    val_loss.append(valid_epoch_loss)
    
    print(f"Epoch: {epoch}, Train loss: {train_epoch_loss:.3f}, Val loss: {val_epoch_loss:.3f}")

    save_best_model(config = config, 
                    current_valid_loss = val_epoch_loss, 
                    epoch = epoch, 
                    model = transformer,
                    optimizer = optimizer, 
                    criterion = loss_fn)


In [None]:
def _process_loss(train_loss, val_loss):
    loss_df = pd.DataFrame({"epochs": list(range(1,len(train_loss)+1)),
                            "train_loss": train_loss, 
                            "val_loss": val_loss})
    return loss_df

In [None]:
loss_df = _process_loss(train_loss, val_loss)

#### Greedy decoding 

#### Model Inference (greedy)

In [None]:
def greedy_decode(model, src_ord, src_res, src_mask, max_len, start_symbol):
    #src = src.to(DEVICE)
    #src_mask = src_mask.to(DEVICE)

    memory = model.encode(src_ord, 
                          src_res,
                          src_mask)
    memory = memory.unsqueeze(1) # shape to 3-D tensor
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long) #.to(Device)
    for i in range(max_len-1):
        # memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)) #.to(Device))
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src_ord.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

In [None]:
def predict_next_order_set(model: torch.nn.Module, src_ord_set: Tensor, src_res_set: Tensor):
    model.eval()
    num_tokens = src_ord_set.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    
    o_hat = greedy_decode(model,
                          src_ord = src_ord_set, 
                          src_res  = src_res_set, 
                          src_mask = src_mask,
                          max_len=num_tokens, 
                          start_symbol=BOS_IDX).flatten()
    
    return o_hat

#### Loading model

In [None]:
from multi_temporal_order_transformer import Seq2SeqTransformer

In [None]:
with open("/wynton/protected/home/rizk-jackson/jknecht/order_path_prediction/experiments/scripts/config_005.yaml", "r") as f:
    config = yaml.safe_load(f)
path = config["path"]

In [None]:
# Loading processed: 
opd_processed = torch.load((config["data_loader_path"] + "opd_processed.pt"))
opd_train = torch.load((config["data_loader_path"] + "opd_train.pt"))
opd_val = torch.load((config["data_loader_path"] + "opd_val.pt"))

In [None]:
# Init model:     
_order2idx = opd_train._order2idx
_idx2order = opd_train._idx2order
unique_orders = opd_processed.unique_orders # From the processed data
SRC_VOCAB_SIZE = len(_order2idx)
TGT_VOCAB_SIZE = len(_order2idx)

# Model params from config: 
EMB_SIZE = config["model_params"]["EMB_SIZE"]
NHEAD = config["model_params"]["NHEAD"]
FFN_HID_DIM = config["model_params"]["FFN_HID_DIM"]
BATCH_SIZE = config["model_params"]["BATCH_SIZE"]
NUM_ENCODER_LAYERS = config["model_params"]["NUM_ENCODER_LAYERS"]
NUM_DECODER_LAYERS = config["model_params"]["NUM_DECODER_LAYERS"]
DROPOUT = config["model_params"]["DROPOUT"]

# Define special symbols and indices
PAD_IDX = config["PAD_idx"]
BOS_IDX = config["BOS_idx"]
EOS_IDX = config["EOS_idx"]
SEP_IDX = config["SEP_idx"] # Otherwise this is 4 
UNK_IDX = config["UNK_idx"]

transformer = Seq2SeqTransformer(num_encoder_layers = NUM_ENCODER_LAYERS, 
                                 num_decoder_layers = NUM_DECODER_LAYERS, 
                                 emb_size = EMB_SIZE,
                                 pat_emb_size = EMB_SIZE, 
                                 pat_input_dim = 841, 
                                 nhead = NHEAD, 
                                 src_vocab_size = SRC_VOCAB_SIZE,
                                 tgt_vocab_size = TGT_VOCAB_SIZE, 
                                 padding_idx = PAD_IDX, 
                                 dim_feedforward = FFN_HID_DIM,
                                 dropout = DROPOUT)

# Applyung U-transform to high-dim parameters: 
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [None]:
transformer.load_state_dict(torch.load(config["best_model_path"], map_location=torch.device('cpu'))["model_state_dict"] )

In [None]:
print(transformer.alpha_o.item(), transformer.alpha_r.item())