In [None]:
# default_exp data

In [None]:
#all_slow

# Data
> Classes and functions for managing data

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#hide
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

In [None]:
#export
from lemonpie.basics import *
from lemonpie.preprocessing.transform import *
from fastai.imports import *
import copy

In [None]:
#hide
from nbdev.showdoc import *

## Split

- Splitting is already done in the raw data before vocab creation.
- The following class is to load and manage the pre-processed splits together.

In [None]:
#export
class EHRDataSplits():
    '''Class to hold the PatientList splits'''
    def __init__(self, path, age_start, age_range, start_is_date, age_in_months):
        self.train, self.valid, self.test = self._load_splits(path, age_start, age_range, start_is_date, age_in_months)
    
    def _load_splits(self, path, age_start, age_range, start_is_date, age_in_months):
        '''Load splits of preprocessed `PatientList`s from persistent store using path'''
        train = PatientList.load(path, 'train', age_start, age_range, start_is_date, age_in_months)
        valid = PatientList.load(path, 'valid', age_start, age_range, start_is_date, age_in_months)
        test  = PatientList.load(path, 'test',  age_start, age_range, start_is_date, age_in_months)
        return train, valid, test

    def get_splits(self):
        '''Return splits'''
        return self.train, self.valid, self.test
    
    def get_lengths(self):
        '''Return a dataframe with lengths (# of patients) of the splits (train, valid, test) and total'''
        lengths = [len(self.train), len(self.valid), len(self.test), len(self.train)+len(self.valid)+len(self.test)]
        return pd.DataFrame(lengths, index=['train','valid','test','total'], columns=['lengths'])
    
    def get_label_counts(self, labels):
        '''Get prevalence counts of labels in each split - returns a dataframe with counts for each split and total count'''
        counts = []
        for label in labels:
            train_count = [self.train[i].conditions[label] == 1 for i in range(len(self.train))].count(True)
            valid_count = [self.valid[i].conditions[label] == 1 for i in range(len(self.valid))].count(True)
            test_count  = [self.test[i].conditions[label] == 1 for i in range(len(self.test))].count(True)
            total_count = train_count+valid_count+test_count
            counts.append([train_count, valid_count, test_count, total_count])
        return pd.DataFrame(counts, index=labels, columns=['train','valid','test','total'])
    
    def get_pos_wts(self, labels):
        '''Get positive weights to be used in `nn.BCEWithLogitsLoss`'''
        pos_counts = self.get_label_counts(labels)
        neg_counts = self.get_lengths().transpose().values - pos_counts
        return round(neg_counts / pos_counts)

In [None]:
show_doc(EHRDataSplits, title_level=3)

<h3 id="EHRDataSplits" class="doc_header"><code>class</code> <code>EHRDataSplits</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>EHRDataSplits</code>(**`path`**, **`age_start`**, **`age_range`**, **`start_is_date`**, **`age_in_months`**)

Class to hold the PatientList splits

In [None]:
show_doc(EHRDataSplits._load_splits)

<h4 id="EHRDataSplits._load_splits" class="doc_header"><code>EHRDataSplits._load_splits</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits._load_splits</code>(**`path`**, **`age_start`**, **`age_range`**, **`start_is_date`**, **`age_in_months`**)

Load splits of preprocessed [`PatientList`](/lemonpie/preprocessing_transform.html#PatientList)s from persistent store using path

In [None]:
show_doc(EHRDataSplits.get_splits)

<h4 id="EHRDataSplits.get_splits" class="doc_header"><code>EHRDataSplits.get_splits</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_splits</code>()

Return splits

In [None]:
show_doc(EHRDataSplits.get_lengths)

<h4 id="EHRDataSplits.get_lengths" class="doc_header"><code>EHRDataSplits.get_lengths</code><a href="__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_lengths</code>()

Return a dataframe with lengths (# of patients) of the splits (train, valid, test) and total

In [None]:
show_doc(EHRDataSplits.get_label_counts)

<h4 id="EHRDataSplits.get_label_counts" class="doc_header"><code>EHRDataSplits.get_label_counts</code><a href="__main__.py#L23" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_label_counts</code>(**`labels`**)

Get prevalence counts of labels in each split - returns a dataframe with counts for each split and total count

In [None]:
show_doc(EHRDataSplits.get_pos_wts)

<h4 id="EHRDataSplits.get_pos_wts" class="doc_header"><code>EHRDataSplits.get_pos_wts</code><a href="__main__.py#L34" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_pos_wts</code>(**`labels`**)

Get positive weights to be used in `nn.BCEWithLogitsLoss`

**Tests**

In [None]:
PATH_1K, CONDITIONS

('/home/vinod/.lemonpie/datasets/synthea/1K',
 {'diabetes': '44054006',
  'stroke': '230690007',
  'alzheimers': '26929004',
  'coronary_heart': '53741008',
  'lung_cancer': '254637007',
  'breast_cancer': '254837009',
  'rheumatoid_arthritis': '69896004',
  'epilepsy': '84757009'})

In [None]:
labels = list(CONDITIONS.keys())

In [None]:
labels

['diabetes',
 'stroke',
 'alzheimers',
 'coronary_heart',
 'lung_cancer',
 'breast_cancer',
 'rheumatoid_arthritis',
 'epilepsy']

In [None]:
splits = EHRDataSplits(PATH_1K, age_start='2000-01-01', age_range=17, start_is_date=True, age_in_months=False)

In [None]:
lengths = splits.get_lengths()
lengths

Unnamed: 0,lengths
train,702
valid,234
test,235
total,1171


In [None]:
prevalence = splits.get_label_counts(labels)
prevalence

Unnamed: 0,train,valid,test,total
diabetes,43,14,19,76
stroke,30,7,11,48
alzheimers,12,7,6,25
coronary_heart,39,11,11,61
lung_cancer,12,0,2,14
breast_cancer,11,8,2,21
rheumatoid_arthritis,2,0,0,2
epilepsy,15,5,2,22


In [None]:
splits.get_pos_wts(labels)

Unnamed: 0,train,valid,test,total
diabetes,15.0,16.0,11.0,14.0
stroke,22.0,32.0,20.0,23.0
alzheimers,58.0,32.0,38.0,46.0
coronary_heart,17.0,20.0,20.0,18.0
lung_cancer,58.0,inf,116.0,83.0
breast_cancer,63.0,28.0,116.0,55.0
rheumatoid_arthritis,350.0,inf,inf,584.0
epilepsy,46.0,46.0,116.0,52.0


In [None]:
lengths.transpose().values

array([[ 702,  234,  235, 1171]])

In [None]:
neg_counts = lengths.transpose().values - prevalence
neg_counts

Unnamed: 0,train,valid,test,total
diabetes,659,220,216,1095
stroke,672,227,224,1123
alzheimers,690,227,229,1146
coronary_heart,663,223,224,1110
lung_cancer,690,234,233,1157
breast_cancer,691,226,233,1150
rheumatoid_arthritis,700,234,235,1169
epilepsy,687,229,233,1149


In [None]:
round(neg_counts / prevalence)

Unnamed: 0,train,valid,test,total
diabetes,15.0,16.0,11.0,14.0
stroke,22.0,32.0,20.0,23.0
alzheimers,58.0,32.0,38.0,46.0
coronary_heart,17.0,20.0,20.0,18.0
lung_cancer,58.0,inf,116.0,83.0
breast_cancer,63.0,28.0,116.0,55.0
rheumatoid_arthritis,350.0,inf,inf,584.0
epilepsy,46.0,46.0,116.0,52.0


**Cross check with raw**
- Check total counts against raw_csv
- Check split counts against split/raw_csv

In [None]:
raw_cnds = pd.read_csv(f'{PATH_1K}/raw_original/conditions.csv', low_memory=False)

In [None]:
cnd_codes = list(CONDITIONS.values())
cnd_codes

['44054006',
 '230690007',
 '26929004',
 '53741008',
 '254637007',
 '254837009',
 '69896004',
 '84757009']

In [None]:
int(CONDITIONS['diabetes'])

44054006

In [None]:
for label in labels:
    print(label,':: ', raw_cnds[raw_cnds.CODE == int(CONDITIONS[label])].CODE.count())

diabetes ::  76
stroke ::  48
alzheimers ::  25
coronary_heart ::  61
lung_cancer ::  14
breast_cancer ::  21
rheumatoid_arthritis ::  2
epilepsy ::  22


In [None]:
raw_cnds_train = pd.read_csv(f'{PATH_1K}/raw_split/train/conditions.csv', low_memory=False)
raw_cnds_valid = pd.read_csv(f'{PATH_1K}/raw_split/valid/conditions.csv', low_memory=False)
raw_cnds_test  = pd.read_csv(f'{PATH_1K}/raw_split/test/conditions.csv', low_memory=False)

In [None]:
for label in labels:
    assert prevalence.loc[label].total == raw_cnds[raw_cnds.CODE == int(CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].train == raw_cnds_train[raw_cnds_train.CODE == int(CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].valid == raw_cnds_valid[raw_cnds_valid.CODE == int(CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].test  == raw_cnds_test [raw_cnds_test.CODE == int(CONDITIONS[label])]. CODE.count()

## Label

**Labeling** definition in fastai -- some processes need to be run on `train` and **applied** to `valid`

This is completed in preprocessing (vocab & transform) as follows
1. Vocabs created from train data
    - Tokenizing unique values for different record codes & demographic values
    - Calculating mean and std for age
2. Vocabs applied to train, valid and test data
    - With `numericalize` for record codes & demographic values
    - With normalizing of age with the mean / std from train

**Hence labeling in our case will be creating X and y**

- X is the patient object
- y (for a single patient) needs to be a tensor made out of the patient's values for labels ('diabetes', 'stroke', 'alzheimers', 'coronary_heart', 'lung_cancer') 

So **creating the `y` tensor** is simply a matter of ..
1. extracting the values of each of the labels from each `Patient` object 
2. turning it into a `torch.FloatTensor`
3. and stacking them up using `torch.stack`

In [None]:
tst_y = np.array((True, False, False, True), dtype='float')
torch.from_numpy(tst_y), torch.FloatTensor(tst_y)

(tensor([1., 0., 0., 1.], dtype=torch.float64), tensor([1., 0., 0., 1.]))

2 ways of creating torch tensor from a numpy array, we will stick with the latter

In [None]:
y = []
for pt in splits.train:
    y.append(torch.FloatTensor(np.array([pt.conditions[label] for label in labels], dtype='float')) )

In [None]:
# y

In [None]:
y = torch.stack(y)

In [None]:
y.shape

torch.Size([702, 8])

In [None]:
y

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

Putting it into a function

In [None]:
def label_data(patient_ds, labels) -> 'x,y':
    '''Extracts y from patient object, returns x=Patient object, y=tensor of conditions'''
    def _get_y(ds, labels):
        y = []
        for pt in ds:
            y.append( torch.FloatTensor(np.array([pt.conditions[label] for label in labels], dtype='float')) )
        return torch.stack(y)
    
    x, y = patient_ds, _get_y(patient_ds, labels)
    return x,y

In [None]:
x_train,y_train = label_data(splits.train, labels)
x_valid,y_valid = label_data(splits.valid, labels)
x_test ,y_test  = label_data(splits.test , labels)

In [None]:
y_train.shape, y_valid.shape, y_test.shape

(torch.Size([702, 8]), torch.Size([234, 8]), torch.Size([235, 8]))

In [None]:
torch.full((10,1), 2)

tensor([[2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2]])

In [None]:
#export
class LabelEHRData():
    '''Class to hold labeled EHR data splits'''
    def __init__(self, train, valid, test, labels):
        '''Extracts y from patient object, each labelset a tuple of x,y: x=Patient object, y=tensor of conditions'''
        self.x_train, self.y_train = train, self._get_y(train, labels)
        self.x_valid, self.y_valid = valid, self._get_y(valid, labels)
        self.x_test,  self.y_test  = test , self._get_y(test , labels)
        
        self.train = self.x_train, self.y_train
        self.valid = self.x_valid, self.y_valid
        self.test  = self.x_test,  self.y_test
    
    def _get_y(self, ds, labels):
        '''Extract y from each patient object in ds and stack them - ds is dataset containing patient objects'''
        y = []
        for pt in ds:
            y.append( torch.FloatTensor(np.array([pt.conditions[label] for label in labels], dtype='float')) )
        return torch.stack(y)

In [None]:
show_doc(LabelEHRData, title_level=3)

<h3 id="LabelEHRData" class="doc_header"><code>class</code> <code>LabelEHRData</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>LabelEHRData</code>(**`train`**, **`valid`**, **`test`**, **`labels`**)

Class to hold labeled EHR data splits

In [None]:
show_doc(LabelEHRData.__init__)

<h4 id="LabelEHRData.__init__" class="doc_header"><code>LabelEHRData.__init__</code><a href="__main__.py#L4" class="source_link" style="float:right">[source]</a></h4>

> <code>LabelEHRData.__init__</code>(**`train`**, **`valid`**, **`test`**, **`labels`**)

Extracts y from patient object, each labelset a tuple of x,y: x=Patient object, y=tensor of conditions

In [None]:
show_doc(LabelEHRData._get_y)

<h4 id="LabelEHRData._get_y" class="doc_header"><code>LabelEHRData._get_y</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>LabelEHRData._get_y</code>(**`ds`**, **`labels`**)

Extract y from each patient object in ds and stack them - ds is dataset containing patient objects

In [None]:
labeled = LabelEHRData(*splits.get_splits(), labels)

In [None]:
labeled.train

(PatientList (702 items)
 base path:/home/vinod/.lemonpie/datasets/synthea/1K; split:train
 age_start:2000-01-01; age_range:17; age_type:years
 ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:27a8b7b6-007d-4036-82a7-80a9ab670dcb, birthdate:2005-04-13 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:532696f2-0b76-4eb0-9aea-a74e2fb1bed2, birthdate:1967-05-18 00:00:00, [(

## Dataset

[Subclasses](https://pytorch.org/docs/master/data.html?highlight=dataloader#torch.utils.data.Dataset) `torch.utils.data.Dataset`<br>
- that is implements `__len__()` and `__getitem__()`

In [None]:
# export
class EHRDataset(torch.utils.data.Dataset):
    """Class to hold a single EHR dataset (holds a tuple of x, y & m for modality type).
    Also handles lazy vs full loading of dataset on GPU."""

    def __init__(
        self,
        x_labeled: list,
        y_labeled: Tensor,
        modality_type: int,
        lazy_load_gpu: bool = True,
    ):
        """If `lazy_load_gpu` is `False`, load entire dataset on GPU."""
        self.m = torch.full((len(x_labeled), 1), modality_type)
        if lazy_load_gpu:
            self.x, self.y = x_labeled, y_labeled
            self.lazy = True
        else:
            self.x = [x.to_gpu() for x in x_labeled]
            self.y = y_labeled.to(DEVICE)
            self.m = self.m.to(DEVICE)
            self.lazy = False

    def __len__(self):
        return len(self.x)

    def _test_getitem(self, i):
        return self.x[i], self.y[i], self.m[i]

    def __getitem__(self, i):
        """If lazy loading, return deep copy of patient object `i`
        else entire dataset already on GPU - just return `i`"""
        if self.lazy:
            return copy.deepcopy(self.x[i]), self.y[i], self.m[i]
        else:
            return self.x[i], self.y[i], self.m[i]


In [None]:
show_doc(EHRDataset, title_level=3)

<h3 id="EHRDataset" class="doc_header"><code>class</code> <code>EHRDataset</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>EHRDataset</code>(**\*`args`**, **\*\*`kwds`**) :: `Dataset`

Class to hold a single EHR dataset (holds a tuple of x, y & m for modality type).
Also handles lazy vs full loading of dataset on GPU.

In [None]:
show_doc(EHRDataset.__init__)

<h4 id="EHRDataset.__init__" class="doc_header"><code>EHRDataset.__init__</code><a href="__main__.py#L6" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataset.__init__</code>(**`x_labeled`**:`list`, **`y_labeled`**:`Tensor`, **`modality_type`**:`int`, **`lazy_load_gpu`**:`bool`=*`True`*)

If `lazy_load_gpu` is `False`, load entire dataset on GPU.

In [None]:
show_doc(EHRDataset.__getitem__)

<h4 id="EHRDataset.__getitem__" class="doc_header"><code>EHRDataset.__getitem__</code><a href="__main__.py#L30" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataset.__getitem__</code>(**`i`**)

If lazy loading, return deep copy of patient object `i`
else entire dataset already on GPU - just return `i`

Since `Patient` is a custom object and not a typical tensor, we need to handle the behavior for `Dataset`, `DataLoader`, etc to function correctly.
- Memory pinning is a good idea for better performance if lazy loading to GPU
    - [A discussion - pin memory vs full load to GPU](https://discuss.pytorch.org/t/pin-memory-vs-sending-direct-to-gpu-from-dataset/33891)
- So when a DataLoader pins memory on a tensor and copy of the tensor is made on page-locked memory in RAM as opposed to swappable memory which speed up transfers to GPU
    - [A good explanation](https://stackoverflow.com/questions/5736968/why-is-cuda-pinned-memory-so-fast)
- But on custom data type like our `Patient` object, we need to define the behavior
    - [Pytorch docs](https://pytorch.org/docs/stable/data.html#memory-pinning)
- Making a [deep copy](https://docs.python.org/3/library/copy.html) of the `Patient`object to mimick tensor behavior
    - Otherwise, given the Patient holds it's changed tensors, all tensors are CUDA tensors after the first epoch and DL tries to pin memory again and this causes an error (TODO: Need to elaborate)

In [None]:
def get_ds(x_train, y_train, x_valid, y_valid, modality_type) -> 'train_ds, valid_ds':
    train_ds,valid_ds = EHRDataset(x_train, y_train, modality_type), EHRDataset(x_valid, y_valid, modality_type)
    return train_ds, valid_ds

**Testing Lazy Load**

In [None]:
train_ds, valid_ds = get_ds(*labeled.train, *labeled.valid, 3)

In [None]:
len(train_ds), len(valid_ds)

(702, 234)

In [None]:
len(labeled.train), len(labeled.x_train)

(2, 702)

In [None]:
assert len(train_ds)==len(labeled.x_train)==len(labeled.y_train)
assert len(valid_ds)==len(labeled.y_valid)==len(labeled.x_valid)

In [None]:
xb,yb, mb = train_ds[0:7]
xb,yb, mb

([ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:27a8b7b6-007d-4036-82a7-80a9ab670dcb, birthdate:2005-04-13 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:532696f2-0b76-4eb0-9aea-a74e2fb1bed2, birthdate:1967-05-18 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.

In [None]:
yb.shape, mb.shape

(torch.Size([7, 8]), torch.Size([7, 1]))

In [None]:
xb[0].obs_nums.is_pinned()

False

In [None]:
train_ds._test_getitem(0)

(ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
 tensor([0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([3]))

## Multimodal Data

| Heirarchy             |   	|   	|   	|   	|
|---	                |---	|---	|---	|---	|
| ModalityTypeDataset  	| Different batch sizes  	|   	|   	|   	|
| MultimodalDatasets    | Uniform batch sizes   	|   	|   	|   	|
| UniModalDatasets      |   	|   	|   	|   	|

### Modality Type & Multimodal Dataset

**`ConcatDataset` & Custom Batch Sampler**

Modality type is the combination of data modalities available for a given patient.

Solution used here is from this Pytorch forum discussion 
- https://discuss.pytorch.org/t/how-to-concatenate-different-datasets-each-with-different-dimensions/123218
- The `ConcatDataset` holds all the specific `MultimodalDataset`s together - one for each modality combination.
    - For example - (EHR + MRI + ECG + Notes), (EHR + MRI), (EHR + DNA + MRI + Notes), etc. 
- The custom batch sampler ensures that each batch only has elements from one of the `MultimodalDataset`s.
    - But also provides shuffling across the various types.

In [None]:
# export

class ModalityTypeBatchSampler(Sampler):
    """Custom BatchSampler for multimodal data."""

    def __init__(self, indices_list: list, batch_size: int, shuffle: bool):
        """Init with indicies from every modality-type dataset and create all batches."""
        self.indices_list = indices_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.all_batches = self._create_batches()

    def _chunk(self, indices, size):
        """Chunk indices into batch size."""
        return torch.split(torch.tensor(indices), size)

    def _create_batches(self):
        """Create batches."""
        all_batches = []
        for indices in self.indices_list:
            if self.shuffle:
                random.shuffle(indices)
            all_batches.extend(self._chunk(indices, self.batch_size))
        all_batches = [batch.tolist() for batch in all_batches]

        return all_batches

    def __iter__(self):
        """Iterable used by dataloaders."""
        if self.shuffle:
            random.shuffle(self.all_batches)
        return iter(self.all_batches)

    def __len__(self):
        """Return length based on concated datasets."""
        return len(self.all_batches)


In [None]:
show_doc(ModalityTypeBatchSampler, title_level=3)

<h3 id="ModalityTypeBatchSampler" class="doc_header"><code>class</code> <code>ModalityTypeBatchSampler</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>ModalityTypeBatchSampler</code>(**\*`args`**, **\*\*`kwds`**) :: `Sampler`

Custom BatchSampler for multimodal data.

In [None]:
show_doc(ModalityTypeBatchSampler.__init__)

<h4 id="ModalityTypeBatchSampler.__init__" class="doc_header"><code>ModalityTypeBatchSampler.__init__</code><a href="__main__.py#L6" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler.__init__</code>(**`indices_list`**:`list`, **`batch_size`**:`int`, **`shuffle`**:`bool`)

Init with indicies from every modality-type dataset and create all batches.

In [None]:
show_doc(ModalityTypeBatchSampler._chunk)

<h4 id="ModalityTypeBatchSampler._chunk" class="doc_header"><code>ModalityTypeBatchSampler._chunk</code><a href="__main__.py#L13" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler._chunk</code>(**`indices`**, **`size`**)

Chunk indices into batch size.

In [None]:
show_doc(ModalityTypeBatchSampler._create_batches)

<h4 id="ModalityTypeBatchSampler._create_batches" class="doc_header"><code>ModalityTypeBatchSampler._create_batches</code><a href="__main__.py#L17" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler._create_batches</code>()

Create batches.

In [None]:
show_doc(ModalityTypeBatchSampler.__iter__)

<h4 id="ModalityTypeBatchSampler.__iter__" class="doc_header"><code>ModalityTypeBatchSampler.__iter__</code><a href="__main__.py#L28" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler.__iter__</code>()

Iterable used by dataloaders.

In [None]:
show_doc(ModalityTypeBatchSampler.__len__)

<h4 id="ModalityTypeBatchSampler.__len__" class="doc_header"><code>ModalityTypeBatchSampler.__len__</code><a href="__main__.py#L34" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler.__len__</code>()

Return length based on concated datasets.

In [None]:
# export
def create_modality_ds_sampler(
    ehr_dataset_list: list, batch_size: int, shuffle: bool
):
    """Create a custom ConcatDataset and BatchSampler for modality types."""

    modtype_dataset = torch.utils.data.ConcatDataset(ehr_dataset_list)
    indxs = modtype_dataset.cumulative_sizes

    indicies_list = []
    for i in range(len(ehr_dataset_list)):
        if i == 0:
            indx_range = range(indxs[0])
        else:
            indx_range = range(indxs[i - 1], indxs[i])
        indicies_list.append(list(indx_range))

    batch_sampler = ModalityTypeBatchSampler(indicies_list, batch_size, shuffle)

    return modtype_dataset, batch_sampler


In [None]:
show_doc(create_modality_ds_sampler)

<h4 id="create_modality_ds_sampler" class="doc_header"><code>create_modality_ds_sampler</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>create_modality_ds_sampler</code>(**`ehr_dataset_list`**:`list`, **`batch_size`**:`int`, **`shuffle`**:`bool`)

Create a custom ConcatDataset and BatchSampler for modality types.

**Multimodal Dataset**
- https://discuss.pytorch.org/t/train-simultaneously-on-two-datasets/649
- This approach  can be used to simultaneously read from 2 Datasets and get a tuple.

In [None]:
class MultimodalDataset(torch.utils.data.Dataset):
    """Class to hold Datasets of multiple modalities."""
    def __init__(self, ds_list):
        """Separate EHR and other modalities."""
        self.ehr_dataset = ds_list[0]
        self.other_datasets = ds_list[1:]

    def __getitem__(self, i):
        """Get patient_ids from EHRDataset and 
        use them to fetch data of other modalities."""
        pts, _, _ = self.ehr_dataset[i]
        ptids = [patient.ptid for patient in pts]
        return self.ehr_dataset[i], tuple(d[pts] for d in self.other_datasets)

    def __len__(self):
        """Return count of patients in this modality type."""
        return len(self.ehr_dataset)

#### Testing with toy data

In [None]:
class ToyEHR_DS(torch.utils.data.Dataset):
    """Toy EHR Dataset for testing multimodal functionality."""

    def __init__(
        self,
        x_labeled: list,
        y_labeled: Tensor,
        modality_type: int,
    ):
        # self.m = torch.full((len(x_labeled), 1), modality_type)
        self.m = modality_type
        self.x, self.y = x_labeled, y_labeled


    def __len__(self):
        return len(self.x)

    def _test_getitem(self, i):
        return self.x[i], self.y[i], self.m

    def __getitem__(self, i):
            return self.x[i], self.y[i], self.m


In [None]:
type1_ds = ToyEHR_DS(['type1_1', 'type1_2', 'type1_3', 'type1_4'], torch.tensor((1, 0, 1, 0)), 1)
type2_ds = ToyEHR_DS(['type2_1', 'type2_2', 'type2_3', 'type2_4', 'type2_5', 'type2_6', 'type2_7'], torch.tensor((1, 0, 1, 0, 0, 1, 1)), 2)
type3_ds = ToyEHR_DS(['type3_1', 'type3_2', 'type3_3', 'type3_4', 'type3_5'], torch.tensor((1, 0, 1, 0, 0)), 3)

In [None]:
type2_ds[3], len(type2_ds)

(('type2_4', tensor(0), 2), 7)

`shuffle = False`

In [None]:
modtype_ds, sampler = create_modality_ds_sampler([type1_ds, type2_ds, type3_ds], batch_size=2, shuffle=False)

In [None]:
dl = DataLoader(modtype_ds,  batch_sampler=sampler)
len(dl)

9

In [None]:
for i, (x, y, m) in enumerate(dl):
    print(f"i={i} -- x={x}") #, y:{y} \n m:{m}")

i=0 -- x=('type1_1', 'type1_2')
i=1 -- x=('type1_3', 'type1_4')
i=2 -- x=('type2_1', 'type2_2')
i=3 -- x=('type2_3', 'type2_4')
i=4 -- x=('type2_5', 'type2_6')
i=5 -- x=('type2_7',)
i=6 -- x=('type3_1', 'type3_2')
i=7 -- x=('type3_3', 'type3_4')
i=8 -- x=('type3_5',)


`shuffle = True`

In [None]:
mm_ds, sampler = create_modality_ds_sampler([type1_ds, type2_ds, type3_ds], batch_size=2, shuffle=True)

In [None]:
dl = DataLoader(mm_ds,  batch_sampler=sampler)
len(dl)

9

In [None]:
for i, (x, y, m) in enumerate(dl):
    print(f"i={i} -- x={x}") # \n y:{y} \n m:{m}")

i=0 -- x=('type3_5', 'type3_3')
i=1 -- x=('type2_7', 'type2_2')
i=2 -- x=('type1_2', 'type1_3')
i=3 -- x=('type3_2',)
i=4 -- x=('type2_6', 'type2_4')
i=5 -- x=('type2_5', 'type2_1')
i=6 -- x=('type2_3',)
i=7 -- x=('type3_1', 'type3_4')
i=8 -- x=('type1_4', 'type1_1')


**Single modality type** - for example just EHR tabular.

In [None]:
mm_ds, sampler = create_modality_ds_sampler([type2_ds], batch_size=2, shuffle=False)

In [None]:
dl = DataLoader(mm_ds,  batch_sampler=sampler)
len(dl)

4

In [None]:
for i, (x, y, m) in enumerate(dl):
    print(f"i={i} -- x={x}") # \n y:{y} \n m:{m}")

i=0 -- x=('type2_1', 'type2_2')
i=1 -- x=('type2_3', 'type2_4')
i=2 -- x=('type2_5', 'type2_6')
i=3 -- x=('type2_7',)


**Testing Multimodal functionality**

In [None]:
class UnimodalDataset(torch.utils.data.Dataset):
    def __init__(self, type: str):
        super().__init__()
        self.type = type
    
    def __getitem__(self, i):
        return f"{self.type}-{i}"

    def __len__(self):
        return 100


In [None]:
class ToyMMDataset(torch.utils.data.Dataset):
    def __init__(self, ds_list):
        self.ehr_dataset = ds_list[0]
        self.other_datasets = ds_list[1:]

    def __getitem__(self, i):
        pts, _, _ = self.ehr_dataset[i]
        # ptids = [patient.ptid for patient in pts]
        return self.ehr_dataset[i], tuple(d[pts] for d in self.other_datasets)

    def __len__(self):
        return len(self.ehr_dataset)

In [None]:
type1_ehr_ds = ToyEHR_DS(['type1_1', 'type1_2', 'type1_3', 'type1_4'], torch.tensor((1, 0, 1, 0)), 1)
type2_ehr_ds = ToyEHR_DS(['type2_1', 'type2_2', 'type2_3', 'type2_4', 'type2_5', 'type2_6', 'type2_7'], torch.tensor((1, 0, 1, 0, 0, 1, 1)), 2)
type3_ehr_ds = ToyEHR_DS(['type3_1', 'type3_2', 'type3_3', 'type3_4', 'type3_5'], torch.tensor((1, 0, 1, 0, 0)), 3)

ehr_batch_sz = 2

mri_ds = UnimodalDataset("mri")
ecg_ds = UnimodalDataset("ecg")
dna_ds = UnimodalDataset("dna")
notes_ds = UnimodalDataset("notes")

type1_mm_ds = ToyMMDataset([type1_ehr_ds, mri_ds, ecg_ds])
type2_mm_ds = ToyMMDataset([type2_ehr_ds, dna_ds, notes_ds])
type3_mm_ds = ToyMMDataset([type3_ehr_ds, notes_ds, mri_ds, ecg_ds])



modtype_ds, modtype_sampler = create_modality_ds_sampler([type1_mm_ds, type2_mm_ds, type3_mm_ds], batch_size=ehr_batch_sz, shuffle=True)
ehr_dl = DataLoader(modtype_ds,  batch_sampler=modtype_sampler)

In [None]:
modtype_ds.cumulative_sizes

[4, 11, 16]

In [None]:
next(iter(ehr_dl))

[[('type2_3', 'type2_5'), tensor([1, 0]), tensor([2, 2])],
 [('dna-type2_3', 'dna-type2_5'), ('notes-type2_3', 'notes-type2_5')]]

#### Testing

In [None]:
class MRIDataset(Dataset):
    def __init__():
        

## DataLoader - Using Pytorch DataLoader

**Need to define a custom collate function**, because default collate cannot handle list of patient objects in x, gives following error
```
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class '__main__.Patient'>
```

In [None]:
valid_ds[0:4]

([ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cpu,
  ptid:f1921fc3-fdfc-441d-a928-27c18002fedf, birthdate:1909-12-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:fc4aa89c-e441-4c0b-841f-3d16ffe1b235, birthdate:1981-04-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:4e0be087-7a33-4655-a9c0-f00f23178ac1, birthdate:1977-02-03 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[3],
         [3],
         [3],
         [3]]))

In [None]:
x_tmps,y_tmps, m_tmps = valid_ds[0:4]

In [None]:
x_tmps

[ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cpu,
 ptid:f1921fc3-fdfc-441d-a928-27c18002fedf, birthdate:1909-12-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
 ptid:fc4aa89c-e441-4c0b-841f-3d16ffe1b235, birthdate:1981-04-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
 ptid:4e0be087-7a33-4655-a9c0-f00f23178ac1, birthdate:1977-02-03 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu]

In [None]:
y_tmps

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
m_tmps

tensor([[3],
        [3],
        [3],
        [3]])

**Old collate fns**

**1. removed cuda calls**
```python
def collate(b):
    xs,ys = zip(*b)
    return [x.to_gpu() for x in xs], torch.unsqueeze(torch.tensor(ys), 1).cuda()
```
**2. removed unsqueeze**
```python
def collate(b):
    xs,ys = zip(*b)
    return xs, torch.unsqueeze(torch.tensor(ys), 1)
```

In [None]:
def collate_ehr(b):
    '''Custom collate function for use in `DataLoader`'''
    xs,ys, ms = zip(*b)
    return xs, torch.stack(ys), torch.stack(ms)

In [None]:
bs = 2

In [None]:
def get_dls(train_ds, valid_ds, bs, collate_fn=collate_ehr, lazy=True) -> 'train_dl, valid_dl':
    return(DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, pin_memory=lazy),
           DataLoader(valid_ds, batch_size=bs*2, collate_fn=collate_fn, pin_memory=lazy))

In [None]:
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs)

**Tests - `iter()`, `next()` - Next Batch**

In [None]:
it = iter(valid_dl)
first_x, first_y, first_m = next(it)
second_x, second_y, second_m = next(it)

In [None]:
first_x, first_y, first_m

([ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cpu,
  ptid:f1921fc3-fdfc-441d-a928-27c18002fedf, birthdate:1909-12-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:fc4aa89c-e441-4c0b-841f-3d16ffe1b235, birthdate:1981-04-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:4e0be087-7a33-4655-a9c0-f00f23178ac1, birthdate:1977-02-03 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[3],
         [3],
         [3],
         [3]]))

In [None]:
first_x[3].med_offsts.is_pinned(), first_y.is_pinned()

(True, True)

In [None]:
second_x, second_y, second_m

([ptid:6d048a56-edb8-4f29-891d-7a84d75a8e78, birthdate:1914-09-05 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:4fc76a3b-e39e-4091-a6af-3595e0cb607e, birthdate:1948-06-01 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:26ca976d-0b5b-4662-af41-535ff670dd5a, birthdate:2014-09-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:59486a8b-389b-4355-9df4-edc62bbd1a11, birthdate:1951-10-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[0., 0., 1., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[3],
         [3],
         [3],
         [3]]))

In [None]:
second_x[0].alg_nums

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
second_x[0].alg_nums.is_pinned()

True

**Testing full GPU loading (non-Lazy)**

In [None]:
train_ds = EHRDataset(*labeled.train, modality_type=0, lazy_load_gpu=False)
valid_ds = EHRDataset(*labeled.valid, modality_type=2, lazy_load_gpu=False)

In [None]:
xb,yb,mb = train_ds[0:5]
xb,yb,mb

([ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0],
 tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0'),
 tensor([[0],
         [0],
         [0],
         [0],
         [0]], device='cuda:0'))

In [None]:
xb[0].demographics.is_pinned()

False

In [None]:
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs, lazy=False)

In [None]:
x_tmp, y_tmp, m_tmp = next(iter(valid_dl))

In [None]:
x_tmp[0].demographics.is_pinned(), m_tmp[0]

(False, tensor([2], device='cuda:0'))

In [None]:
x_tmp[0]

ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cuda:0

In [None]:
# export
class EHRData:
    """All encompassing class for EHR data
    Holds Splits, Labels, Datasets, DataLoaders and
    provides convenience fns for training and prediction."""

    def __init__(
        self,
        path,
        labels,
        age_start,
        age_range,
        start_is_date,
        age_in_months,
        lazy_load_gpu=True,
    ):
        self.path, self.labels = path, labels
        self.age_start, self.age_range = age_start, age_range
        self.start_is_date, self.age_in_months = start_is_date, age_in_months
        self.lazy_load_gpu = lazy_load_gpu

    def load_splits(self, modality_type):
        """Load data splits given dataset path"""
        self.splits = EHRDataSplits(
            self.path,
            modality_type,
            self.age_start,
            self.age_range,
            self.start_is_date,
            self.age_in_months,
        )

    def label(self):
        """Run labeler - i.e. extract y from patient objects"""
        self.labeled = LabelEHRData(*self.splits.get_splits(), self.labels)

    def create_datasets(self, modality_type):
        """Create `EHRDataset`s"""
        self.train_ds = EHRDataset(*self.labeled.train, modality_type, self.lazy_load_gpu)
        self.valid_ds = EHRDataset(*self.labeled.valid, modality_type, self.lazy_load_gpu)
        self.test_ds = EHRDataset(*self.labeled.test, modality_type, self.lazy_load_gpu)

    def ehr_collate(b):
        """Custom collate function for use in `DataLoader`"""
        xs,ys, ms = zip(*b)
        return xs, torch.stack(ys), torch.stack(ms)

    def create_dls(self, bs, lazy, c_fn=ehr_collate, **kwargs):
        """Create `DataLoader`s"""
        self.train_dl = DataLoader(
            self.train_ds, bs, shuffle=True, collate_fn=c_fn, pin_memory=lazy, **kwargs
        )
        self.valid_dl = DataLoader(
            self.valid_ds, bs * 2, collate_fn=c_fn, pin_memory=lazy, **kwargs
        )
        self.test_dl = DataLoader(
            self.test_ds, bs * 2, collate_fn=c_fn, pin_memory=lazy, **kwargs
        )

    def _per_modality(self, modality_type, bs=64, num_workers=0):
        """Return all data per modality."""
        self.load_splits(modality_type)
        self.label()
        self.create_datasets(modality_type)
        self.create_dls(bs, self.lazy_load_gpu, num_workers=num_workers)

        pos_wts_df = self.splits.get_pos_wts(self.labels)
        pos_wts = {}
        pos_wts["train"] = torch.Tensor(pos_wts["train"].values)
        pos_wts["valid"] = torch.Tensor(pos_wts["valid"].values)
        pos_wts["test"] = torch.Tensor(pos_wts["test"].values)
        return self.train_dl, self.valid_dl, self.test_dl, pos_wts

    def get_data(self, bs=64, num_workers=0):
        """Return all data for every modality."""
        modality_types = os.listdir(f"{self.path}/processed")
        data = {}
        for m in modality_types:
            data[m] = self._per_modality(m, bs, num_workers)
        
        return data


In [None]:
show_doc(EHRData.load_splits)

<h4 id="EHRData.load_splits" class="doc_header"><code>EHRData.load_splits</code><a href="__main__.py#L22" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.load_splits</code>()

Load data splits given dataset path

In [None]:
show_doc(EHRData.label)

<h4 id="EHRData.label" class="doc_header"><code>EHRData.label</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.label</code>()

Run labeler - i.e. extract y from patient objects

In [None]:
show_doc(EHRData.create_datasets)

<h4 id="EHRData.create_datasets" class="doc_header"><code>EHRData.create_datasets</code><a href="__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.create_datasets</code>()

Create [`EHRDataset`](/lemonpie/data.html#EHRDataset)s

In [None]:
show_doc(EHRData.ehr_collate)

<h4 id="EHRData.ehr_collate" class="doc_header"><code>EHRData.ehr_collate</code><a href="__main__.py#L24" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.ehr_collate</code>(**`b`**)

Custom collate function for use in `DataLoader`

In [None]:
show_doc(EHRData.create_dls)

<h4 id="EHRData.create_dls" class="doc_header"><code>EHRData.create_dls</code><a href="__main__.py#L29" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.create_dls</code>(**`bs`**, **`lazy`**, **`c_fn`**=*`ehr_collate`*, **\*\*`kwargs`**)

Create `DataLoader`s

In [None]:
show_doc(EHRData.get_data)

<h4 id="EHRData.get_data" class="doc_header"><code>EHRData.get_data</code><a href="__main__.py#L35" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.get_data</code>(**`bs`**=*`64`*, **`num_workers`**=*`0`*)

Convenience function - returns everything needed for training

In [None]:
show_doc(EHRData.get_test_data)

<h4 id="EHRData.get_test_data" class="doc_header"><code>EHRData.get_test_data</code><a href="__main__.py#L47" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.get_test_data</code>(**`bs`**=*`64`*, **`num_workers`**=*`0`*)

Convenience function - returns everything needed for prediction using test data

class MultiModalEHRData:
    def __init__(
        self,
        path,
        labels,
        age_start,
        age_range,
        start_is_date,
        age_in_months,
        lazy_load_gpu=True,
    ):


## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_basics.ipynb.
Converted 01_preprocessing_clean.ipynb.
Converted 02_preprocessing_vocab.ipynb.
Converted 03_preprocessing_transform.ipynb.
Converted 04_data.ipynb.
Converted 05_metrics.ipynb.
Converted 06_learn.ipynb.
Converted 07_models.ipynb.
Converted 08_experiment.ipynb.
Converted 999_MMDS.ipynb.
Converted 999_amp_testing.ipynb.
Converted 999_fusion.ipynb.
Converted 99_quick_walkthru.ipynb.
Converted 99_running_exps.ipynb.
Converted index.ipynb.
