# HW3 _ RNN

## Overview

In this homework, you will build a bi-directional RNN on diagnosis codes. The recurrent nature of RNN allows us to model the temporal relation of different visits of a patient. More specifically, we will still perform Heart Failure Prediction, but with different input formats.

In [10]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Get the current path:
current_dir = os.path.dirname(os.path.abspath(''))

# Define data path
DATA_PATH = "./HW3_RNN-lib/data"

## About Raw Data

To get started, we will implement a naive RNN model for heart failure prediction using the diagnosis codes.

We will use the same dataset synthesized from MIMIC-III, but with different input formats.

The data has been preprocessed for you. Let us load them and take a look.

In [21]:
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

assert len(pids) == len(vids) == len(hfs) == len(seqs) == 1000
assert len(types) == 619


where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes

Let us take a patient as an example.

In [22]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


Note that seqs is a list of list of list. That is, seqs[i][j][k] gives you the k-th diagnosis codes for the j-th visit for the i-th patient.

And you can look up the meaning of the ICD9 code online. For example, DIAG_276 represetns disorders of fluid electrolyte and acid-base balance.

Further, let see number of heart failure patients.


In [23]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


## 1 Build the dataset 

### 1.1 CustomDataset 

First, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the sequences of diagnosis codes `seqs` as input and heart failure `hfs` as output.

In [25]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        
        """
        TODO: Store `seqs`. to `self.x` and `hfs` to `self.y`.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        Do NOT permute the data.
        """
        
        self.x = seqs
        self.y = hfs
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        return len(self.x)
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        
        return self.x[index], self.y[index]
        

dataset = CustomDataset(seqs, hfs)

print(dataset[:1])

([[[85, 112, 346, 380, 269, 511, 114, 103, 530, 597, 511], [85, 103, 112, 513, 511, 19, 149, 530, 186, 66]]], [1])


### 1.2 Collate Function [20 points]

As you note that, we do not convert the data to tensor in the built `CustomDataset`. Instead, we will do this using a collate function `collate_fn()`. 

This collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

For example, assume the `DataLoader` gets a list of two samples.

```
[ [ [0, 1, 2], [8, 0] ], 
  [ [12, 13, 6, 7], [12], [23, 11] ] ]
```

where the first sample has two visits `[0, 1, 2]` and `[8, 0]` and the second sample has three visits `[12, 13, 6, 7]`, `[12]`, and `[23, 11]`.

The collate function `collate_fn()` is supposed to pad them into the same shape (3, 4), where 3 is the maximum number of visits and 4 is the maximum number of diagnosis codes.

``` 
[ [ [0, 1, 2, *0*], [8, 0, *0*, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [12, 13, 6, 7], [12, *0*, *0*, *0*], [23, 11, *0*, *0*] ] ]
```

Further, the padding information will be stored in a mask with the same shape, where 1 indicates that the diagnosis code at this position is from the original input, and 0 indicates that the diagnosis code at this position is the padded value.

```
[ [ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 1, 1], [1, 0, 0, 0], [1, 1, 0, 0] ] ]
```

Lastly, we will have another diagnosis sequence in reversed time. This will be used in our RNN model for masking. Note that we only flip the true visits.

``` 
[ [ [8, 0, *0*, *0*], [0, 1, 2, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [23, 11, *0*, *0*], [12, *0*, *0*, *0*], [12, 13, 6, 7] ] ]
```

And a reversed mask as well.

```
[ [ [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1], ] ]
```

We need to pad the sequences into the same length so that we can do batch training on GPU. And we also need this mask so that when training, we can ignored the padded value as they actually do not contain any information.

In [31]:
data = CustomDataset(seqs, hfs)
sequences, labels = zip(*data) 

num_patients = len(sequences)
num_visits = [len(patient) for patient in sequences]
num_codes = [len(visit) for patient in sequences for visit in patient]

max_num_visits = max(num_visits)
max_num_codes = max(num_codes)

x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)

print(x)

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0,

In [33]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)

    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            # Fill in the diagnosis codes for this visit
            x[i_patient, j_visit, :len(visit)] = torch.tensor(visit)
            # Create mask for valid positions (True where we have actual values)
            masks[i_patient, j_visit, :len(visit)] = True
            
            # Fill in the reversed version
            # Calculate reversed visit index
            rev_j = max_num_visits - j_visit - 1
            rev_x[i_patient, rev_j, :len(visit)] = torch.tensor(visit)
            rev_masks[i_patient, rev_j, :len(visit)] = True
    
    return x, masks, rev_x, rev_masks, y

In [34]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

assert x.dtype == rev_x.dtype == torch.long
assert y.dtype == torch.float
assert masks.dtype == rev_masks.dtype == torch.bool

assert x.shape == rev_x.shape == masks.shape == rev_masks.shape == (10, 3, 24)
assert y.shape == (10,)

Now we have `CustomDataset` and `collate_fn()`. Let us split the dataset into training and validation sets.

In [35]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 800
Length of val dataset: 200


**hf_events.csv**

The data provided in *hf_events.csv* contains pid of patients who have been diagnosed with heart failure (i.e., DIAG_398, DIAG_402, DIAG_404, DIAG_428) in at least one visit. They are in the form of a tuple with the format *(pid, vid, label)*. For example,

```
156,0,1
181,1,1
```

The vid indicates the index of the first visit with heart failure of that patient and a label of 1 indicates the presence of heart failure. **Note that only patients with heart failure are included in this file. Patients who are not mentioned in this file have never been diagnosed with heart failure.**

**event_feature_map.csv**

The *event_feature_map.csv* is a map from an event_id to an integer index. This file contains *(idx, event_id)* pairs for all event ids.

## 1 Descriptive Statistics [20 points]

Before starting analytic modeling, it is a good practice to get descriptive statistics of the input raw data. In this question, you need to write code that computes various metrics on the data described previously. A skeleton code is provided to you as a starting point.

The definition of terms used in the result table are described below:

- **Event count**: Number of events recorded for a given patient.
- **Encounter count**: Number of visits recorded for a given patient.

Note that every line in the input file is an event, while each visit consists of multiple events.

**Complete the following code cell to implement the required statistics.**

Please be aware that **you are NOT allowed to change the filename and any existing function declarations.** Only `numpy`, `scipy`, `scikit-learn`, `pandas` and other built-in modules of python will be available for you to use. The use of `pandas` library is suggested. 

In [None]:
import time
import pandas as pd
import numpy as np
import datetime

# PLEASE USE THE GIVEN FUNCTION NAME, DO NOT CHANGE IT.

def read_csv(filepath=TRAIN_DATA_PATH):

    '''
    Read the events.csv and hf_events.csv files. 
    Variables returned from this function are passed as input to the metric functions.
    
    NOTE: remember to use `filepath` whose default value is `TRAIN_DATA_PATH`.
    '''
    
    events = pd.read_csv(filepath + 'events.csv')
    hf = pd.read_csv(filepath + 'hf_events.csv')

    return events, hf

def event_count_metrics(events, hf):

    '''
    TODO : Implement this function to return the event count metrics.
    
    Event count is defined as the number of events recorded for a given patient.
    '''
    ## your code here
    
    # Count events per patient
    event_counts = events['pid'].value_counts().reset_index()
    event_counts.columns = ['pid', 'event_count']
    
    # Merge event counts with HF status
    patient_df = event_counts.merge(hf, on='pid', how='left')
    normal_patients = patient_df[patient_df['label'].isna()]
    hf_patients = patient_df[patient_df['label']==1]
    
    # Calculate metrics for HF patients
    avg_hf_event_count = hf_patients['event_count'].mean() if not hf_patients.empty else None
    max_hf_event_count = hf_patients['event_count'].max() if not hf_patients.empty else None
    min_hf_event_count = hf_patients['event_count'].min() if not hf_patients.empty else None
    
    # Calculate metrics for normal patients
    avg_norm_event_count = normal_patients['event_count'].mean() if not normal_patients.empty else None
    max_norm_event_count = normal_patients['event_count'].max() if not normal_patients.empty else None
    min_norm_event_count = normal_patients['event_count'].min() if not normal_patients.empty else None

    return avg_hf_event_count, max_hf_event_count, min_hf_event_count, \
           avg_norm_event_count, max_norm_event_count, min_norm_event_count

def encounter_count_metrics(events, hf):

    '''
    TODO : Implement this function to return the encounter count metrics.
    
    Encounter count is defined as the number of visits recorded for a given patient. 
    '''
    # your code here
    
    vid_counts = events.groupby('pid')['vid'].nunique().reset_index()
    vid_counts.columns = ['pid', 'encounter_count']

    patient_df = vid_counts.merge(hf, on='pid', how='left')
    normal_patients = patient_df[patient_df['label'].isna()]
    hf_patients = patient_df[patient_df['label']==1]
    
    avg_hf_encounter_count = hf_patients['encounter_count'].mean() if not hf_patients.empty else None
    max_hf_encounter_count = hf_patients['encounter_count'].max() if not hf_patients.empty else None
    min_hf_encounter_count = hf_patients['encounter_count'].min() if not hf_patients.empty else None
    avg_norm_encounter_count = normal_patients['encounter_count'].mean() if not normal_patients.empty else None
    max_norm_encounter_count = normal_patients['encounter_count'].max() if not normal_patients.empty else None
    min_norm_encounter_count = normal_patients['encounter_count'].min() if not normal_patients.empty else None
    
    return avg_hf_encounter_count, max_hf_encounter_count, min_hf_encounter_count, \
           avg_norm_encounter_count, max_norm_encounter_count, min_norm_encounter_count

In [None]:
events, hf = read_csv(TRAIN_DATA_PATH)

#Compute the event count metrics
start_time = time.time()
event_count = event_count_metrics(events, hf)
end_time = time.time()
print(("Time to compute event count metrics: " + str(end_time - start_time) + "s"))
print(event_count)

#Compute the encounter count metrics
start_time = time.time()
encounter_count = encounter_count_metrics(events, hf)
end_time = time.time()
print(("Time to compute encounter count metrics: " + str(end_time - start_time) + "s"))
print(encounter_count)

## 2 Naive RNN [35 points] 

Let us implement a naive bi-directional RNN model.

<img src="img/bi-rnn.jpg" width="600"/>

Remember from class that, first of all, we need to transform the diagnosis code for each visit of a patient to an embedding. To do this, we can use `nn.Embedding()`, where `num_embeddings` is the number of diagnosis codes and `embedding_dim` is the embedding dimension.

Then, we can construct a simple RNN structure. Each input is this multi-hot vector. At the 0-th visit, this has $\boldsymbol{X}_0$, and at t-th visit, this has $\boldsymbol{X}_t$.

Each one of the input will then map to a hidden state $\boldsymbol{\overleftrightarrow{h}}_t$. The forward hidden state $\boldsymbol{\overrightarrow{h}}_t$ can be determined by $\boldsymbol{\overrightarrow{h}}_{t-1}$ and the corresponding current input $\boldsymbol{X}_t$.

Similarly, we will have another RNN to process the sequence in the reverse order, so that the hidden state $\boldsymbol{\overleftarrow{h}}_t$ is determined by $\boldsymbol{\overleftarrow{h}}_{t+1}$ and $\boldsymbol{X}_t$.

Finally, once we have the $\boldsymbol{\overrightarrow{h}}_T$ and $\boldsymbol{\overleftarrow{h}}_{0}$, we will concatenate the two vectors as the feature vector and train a NN to perform the classification.

Now, let us build this model. The forward steps will be:

    1. Pass the sequence through the embedding layer;
    2. Sum the embeddings for each diagnosis code up for a visit of a patient;
    3. Pass the embeddings through the RNN layer;
    4. Obtain the hidden state at the last visit;
    5. Do 1-4 for both directions and concatenate the hidden states.
    6. Pass the hidden state through the linear and activation layers.

### 2.1 Mask Selection [20 points]

Importantly, you need to use `masks` to mask out the paddings in before step 2 and before 4. So, let us first preform the mask selection.

In [38]:
def sum_embeddings_with_mask(x, masks):
    """
    Sum embeddings for true visits using mask selection
    
    Arguments:
        x: embeddings tensor (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: boolean mask tensor (batch_size, # visits, # diagnosis codes)
    
    Returns:
        sum_embeddings: summed embeddings (batch_size, # visits, embedding_dim)
    """
    # Expand mask to match embedding dimensions
    expanded_masks = masks.unsqueeze(-1).float()
    
    # Apply mask by multiplying with embeddings
    # masked_x shape: (batch_size, # visits, # diagnosis codes, embedding_dim)
    masked_x = x * expanded_masks
    
    # Sum along diagnosis codes dimension (dim=2)
    # sum_embeddings shape: (batch_size, # visits, embedding_dim)
    sum_embeddings = torch.sum(masked_x, dim=2)
    
    return sum_embeddings


Test Function

In [37]:
def test_sum_embeddings_with_mask():
    # Create sample data
    batch_size = 2
    num_visits = 3
    num_codes = 4
    embedding_dim = 5
    
    # Create sample embeddings
    x = torch.randn(batch_size, num_visits, num_codes, embedding_dim)
    
    # Create sample masks
    masks = torch.zeros(batch_size, num_visits, num_codes, dtype=torch.bool)
    masks[:, :, :2] = True  # First two codes in each visit are valid
    
    # Apply function
    result = sum_embeddings_with_mask(x, masks)
    
    # Print shapes and results
    print("Input shapes:")
    print(f"x shape: {x.shape}")
    print(f"masks shape: {masks.shape}")
    print(f"Output shape: {result.shape}")
    
    # Verify results
    print("\nVerifying results:")
    # Manual calculation for first visit of first batch
    manual_sum = x[0, 0, :2].sum(dim=0)  # Sum only first two codes
    print(f"Manual sum for first visit: {manual_sum}")
    print(f"Function result for first visit: {result[0, 0]}")
    
    return result

# Run test
test_result = test_sum_embeddings_with_mask()


Input shapes:
x shape: torch.Size([2, 3, 4, 5])
masks shape: torch.Size([2, 3, 4])
Output shape: torch.Size([2, 3, 5])

Verifying results:
Manual sum for first visit: tensor([ 0.3717,  1.6714,  1.0547, -0.2694, -2.4057])
Function result for first visit: tensor([ 0.3717,  1.6714,  1.0547, -0.2694, -2.4057])


In [39]:
import random
import ast
import inspect


def uses_loop(function):
    loop_statements = ast.For, ast.While, ast.AsyncFor

    nodes = ast.walk(ast.parse(inspect.getsource(function)))
    return any(isinstance(node, loop_statements) for node in nodes)

def generate_random_mask(batch_size, max_num_visits , max_num_codes):
    num_visits = [random.randint(1, max_num_visits) for _ in range(batch_size)]
    num_codes = []
    for n in num_visits:
        num_codes_visit = [0] * max_num_visits
        for i in range(n):
            num_codes_visit[i] = (random.randint(1, max_num_codes))
        num_codes.append(num_codes_visit)
    masks = [torch.ones((l,), dtype=torch.bool) for num_codes_visit in num_codes for l in num_codes_visit]
    masks = torch.stack([torch.cat([i, i.new_zeros(max_num_codes - i.size(0))], 0) for i in masks], 0)
    masks = masks.view((batch_size, max_num_visits, max_num_codes)).bool()
    return masks


batch_size = 16
max_num_visits = 10
max_num_codes = 20
embedding_dim = 100

torch.random.manual_seed(7)
x = torch.randn((batch_size, max_num_visits , max_num_codes, embedding_dim))
masks = generate_random_mask(batch_size, max_num_visits , max_num_codes)
out = sum_embeddings_with_mask(x, masks)

assert uses_loop(sum_embeddings_with_mask) is False
assert out.shape == (batch_size, max_num_visits, embedding_dim)


In [40]:
def get_last_visit(hidden_states, masks):
    """
    TODO: obtain the hidden state for the last true visit (not padding visits)

    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
        
    NOTE: DO NOT use for loop.
    
    HINT: First convert the mask to a vector of shape (batch_size,) containing the true visit length; 
          and then use this length vector as index to select the last visit.
    """
    
    # Get any non-zero values along the diagnosis codes dimension to identify valid visits
    # Shape: (batch_size, # visits)
    valid_visits = masks.any(dim=-1)
    
    # Get the indices of the last true visit for each batch
    # Shape: (batch_size,)
    last_visit_indices = valid_visits.sum(dim=1) - 1
    
    # Create batch indices
    # Shape: (batch_size,)
    batch_indices = torch.arange(hidden_states.size(0))
    
    # Get the last hidden state for each sequence in the batch
    # Shape: (batch_size, embedding_dim)
    last_hidden_state = hidden_states[batch_indices, last_visit_indices]
    
    return last_hidden_state

In [41]:
assert uses_loop(get_last_visit) is False

max_num_visits = 10
batch_size = 16
max_num_codes = 20
embedding_dim = 100

torch.random.manual_seed(7)
hidden_states = torch.randn((batch_size, max_num_visits, embedding_dim))
masks = generate_random_mask(batch_size, max_num_visits , max_num_codes)
out = get_last_visit(hidden_states, masks)

assert out.shape == (batch_size, embedding_dim)